diff options
author | Chris Dyer <redpony@gmail.com> | 2009-12-26 12:49:06 -0600 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2009-12-26 12:49:06 -0600 |
commit | 3f01c8ed777aec011181dc515d9d28aa81e8530b (patch) | |
tree | 93d9f6bfb9c26cb8334ca97b42b42a27dc1dc323 /decoder | |
parent | 9be811a26da86b87bac8696f155188d9a675e61b (diff) |
increase intersection speed by a couple orders of magnitude for linear chain graphs
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/apply_models.cc | 2 | ||||
-rw-r--r-- | decoder/hg.h | 8 | ||||
-rw-r--r-- | decoder/hg_intersect.cc | 24 | ||||
-rw-r--r-- | decoder/lattice.cc | 4 | ||||
-rw-r--r-- | decoder/lattice.h | 10 | ||||
-rw-r--r-- | decoder/lexcrf.cc | 1 | ||||
-rw-r--r-- | decoder/tagger.cc | 3 |
7 files changed, 47 insertions, 5 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index a340aa1a..2d8a60d5 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -407,5 +407,7 @@ void ApplyModelSet(const Hypergraph& in, cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); } + out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed + // automatically } diff --git a/decoder/hg.h b/decoder/hg.h index 7a2658b8..af8d38d2 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -14,7 +14,7 @@ // - edges have 1 head, 0..n tails class Hypergraph { public: - Hypergraph() {} + Hypergraph() : is_linear_chain_(false) {} // SmallVector is a fast, small vector<int> implementation for sizes <= 2 typedef SmallVector TailNodeVector; @@ -57,6 +57,7 @@ class Hypergraph { void swap(Hypergraph& other) { other.nodes_.swap(nodes_); + std::swap(is_linear_chain_, other.is_linear_chain_); other.edges_.swap(edges_); } @@ -175,6 +176,11 @@ class Hypergraph { inline size_t NumberOfNodes() const { return nodes_.size(); } inline bool empty() const { return nodes_.empty(); } + // linear chains can be represented in a number of ways in a hypergraph, + // we define them to consist only of lexical translations and monotonic rules + inline bool IsLinearChain() const { return is_linear_chain_; } + bool is_linear_chain_; + // nodes_ is sorted in topological order std::vector<Node> nodes_; // edges_ is not guaranteed to be in any particular order diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index a5e8913a..e414fc19 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -49,7 +49,31 @@ struct RuleFilter { } }; +static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) { + vector<bool> prune(hg->edges_.size(), false); + set<int> cov; + for (int i = 0; i < prune.size(); ++i) { + const Hypergraph::Edge& edge = hg->edges_[i]; + if (edge.Arity() == 0) { + const int trg_index = edge.prev_i_; + const WordID trg = target[trg_index][0].label; + assert(edge.rule_->EWords() == 1); + prune[i] = (edge.rule_->e_[0] != trg); + if (!prune[i]) { + cov.insert(trg_index); + } + } + } + hg->PruneEdges(prune); + return (cov.size() == target.size()); +} + bool HG::Intersect(const Lattice& target, Hypergraph* hg) { + // there are a number of faster algorithms available for restricted + // classes of hypergraph and/or target. + if (hg->IsLinearChain() && target.IsSentence()) + return FastLinearIntersect(target, hg); + vector<bool> rem(hg->edges_.size(), false); const RuleFilter filter(target, 15); // TODO make configurable for (int i = 0; i < rem.size(); ++i) diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 56bc9551..956e12b4 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -54,8 +54,10 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { if (LooksLikePLF(text_or_plf)) HypergraphIO::PLFtoLattice(text_or_plf, pl); - else + else { ConvertTextToLattice(text_or_plf, pl); + pl->is_sentence_ = true; + } pl->ComputeDistances(); } diff --git a/decoder/lattice.h b/decoder/lattice.h index 71589b92..9a1932df 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -24,18 +24,22 @@ struct LatticeArc { class Lattice : public std::vector<std::vector<LatticeArc> > { friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); public: - Lattice() {} + Lattice() : is_sentence_(false) {} explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) : - std::vector<std::vector<LatticeArc> >(t, v) {} + std::vector<std::vector<LatticeArc> >(t, v), + is_sentence_(false) {} int Distance(int from, int to) const { if (dist_.empty()) return (to - from); return dist_(from, to); } - + // TODO this should actually be computed based on the contents + // of the lattice + bool IsSentence() const { return is_sentence_; } private: void ComputeDistances(); Array2D<int> dist_; + bool is_sentence_; }; #endif diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc index b80d055c..b0e03c69 100644 --- a/decoder/lexcrf.cc +++ b/decoder/lexcrf.cc @@ -106,6 +106,7 @@ bool LexicalCRF::Translate(const string& input, LatticeTools::ConvertTextToLattice(input, &lattice); smeta->SetSourceLength(lattice.size()); pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->is_linear_chain_ = true; forest->Reweight(weights); return true; } diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 5a0155cc..68894abd 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -57,6 +57,8 @@ struct TaggerImpl { Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); edge->i_ = i; edge->j_ = i+1; + edge->prev_i_ = i; // we set these for FastLinearIntersect + edge->prev_j_ = i+1; // " " " forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); } if (prev_node_id >= 0) { @@ -104,6 +106,7 @@ bool Tagger::Translate(const string& input, } pimpl_->BuildTrellis(sequence, forest); forest->Reweight(weights); + forest->is_linear_chain_ = true; return true; } |