summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc2
-rw-r--r--decoder/hg.h8
-rw-r--r--decoder/hg_intersect.cc24
-rw-r--r--decoder/lattice.cc4
-rw-r--r--decoder/lattice.h10
-rw-r--r--decoder/lexcrf.cc1
-rw-r--r--decoder/tagger.cc3
7 files changed, 47 insertions, 5 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index a340aa1a..2d8a60d5 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -407,5 +407,7 @@ void ApplyModelSet(const Hypergraph& in,
cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
exit(1);
}
+ out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed
+ // automatically
}
diff --git a/decoder/hg.h b/decoder/hg.h
index 7a2658b8..af8d38d2 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -14,7 +14,7 @@
// - edges have 1 head, 0..n tails
class Hypergraph {
public:
- Hypergraph() {}
+ Hypergraph() : is_linear_chain_(false) {}
// SmallVector is a fast, small vector<int> implementation for sizes <= 2
typedef SmallVector TailNodeVector;
@@ -57,6 +57,7 @@ class Hypergraph {
void swap(Hypergraph& other) {
other.nodes_.swap(nodes_);
+ std::swap(is_linear_chain_, other.is_linear_chain_);
other.edges_.swap(edges_);
}
@@ -175,6 +176,11 @@ class Hypergraph {
inline size_t NumberOfNodes() const { return nodes_.size(); }
inline bool empty() const { return nodes_.empty(); }
+ // linear chains can be represented in a number of ways in a hypergraph,
+ // we define them to consist only of lexical translations and monotonic rules
+ inline bool IsLinearChain() const { return is_linear_chain_; }
+ bool is_linear_chain_;
+
// nodes_ is sorted in topological order
std::vector<Node> nodes_;
// edges_ is not guaranteed to be in any particular order
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc
index a5e8913a..e414fc19 100644
--- a/decoder/hg_intersect.cc
+++ b/decoder/hg_intersect.cc
@@ -49,7 +49,31 @@ struct RuleFilter {
}
};
+static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) {
+ vector<bool> prune(hg->edges_.size(), false);
+ set<int> cov;
+ for (int i = 0; i < prune.size(); ++i) {
+ const Hypergraph::Edge& edge = hg->edges_[i];
+ if (edge.Arity() == 0) {
+ const int trg_index = edge.prev_i_;
+ const WordID trg = target[trg_index][0].label;
+ assert(edge.rule_->EWords() == 1);
+ prune[i] = (edge.rule_->e_[0] != trg);
+ if (!prune[i]) {
+ cov.insert(trg_index);
+ }
+ }
+ }
+ hg->PruneEdges(prune);
+ return (cov.size() == target.size());
+}
+
bool HG::Intersect(const Lattice& target, Hypergraph* hg) {
+ // there are a number of faster algorithms available for restricted
+ // classes of hypergraph and/or target.
+ if (hg->IsLinearChain() && target.IsSentence())
+ return FastLinearIntersect(target, hg);
+
vector<bool> rem(hg->edges_.size(), false);
const RuleFilter filter(target, 15); // TODO make configurable
for (int i = 0; i < rem.size(); ++i)
diff --git a/decoder/lattice.cc b/decoder/lattice.cc
index 56bc9551..956e12b4 100644
--- a/decoder/lattice.cc
+++ b/decoder/lattice.cc
@@ -54,8 +54,10 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
if (LooksLikePLF(text_or_plf))
HypergraphIO::PLFtoLattice(text_or_plf, pl);
- else
+ else {
ConvertTextToLattice(text_or_plf, pl);
+ pl->is_sentence_ = true;
+ }
pl->ComputeDistances();
}
diff --git a/decoder/lattice.h b/decoder/lattice.h
index 71589b92..9a1932df 100644
--- a/decoder/lattice.h
+++ b/decoder/lattice.h
@@ -24,18 +24,22 @@ struct LatticeArc {
class Lattice : public std::vector<std::vector<LatticeArc> > {
friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
public:
- Lattice() {}
+ Lattice() : is_sentence_(false) {}
explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
- std::vector<std::vector<LatticeArc> >(t, v) {}
+ std::vector<std::vector<LatticeArc> >(t, v),
+ is_sentence_(false) {}
int Distance(int from, int to) const {
if (dist_.empty())
return (to - from);
return dist_(from, to);
}
-
+ // TODO this should actually be computed based on the contents
+ // of the lattice
+ bool IsSentence() const { return is_sentence_; }
private:
void ComputeDistances();
Array2D<int> dist_;
+ bool is_sentence_;
};
#endif
diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc
index b80d055c..b0e03c69 100644
--- a/decoder/lexcrf.cc
+++ b/decoder/lexcrf.cc
@@ -106,6 +106,7 @@ bool LexicalCRF::Translate(const string& input,
LatticeTools::ConvertTextToLattice(input, &lattice);
smeta->SetSourceLength(lattice.size());
pimpl_->BuildTrellis(lattice, *smeta, forest);
+ forest->is_linear_chain_ = true;
forest->Reweight(weights);
return true;
}
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 5a0155cc..68894abd 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -57,6 +57,8 @@ struct TaggerImpl {
Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
edge->i_ = i;
edge->j_ = i+1;
+ edge->prev_i_ = i; // we set these for FastLinearIntersect
+ edge->prev_j_ = i+1; // " " "
forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
}
if (prev_node_id >= 0) {
@@ -104,6 +106,7 @@ bool Tagger::Translate(const string& input,
}
pimpl_->BuildTrellis(sequence, forest);
forest->Reweight(weights);
+ forest->is_linear_chain_ = true;
return true;
}