diff options
-rw-r--r-- | decoder/apply_models.cc | 2 | ||||
-rw-r--r-- | decoder/hg.h | 8 | ||||
-rw-r--r-- | decoder/hg_intersect.cc | 24 | ||||
-rw-r--r-- | decoder/lattice.cc | 4 | ||||
-rw-r--r-- | decoder/lattice.h | 10 | ||||
-rw-r--r-- | decoder/lexcrf.cc | 1 | ||||
-rw-r--r-- | decoder/tagger.cc | 3 | ||||
-rw-r--r-- | tests/system_tests/unsup-align/gold.statistics | 32 |
8 files changed, 63 insertions, 21 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index a340aa1a..2d8a60d5 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -407,5 +407,7 @@ void ApplyModelSet(const Hypergraph& in, cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); } + out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed + // automatically } diff --git a/decoder/hg.h b/decoder/hg.h index 7a2658b8..af8d38d2 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -14,7 +14,7 @@ // - edges have 1 head, 0..n tails class Hypergraph { public: - Hypergraph() {} + Hypergraph() : is_linear_chain_(false) {} // SmallVector is a fast, small vector<int> implementation for sizes <= 2 typedef SmallVector TailNodeVector; @@ -57,6 +57,7 @@ class Hypergraph { void swap(Hypergraph& other) { other.nodes_.swap(nodes_); + std::swap(is_linear_chain_, other.is_linear_chain_); other.edges_.swap(edges_); } @@ -175,6 +176,11 @@ class Hypergraph { inline size_t NumberOfNodes() const { return nodes_.size(); } inline bool empty() const { return nodes_.empty(); } + // linear chains can be represented in a number of ways in a hypergraph, + // we define them to consist only of lexical translations and monotonic rules + inline bool IsLinearChain() const { return is_linear_chain_; } + bool is_linear_chain_; + // nodes_ is sorted in topological order std::vector<Node> nodes_; // edges_ is not guaranteed to be in any particular order diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index a5e8913a..e414fc19 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -49,7 +49,31 @@ struct RuleFilter { } }; +static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) { + vector<bool> prune(hg->edges_.size(), false); + set<int> cov; + for (int i = 0; i < prune.size(); ++i) { + const Hypergraph::Edge& edge = hg->edges_[i]; + if (edge.Arity() == 0) { + const int trg_index = edge.prev_i_; + const WordID trg = target[trg_index][0].label; + assert(edge.rule_->EWords() == 1); + prune[i] = (edge.rule_->e_[0] != trg); + if (!prune[i]) { + cov.insert(trg_index); + } + } + } + hg->PruneEdges(prune); + return (cov.size() == target.size()); +} + bool HG::Intersect(const Lattice& target, Hypergraph* hg) { + // there are a number of faster algorithms available for restricted + // classes of hypergraph and/or target. + if (hg->IsLinearChain() && target.IsSentence()) + return FastLinearIntersect(target, hg); + vector<bool> rem(hg->edges_.size(), false); const RuleFilter filter(target, 15); // TODO make configurable for (int i = 0; i < rem.size(); ++i) diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 56bc9551..956e12b4 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -54,8 +54,10 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { if (LooksLikePLF(text_or_plf)) HypergraphIO::PLFtoLattice(text_or_plf, pl); - else + else { ConvertTextToLattice(text_or_plf, pl); + pl->is_sentence_ = true; + } pl->ComputeDistances(); } diff --git a/decoder/lattice.h b/decoder/lattice.h index 71589b92..9a1932df 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -24,18 +24,22 @@ struct LatticeArc { class Lattice : public std::vector<std::vector<LatticeArc> > { friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); public: - Lattice() {} + Lattice() : is_sentence_(false) {} explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) : - std::vector<std::vector<LatticeArc> >(t, v) {} + std::vector<std::vector<LatticeArc> >(t, v), + is_sentence_(false) {} int Distance(int from, int to) const { if (dist_.empty()) return (to - from); return dist_(from, to); } - + // TODO this should actually be computed based on the contents + // of the lattice + bool IsSentence() const { return is_sentence_; } private: void ComputeDistances(); Array2D<int> dist_; + bool is_sentence_; }; #endif diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc index b80d055c..b0e03c69 100644 --- a/decoder/lexcrf.cc +++ b/decoder/lexcrf.cc @@ -106,6 +106,7 @@ bool LexicalCRF::Translate(const string& input, LatticeTools::ConvertTextToLattice(input, &lattice); smeta->SetSourceLength(lattice.size()); pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->is_linear_chain_ = true; forest->Reweight(weights); return true; } diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 5a0155cc..68894abd 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -57,6 +57,8 @@ struct TaggerImpl { Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); edge->i_ = i; edge->j_ = i+1; + edge->prev_i_ = i; // we set these for FastLinearIntersect + edge->prev_j_ = i+1; // " " " forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); } if (prev_node_id >= 0) { @@ -104,6 +106,7 @@ bool Tagger::Translate(const string& input, } pimpl_->BuildTrellis(sequence, forest); forest->Reweight(weights); + forest->is_linear_chain_ = true; return true; } diff --git a/tests/system_tests/unsup-align/gold.statistics b/tests/system_tests/unsup-align/gold.statistics index afc49bfc..2f37c2db 100644 --- a/tests/system_tests/unsup-align/gold.statistics +++ b/tests/system_tests/unsup-align/gold.statistics @@ -7,8 +7,8 @@ +lm_edges 3 +lm_paths 2 +lm_trans blue -constr_nodes 3 -constr_edges 3 +constr_nodes 2 +constr_edges 2 constr_paths 1 -lm_nodes 2 -lm_edges 4 @@ -19,8 +19,8 @@ constr_paths 1 +lm_edges 4 +lm_paths 3 +lm_trans house -constr_nodes 3 -constr_edges 3 +constr_nodes 2 +constr_edges 2 constr_paths 1 -lm_nodes 4 -lm_edges 16 @@ -31,8 +31,8 @@ constr_paths 1 +lm_edges 20 +lm_paths 49 +lm_trans the house -constr_nodes 8 -constr_edges 11 +constr_nodes 7 +constr_edges 10 constr_paths 4 -lm_nodes 4 -lm_edges 12 @@ -43,8 +43,8 @@ constr_paths 4 +lm_edges 16 +lm_paths 25 +lm_trans house blue -constr_nodes 8 -constr_edges 11 +constr_nodes 7 +constr_edges 10 constr_paths 4 -lm_nodes 4 -lm_edges 14 @@ -55,8 +55,8 @@ constr_paths 4 +lm_edges 18 +lm_paths 36 +lm_trans the the -constr_nodes 8 -constr_edges 11 +constr_nodes 7 +constr_edges 10 constr_paths 4 -lm_nodes 2 -lm_edges 5 @@ -67,8 +67,8 @@ constr_paths 4 +lm_edges 5 +lm_paths 4 +lm_trans the -constr_nodes 3 -constr_edges 3 +constr_nodes 2 +constr_edges 2 constr_paths 1 -lm_nodes 4 -lm_edges 14 @@ -79,8 +79,8 @@ constr_paths 1 +lm_edges 18 +lm_paths 36 +lm_trans the the -constr_nodes 8 -constr_edges 11 +constr_nodes 7 +constr_edges 10 constr_paths 4 -lm_nodes 4 -lm_edges 10 @@ -91,6 +91,6 @@ constr_paths 4 +lm_edges 14 +lm_paths 16 +lm_trans end thet -constr_nodes 8 -constr_edges 11 +constr_nodes 7 +constr_edges 10 constr_paths 4 |