diff options
| author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
|---|---|---|
| committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-22 05:12:27 +0000 | 
| commit | 0172721855098ca02b207231a654dffa5e4eb1c9 (patch) | |
| tree | 8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /decoder/hg_intersect.cc | |
| parent | 37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff) | |
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/hg_intersect.cc')
| -rw-r--r-- | decoder/hg_intersect.cc | 160 | 
1 files changed, 160 insertions, 0 deletions
| diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc new file mode 100644 index 00000000..02ff752e --- /dev/null +++ b/decoder/hg_intersect.cc @@ -0,0 +1,160 @@ +#include "hg_intersect.h" + +#include <vector> +#include <tr1/unordered_map> +#include <boost/lexical_cast.hpp> +#include <boost/functional/hash.hpp> + +#include "tdict.h" +#include "hg.h" +#include "trule.h" +#include "wordid.h" +#include "bottom_up_parser.h" + +using boost::lexical_cast; +using namespace std::tr1; +using namespace std; + +struct RuleFilter { +  unordered_map<vector<WordID>, bool, boost::hash<vector<WordID> > > exists_; +  bool true_lattice; +  RuleFilter(const Lattice& target, int max_phrase_size) { +    true_lattice = false; +    for (int i = 0; i < target.size(); ++i) { +      vector<WordID> phrase; +      int lim = min(static_cast<int>(target.size()), i + max_phrase_size); +      for (int j = i; j < lim; ++j) { +        if (target[j].size() > 1) { true_lattice = true; break; } +        phrase.push_back(target[j][0].label); +        exists_[phrase] = true; +      } +    } +    vector<WordID> sos(1, TD::Convert("<s>")); +    exists_[sos] = true; +  } +  bool operator()(const TRule& r) const { +    // TODO do some smarter filtering for lattices +    if (true_lattice) return false;  // don't filter "true lattice" input +    const vector<WordID>& e = r.e(); +    for (int i = 0; i < e.size(); ++i) { +      if (e[i] <= 0) continue; +      vector<WordID> phrase; +      for (int j = i; j < e.size(); ++j) { +        if (e[j] <= 0) break; +        phrase.push_back(e[j]); +        if (exists_.count(phrase) == 0) return true; +      } +    } +    return false; +  } +}; + +static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) { +  cerr << "  Fast linear-chain intersection...\n"; +  vector<bool> prune(hg->edges_.size(), false); +  set<int> cov; +  map<const TRule*, TRulePtr> inverted_rules; +  for (int i = 0; i < prune.size(); ++i) { +    Hypergraph::Edge& edge = hg->edges_[i]; +    if (edge.Arity() == 0) { +      const int trg_index = edge.prev_i_; +      const WordID trg = target[trg_index][0].label; +      assert(edge.rule_->EWords() == 1); +      TRulePtr& inv_rule = inverted_rules[edge.rule_.get()]; +      if (!inv_rule) { +        inv_rule.reset(new TRule(*edge.rule_)); +        inv_rule->e_.swap(inv_rule->f_); +      } +      prune[i] = (edge.rule_->e_[0] != trg); +      if (!prune[i]) { +        cov.insert(trg_index); +        swap(edge.prev_i_, edge.i_); +        swap(edge.prev_j_, edge.j_); +        edge.rule_.swap(inv_rule); +      } +    } +  } +  hg->PruneEdges(prune, true); +  return (cov.size() == target.size()); +} + +bool HG::Intersect(const Lattice& target, Hypergraph* hg) { +  // there are a number of faster algorithms available for restricted +  // classes of hypergraph and/or target. +  if (hg->IsLinearChain() && target.IsSentence()) +    return FastLinearIntersect(target, hg); + +  vector<bool> rem(hg->edges_.size(), false); +  const RuleFilter filter(target, 15);   // TODO make configurable +  for (int i = 0; i < rem.size(); ++i) +    rem[i] = filter(*hg->edges_[i].rule_); +  hg->PruneEdges(rem, true); + +  const int nedges = hg->edges_.size(); +  const int nnodes = hg->nodes_.size(); + +  TextGrammar* g = new TextGrammar; +  GrammarPtr gp(g); +  vector<int> cats(nnodes); +  // each node in the translation forest becomes a "non-terminal" in the new +  // grammar, create the labels here +  const string kSEP = "_"; +  for (int i = 0; i < nnodes; ++i) { +    const char* pstr = "CAT"; +    if (hg->nodes_[i].cat_ < 0) +      pstr = TD::Convert(-hg->nodes_[i].cat_); +    cats[i] = TD::Convert(pstr + kSEP + lexical_cast<string>(i)) * -1; +  } + +  // construct the grammar +  for (int i = 0; i < nedges; ++i) { +    const Hypergraph::Edge& edge = hg->edges_[i]; +    const vector<WordID>& tgt = edge.rule_->e(); +    const vector<WordID>& src = edge.rule_->f(); +    TRulePtr rule(new TRule); +    rule->prev_i = edge.i_; +    rule->prev_j = edge.j_; +    rule->lhs_ = cats[edge.head_node_]; +    vector<WordID>& f = rule->f_; +    vector<WordID>& e = rule->e_; +    f.resize(tgt.size());   // swap source and target, since the parser +    e.resize(src.size());   // parses using the source side! +    Hypergraph::TailNodeVector tn(edge.tail_nodes_.size()); +    int ntc = 0; +    for (int j = 0; j < tgt.size(); ++j) { +      const WordID& cur = tgt[j]; +      if (cur > 0) { +        f[j] = cur; +      } else { +        tn[ntc++] = cur; +        f[j] = cats[edge.tail_nodes_[-cur]]; +      } +    } +    ntc = 0; +    for (int j = 0; j < src.size(); ++j) { +      const WordID& cur = src[j]; +      if (cur > 0) { +        e[j] = cur; +      } else { +        e[j] = tn[ntc++]; +      } +    } +    rule->scores_ = edge.feature_values_; +    rule->parent_rule_ = edge.rule_; +    rule->ComputeArity(); +    //cerr << "ADD: " << rule->AsString() << endl; +     +    g->AddRule(rule); +  } +  g->SetMaxSpan(target.size() + 1); +  const string& new_goal = TD::Convert(cats.back() * -1); +  vector<GrammarPtr> grammars(1, gp); +  Hypergraph tforest; +  ExhaustiveBottomUpParser parser(new_goal, grammars); +  if (!parser.Parse(target, &tforest)) +    return false; +  else +    hg->swap(tforest); +  return true; +} + | 
