diff options
-rw-r--r-- | decoder/tree2string_translator.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6f65658e..4cd584fb 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -101,6 +101,7 @@ namespace std { struct Tree2StringTranslatorImpl { vector<boost::shared_ptr<Tree2StringGrammarNode>> root; bool add_pass_through_rules; + unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { @@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl { } } } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFID = FD::Convert("PassThrough"); + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + for (auto& prod : tree.nodes) { + ostringstream os; + vector<int> rhse, rhsf; + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + } else { + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + ++ntc; + } + } + os << ')'; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(kFID, 1.0); + cur->rules.push_back(rule); + } + } + + void RemoveGrammars() { + assert(remove_grammars < root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { + remove_grammars = 0; cdec::TreeFragment input_tree(input, false); + if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); vector<int> tree2hg(input_tree.nodes.size() + 1, -1); @@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv } void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { |