From 3f01c8ed777aec011181dc515d9d28aa81e8530b Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Sat, 26 Dec 2009 12:49:06 -0600
Subject: increase intersection speed by a couple orders of magnitude for
 linear chain graphs

---
 decoder/apply_models.cc |  2 ++
 decoder/hg.h            |  8 +++++++-
 decoder/hg_intersect.cc | 24 ++++++++++++++++++++++++
 decoder/lattice.cc      |  4 +++-
 decoder/lattice.h       | 10 +++++++---
 decoder/lexcrf.cc       |  1 +
 decoder/tagger.cc       |  3 +++
 7 files changed, 47 insertions(+), 5 deletions(-)

(limited to 'decoder')

diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index a340aa1a..2d8a60d5 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -407,5 +407,7 @@ void ApplyModelSet(const Hypergraph& in,
     cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
     exit(1);
   }
+  out->is_linear_chain_ = in.is_linear_chain_;  // TODO remove when this is computed
+                                                // automatically
 }
 
diff --git a/decoder/hg.h b/decoder/hg.h
index 7a2658b8..af8d38d2 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -14,7 +14,7 @@
 //  - edges have 1 head, 0..n tails
 class Hypergraph {
  public:
-  Hypergraph() {}
+  Hypergraph() : is_linear_chain_(false) {}
 
   // SmallVector is a fast, small vector<int> implementation for sizes <= 2
   typedef SmallVector TailNodeVector;
@@ -57,6 +57,7 @@ class Hypergraph {
 
   void swap(Hypergraph& other) {
     other.nodes_.swap(nodes_);
+    std::swap(is_linear_chain_, other.is_linear_chain_);
     other.edges_.swap(edges_);
   }
 
@@ -175,6 +176,11 @@ class Hypergraph {
   inline size_t NumberOfNodes() const { return nodes_.size(); }
   inline bool empty() const { return nodes_.empty(); }
 
+  // linear chains can be represented in a number of ways in a hypergraph,
+  // we define them to consist only of lexical translations and monotonic rules
+  inline bool IsLinearChain() const { return is_linear_chain_; }
+  bool is_linear_chain_;
+
   // nodes_ is sorted in topological order
   std::vector<Node> nodes_;
   // edges_ is not guaranteed to be in any particular order
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc
index a5e8913a..e414fc19 100644
--- a/decoder/hg_intersect.cc
+++ b/decoder/hg_intersect.cc
@@ -49,7 +49,31 @@ struct RuleFilter {
   }
 };
 
+static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) {
+  vector<bool> prune(hg->edges_.size(), false);
+  set<int> cov;
+  for (int i = 0; i < prune.size(); ++i) {
+    const Hypergraph::Edge& edge = hg->edges_[i];
+    if (edge.Arity() == 0) {
+      const int trg_index = edge.prev_i_;
+      const WordID trg = target[trg_index][0].label;
+      assert(edge.rule_->EWords() == 1);
+      prune[i] = (edge.rule_->e_[0] != trg);
+      if (!prune[i]) {
+        cov.insert(trg_index);
+      }
+    }
+  }
+  hg->PruneEdges(prune);
+  return (cov.size() == target.size());
+}
+
 bool HG::Intersect(const Lattice& target, Hypergraph* hg) {
+  // there are a number of faster algorithms available for restricted
+  // classes of hypergraph and/or target.
+  if (hg->IsLinearChain() && target.IsSentence())
+    return FastLinearIntersect(target, hg);
+
   vector<bool> rem(hg->edges_.size(), false);
   const RuleFilter filter(target, 15);   // TODO make configurable
   for (int i = 0; i < rem.size(); ++i)
diff --git a/decoder/lattice.cc b/decoder/lattice.cc
index 56bc9551..956e12b4 100644
--- a/decoder/lattice.cc
+++ b/decoder/lattice.cc
@@ -54,8 +54,10 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
 void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
   if (LooksLikePLF(text_or_plf))
     HypergraphIO::PLFtoLattice(text_or_plf, pl);
-  else
+  else {
     ConvertTextToLattice(text_or_plf, pl);
+    pl->is_sentence_ = true;
+  }
   pl->ComputeDistances();
 }
 
diff --git a/decoder/lattice.h b/decoder/lattice.h
index 71589b92..9a1932df 100644
--- a/decoder/lattice.h
+++ b/decoder/lattice.h
@@ -24,18 +24,22 @@ struct LatticeArc {
 class Lattice : public std::vector<std::vector<LatticeArc> > {
   friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
  public:
-  Lattice() {}
+  Lattice() : is_sentence_(false) {}
   explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
-   std::vector<std::vector<LatticeArc> >(t, v) {}
+   std::vector<std::vector<LatticeArc> >(t, v),
+   is_sentence_(false) {}
   int Distance(int from, int to) const {
     if (dist_.empty())
       return (to - from);
     return dist_(from, to);
   }
-
+  // TODO this should actually be computed based on the contents
+  // of the lattice
+  bool IsSentence() const { return is_sentence_; }
  private:
   void ComputeDistances();
   Array2D<int> dist_;
+  bool is_sentence_;
 };
 
 #endif
diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc
index b80d055c..b0e03c69 100644
--- a/decoder/lexcrf.cc
+++ b/decoder/lexcrf.cc
@@ -106,6 +106,7 @@ bool LexicalCRF::Translate(const string& input,
   LatticeTools::ConvertTextToLattice(input, &lattice);
   smeta->SetSourceLength(lattice.size());
   pimpl_->BuildTrellis(lattice, *smeta, forest);
+  forest->is_linear_chain_ = true;
   forest->Reweight(weights);
   return true;
 }
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 5a0155cc..68894abd 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -57,6 +57,8 @@ struct TaggerImpl {
         Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
         edge->i_ = i;
         edge->j_ = i+1;
+        edge->prev_i_ = i;    // we set these for FastLinearIntersect
+        edge->prev_j_ = i+1;  //      "      "            "
         forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
       }
       if (prev_node_id >= 0) {
@@ -104,6 +106,7 @@ bool Tagger::Translate(const string& input,
   }
   pimpl_->BuildTrellis(sequence, forest);
   forest->Reweight(weights);
+  forest->is_linear_chain_ = true;
   return true;
 }
 
-- 
cgit v1.2.3