diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-22 12:07:20 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-22 12:07:20 +0100 |
commit | 5f98fe5c4f2a2090eeb9d30c030305a70a8347d1 (patch) | |
tree | 9b6002f850e6dea1e3400c6b19bb31a9cdf3067f /rst_parser/arc_factored.h | |
parent | cf9994131993b40be62e90e213b1e11e6b550143 (diff) | |
parent | 21825a09d97c2e0afd20512f306fb25fed55e529 (diff) |
Merge remote branch 'upstream/master'
Conflicts:
Jamroot
bjam
decoder/Jamfile
decoder/cdec.cc
dpmert/Jamfile
jam-files/sanity.jam
klm/lm/Jamfile
klm/util/Jamfile
mira/Jamfile
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r-- | rst_parser/arc_factored.h | 124 |
1 files changed, 0 insertions, 124 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h deleted file mode 100644 index c5481d80..00000000 --- a/rst_parser/arc_factored.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _ARC_FACTORED_H_ -#define _ARC_FACTORED_H_ - -#include <iostream> -#include <cassert> -#include <vector> -#include <utility> -#include <boost/shared_ptr.hpp> -#include "array2d.h" -#include "sparse_vector.h" -#include "prob.h" -#include "weights.h" -#include "wordid.h" - -struct TaggedSentence { - std::vector<WordID> words; - std::vector<WordID> pos; -}; - -struct ArcFeatureFunctions; -struct EdgeSubset { - EdgeSubset() {} - std::vector<short> roots; // unless multiroot trees are supported, this - // will have a single member - std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0 - // assumes ArcFeatureFunction::PrepareForInput has already been called - void ExtractFeatures(const TaggedSentence& sentence, - const ArcFeatureFunctions& ffs, - SparseVector<double>* features) const; -}; - -class ArcFactoredForest { - public: - ArcFactoredForest() : num_words_() {} - explicit ArcFactoredForest(short num_words) : num_words_(num_words) { - resize(num_words); - } - - unsigned size() const { return num_words_; } - - void resize(unsigned num_words) { - num_words_ = num_words; - root_edges_.clear(); - edges_.clear(); - root_edges_.resize(num_words); - edges_.resize(num_words, num_words); - for (int h = 0; h < num_words; ++h) { - for (int m = 0; m < num_words; ++m) { - edges_(h, m).h = h; - edges_(h, m).m = m; - } - root_edges_[h].h = -1; - root_edges_[h].m = h; - } - } - - // compute the maximum spanning tree based on the current weighting - // using the O(n^2) CLE algorithm - void MaximumSpanningTree(EdgeSubset* st) const; - - // Reweight edges so that edge_prob is the edge's marginals - // optionally returns log partition - void EdgeMarginals(prob_t* p_log_z = NULL); - - // This may not return a tree - void PickBestParentForEachWord(EdgeSubset* st) const; - - struct Edge { - Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} - short h; - short m; - SparseVector<weight_t> features; - prob_t edge_prob; - }; - - // set eges_[*].features - void ExtractFeatures(const TaggedSentence& sentence, - const ArcFeatureFunctions& ffs); - - const Edge& operator()(short h, short m) const { - return h >= 0 ? edges_(h, m) : root_edges_[m]; - } - - Edge& operator()(short h, short m) { - return h >= 0 ? edges_(h, m) : root_edges_[m]; - } - - float Weight(short h, short m) const { - return log((*this)(h,m).edge_prob); - } - - template <class V> - void Reweight(const V& weights) { - for (int m = 0; m < num_words_; ++m) { - for (int h = 0; h < num_words_; ++h) { - if (h != m) { - Edge& e = edges_(h, m); - e.edge_prob.logeq(e.features.dot(weights)); - } - } - Edge& e = root_edges_[m]; - e.edge_prob.logeq(e.features.dot(weights)); - } - } - - private: - int num_words_; - std::vector<Edge> root_edges_; - Array2D<Edge> edges_; -}; - -inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& edge) { - return os << "(" << edge.h << " < " << edge.m << ")"; -} - -inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { - for (unsigned i = 0; i < ss.roots.size(); ++i) - os << "ROOT < " << ss.roots[i] << std::endl; - for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) - os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; - return os; -} - -#endif |