diff options
Diffstat (limited to 'decoder/hg_intersect.cc')
-rw-r--r-- | decoder/hg_intersect.cc | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc new file mode 100644 index 00000000..02ff752e --- /dev/null +++ b/decoder/hg_intersect.cc @@ -0,0 +1,160 @@ +#include "hg_intersect.h" + +#include <vector> +#include <tr1/unordered_map> +#include <boost/lexical_cast.hpp> +#include <boost/functional/hash.hpp> + +#include "tdict.h" +#include "hg.h" +#include "trule.h" +#include "wordid.h" +#include "bottom_up_parser.h" + +using boost::lexical_cast; +using namespace std::tr1; +using namespace std; + +struct RuleFilter { + unordered_map<vector<WordID>, bool, boost::hash<vector<WordID> > > exists_; + bool true_lattice; + RuleFilter(const Lattice& target, int max_phrase_size) { + true_lattice = false; + for (int i = 0; i < target.size(); ++i) { + vector<WordID> phrase; + int lim = min(static_cast<int>(target.size()), i + max_phrase_size); + for (int j = i; j < lim; ++j) { + if (target[j].size() > 1) { true_lattice = true; break; } + phrase.push_back(target[j][0].label); + exists_[phrase] = true; + } + } + vector<WordID> sos(1, TD::Convert("<s>")); + exists_[sos] = true; + } + bool operator()(const TRule& r) const { + // TODO do some smarter filtering for lattices + if (true_lattice) return false; // don't filter "true lattice" input + const vector<WordID>& e = r.e(); + for (int i = 0; i < e.size(); ++i) { + if (e[i] <= 0) continue; + vector<WordID> phrase; + for (int j = i; j < e.size(); ++j) { + if (e[j] <= 0) break; + phrase.push_back(e[j]); + if (exists_.count(phrase) == 0) return true; + } + } + return false; + } +}; + +static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) { + cerr << " Fast linear-chain intersection...\n"; + vector<bool> prune(hg->edges_.size(), false); + set<int> cov; + map<const TRule*, TRulePtr> inverted_rules; + for (int i = 0; i < prune.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + if (edge.Arity() == 0) { + const int trg_index = edge.prev_i_; + const WordID trg = target[trg_index][0].label; + assert(edge.rule_->EWords() == 1); + TRulePtr& inv_rule = inverted_rules[edge.rule_.get()]; + if (!inv_rule) { + inv_rule.reset(new TRule(*edge.rule_)); + inv_rule->e_.swap(inv_rule->f_); + } + prune[i] = (edge.rule_->e_[0] != trg); + if (!prune[i]) { + cov.insert(trg_index); + swap(edge.prev_i_, edge.i_); + swap(edge.prev_j_, edge.j_); + edge.rule_.swap(inv_rule); + } + } + } + hg->PruneEdges(prune, true); + return (cov.size() == target.size()); +} + +bool HG::Intersect(const Lattice& target, Hypergraph* hg) { + // there are a number of faster algorithms available for restricted + // classes of hypergraph and/or target. + if (hg->IsLinearChain() && target.IsSentence()) + return FastLinearIntersect(target, hg); + + vector<bool> rem(hg->edges_.size(), false); + const RuleFilter filter(target, 15); // TODO make configurable + for (int i = 0; i < rem.size(); ++i) + rem[i] = filter(*hg->edges_[i].rule_); + hg->PruneEdges(rem, true); + + const int nedges = hg->edges_.size(); + const int nnodes = hg->nodes_.size(); + + TextGrammar* g = new TextGrammar; + GrammarPtr gp(g); + vector<int> cats(nnodes); + // each node in the translation forest becomes a "non-terminal" in the new + // grammar, create the labels here + const string kSEP = "_"; + for (int i = 0; i < nnodes; ++i) { + const char* pstr = "CAT"; + if (hg->nodes_[i].cat_ < 0) + pstr = TD::Convert(-hg->nodes_[i].cat_); + cats[i] = TD::Convert(pstr + kSEP + lexical_cast<string>(i)) * -1; + } + + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = hg->edges_[i]; + const vector<WordID>& tgt = edge.rule_->e(); + const vector<WordID>& src = edge.rule_->f(); + TRulePtr rule(new TRule); + rule->prev_i = edge.i_; + rule->prev_j = edge.j_; + rule->lhs_ = cats[edge.head_node_]; + vector<WordID>& f = rule->f_; + vector<WordID>& e = rule->e_; + f.resize(tgt.size()); // swap source and target, since the parser + e.resize(src.size()); // parses using the source side! + Hypergraph::TailNodeVector tn(edge.tail_nodes_.size()); + int ntc = 0; + for (int j = 0; j < tgt.size(); ++j) { + const WordID& cur = tgt[j]; + if (cur > 0) { + f[j] = cur; + } else { + tn[ntc++] = cur; + f[j] = cats[edge.tail_nodes_[-cur]]; + } + } + ntc = 0; + for (int j = 0; j < src.size(); ++j) { + const WordID& cur = src[j]; + if (cur > 0) { + e[j] = cur; + } else { + e[j] = tn[ntc++]; + } + } + rule->scores_ = edge.feature_values_; + rule->parent_rule_ = edge.rule_; + rule->ComputeArity(); + //cerr << "ADD: " << rule->AsString() << endl; + + g->AddRule(rule); + } + g->SetMaxSpan(target.size() + 1); + const string& new_goal = TD::Convert(cats.back() * -1); + vector<GrammarPtr> grammars(1, gp); + Hypergraph tforest; + ExhaustiveBottomUpParser parser(new_goal, grammars); + if (!parser.Parse(target, &tforest)) + return false; + else + hg->swap(tforest); + return true; +} + |