diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/tree2string_translator.cc | 47 | 
1 files changed, 47 insertions, 0 deletions
| diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6f65658e..4cd584fb 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -101,6 +101,7 @@ namespace std {  struct Tree2StringTranslatorImpl {    vector<boost::shared_ptr<Tree2StringGrammarNode>> root;    bool add_pass_through_rules; +  unsigned remove_grammars;    Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) :        add_pass_through_rules(conf.count("add_pass_through_rules")) {      if (conf.count("grammar")) { @@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl {        }      }    } + +  void CreatePassThroughRules(const cdec::TreeFragment& tree) { +    static const int kFID = FD::Convert("PassThrough"); +    root.resize(root.size() + 1); +    root.back().reset(new Tree2StringGrammarNode); +    ++remove_grammars; +    for (auto& prod : tree.nodes) { +      ostringstream os; +      vector<int> rhse, rhsf; +      int ntc = 0; +      int lhs = -(prod.lhs & cdec::ALL_MASK); +      os << '(' << TD::Convert(-lhs); +      for (auto& sym : prod.rhs) { +        os << ' '; +        if (cdec::IsTerminal(sym)) { +          os << TD::Convert(sym); +          rhse.push_back(sym); +          rhsf.push_back(sym); +        } else { +          unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; +          os << '[' << TD::Convert(id) << ']'; +          rhsf.push_back(-id); +          rhse.push_back(-ntc); +          ++ntc; +        } +      } +      os << ')'; +      cdec::TreeFragment rule_src(os.str(), true); +      Tree2StringGrammarNode* cur = root.back().get(); +      for (auto sym : rule_src) +        cur = &cur->next[sym]; +      TRulePtr rule(new TRule(rhse, rhsf, lhs)); +      rule->ComputeArity(); +      rule->scores_.set_value(kFID, 1.0); +      cur->rules.push_back(rule); +    } +  } + +  void RemoveGrammars() { +    assert(remove_grammars < root.size()); +    root.resize(root.size() - remove_grammars); +  } +    bool Translate(const string& input,                   SentenceMetadata* smeta,                   const vector<double>& weights,                   Hypergraph* minus_lm_forest) { +    remove_grammars = 0;      cdec::TreeFragment input_tree(input, false); +    if (add_pass_through_rules) CreatePassThroughRules(input_tree);      Hypergraph hg;      hg.ReserveNodes(input_tree.nodes.size());      vector<int> tree2hg(input_tree.nodes.size() + 1, -1); @@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv  }  void Tree2StringTranslator::SentenceCompleteImpl() { +  pimpl_->RemoveGrammars();  }  std::string Tree2StringTranslator::GetDecoderType() const { | 
