summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.h
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
committerKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
commit5f98fe5c4f2a2090eeb9d30c030305a70a8347d1 (patch)
tree9b6002f850e6dea1e3400c6b19bb31a9cdf3067f /rst_parser/arc_factored.h
parentcf9994131993b40be62e90e213b1e11e6b550143 (diff)
parent21825a09d97c2e0afd20512f306fb25fed55e529 (diff)
Merge remote branch 'upstream/master'
Conflicts: Jamroot bjam decoder/Jamfile decoder/cdec.cc dpmert/Jamfile jam-files/sanity.jam klm/lm/Jamfile klm/util/Jamfile mira/Jamfile
Diffstat (limited to 'rst_parser/arc_factored.h')
-rw-r--r--rst_parser/arc_factored.h124
1 files changed, 0 insertions, 124 deletions
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
deleted file mode 100644
index c5481d80..00000000
--- a/rst_parser/arc_factored.h
+++ /dev/null
@@ -1,124 +0,0 @@
-#ifndef _ARC_FACTORED_H_
-#define _ARC_FACTORED_H_
-
-#include <iostream>
-#include <cassert>
-#include <vector>
-#include <utility>
-#include <boost/shared_ptr.hpp>
-#include "array2d.h"
-#include "sparse_vector.h"
-#include "prob.h"
-#include "weights.h"
-#include "wordid.h"
-
-struct TaggedSentence {
- std::vector<WordID> words;
- std::vector<WordID> pos;
-};
-
-struct ArcFeatureFunctions;
-struct EdgeSubset {
- EdgeSubset() {}
- std::vector<short> roots; // unless multiroot trees are supported, this
- // will have a single member
- std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0
- // assumes ArcFeatureFunction::PrepareForInput has already been called
- void ExtractFeatures(const TaggedSentence& sentence,
- const ArcFeatureFunctions& ffs,
- SparseVector<double>* features) const;
-};
-
-class ArcFactoredForest {
- public:
- ArcFactoredForest() : num_words_() {}
- explicit ArcFactoredForest(short num_words) : num_words_(num_words) {
- resize(num_words);
- }
-
- unsigned size() const { return num_words_; }
-
- void resize(unsigned num_words) {
- num_words_ = num_words;
- root_edges_.clear();
- edges_.clear();
- root_edges_.resize(num_words);
- edges_.resize(num_words, num_words);
- for (int h = 0; h < num_words; ++h) {
- for (int m = 0; m < num_words; ++m) {
- edges_(h, m).h = h;
- edges_(h, m).m = m;
- }
- root_edges_[h].h = -1;
- root_edges_[h].m = h;
- }
- }
-
- // compute the maximum spanning tree based on the current weighting
- // using the O(n^2) CLE algorithm
- void MaximumSpanningTree(EdgeSubset* st) const;
-
- // Reweight edges so that edge_prob is the edge's marginals
- // optionally returns log partition
- void EdgeMarginals(prob_t* p_log_z = NULL);
-
- // This may not return a tree
- void PickBestParentForEachWord(EdgeSubset* st) const;
-
- struct Edge {
- Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
- short h;
- short m;
- SparseVector<weight_t> features;
- prob_t edge_prob;
- };
-
- // set eges_[*].features
- void ExtractFeatures(const TaggedSentence& sentence,
- const ArcFeatureFunctions& ffs);
-
- const Edge& operator()(short h, short m) const {
- return h >= 0 ? edges_(h, m) : root_edges_[m];
- }
-
- Edge& operator()(short h, short m) {
- return h >= 0 ? edges_(h, m) : root_edges_[m];
- }
-
- float Weight(short h, short m) const {
- return log((*this)(h,m).edge_prob);
- }
-
- template <class V>
- void Reweight(const V& weights) {
- for (int m = 0; m < num_words_; ++m) {
- for (int h = 0; h < num_words_; ++h) {
- if (h != m) {
- Edge& e = edges_(h, m);
- e.edge_prob.logeq(e.features.dot(weights));
- }
- }
- Edge& e = root_edges_[m];
- e.edge_prob.logeq(e.features.dot(weights));
- }
- }
-
- private:
- int num_words_;
- std::vector<Edge> root_edges_;
- Array2D<Edge> edges_;
-};
-
-inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& edge) {
- return os << "(" << edge.h << " < " << edge.m << ")";
-}
-
-inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) {
- for (unsigned i = 0; i < ss.roots.size(); ++i)
- os << "ROOT < " << ss.roots[i] << std::endl;
- for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i)
- os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl;
- return os;
-}
-
-#endif