From 10427eab298e019bffed15e71125d34b2e27f468 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Apr 2014 21:52:15 -0400 Subject: deal with pass through rules --- decoder/tree2string_translator.cc | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 6f65658e..4cd584fb 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -101,6 +101,7 @@ namespace std { struct Tree2StringTranslatorImpl { vector> root; bool add_pass_through_rules; + unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) : add_pass_through_rules(conf.count("add_pass_through_rules")) { if (conf.count("grammar")) { @@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl { } } } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFID = FD::Convert("PassThrough"); + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + for (auto& prod : tree.nodes) { + ostringstream os; + vector rhse, rhsf; + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + } else { + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + ++ntc; + } + } + os << ')'; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(kFID, 1.0); + cur->rules.push_back(rule); + } + } + + void RemoveGrammars() { + assert(remove_grammars < root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* minus_lm_forest) { + remove_grammars = 0; cdec::TreeFragment input_tree(input, false); + if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); vector tree2hg(input_tree.nodes.size() + 1, -1); @@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map& kv } void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { -- cgit v1.2.3