summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/tree2string_translator.cc47
1 files changed, 47 insertions, 0 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index 6f65658e..4cd584fb 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -101,6 +101,7 @@ namespace std {
struct Tree2StringTranslatorImpl {
vector<boost::shared_ptr<Tree2StringGrammarNode>> root;
bool add_pass_through_rules;
+ unsigned remove_grammars;
Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) :
add_pass_through_rules(conf.count("add_pass_through_rules")) {
if (conf.count("grammar")) {
@@ -114,11 +115,56 @@ struct Tree2StringTranslatorImpl {
}
}
}
+
+ void CreatePassThroughRules(const cdec::TreeFragment& tree) {
+ static const int kFID = FD::Convert("PassThrough");
+ root.resize(root.size() + 1);
+ root.back().reset(new Tree2StringGrammarNode);
+ ++remove_grammars;
+ for (auto& prod : tree.nodes) {
+ ostringstream os;
+ vector<int> rhse, rhsf;
+ int ntc = 0;
+ int lhs = -(prod.lhs & cdec::ALL_MASK);
+ os << '(' << TD::Convert(-lhs);
+ for (auto& sym : prod.rhs) {
+ os << ' ';
+ if (cdec::IsTerminal(sym)) {
+ os << TD::Convert(sym);
+ rhse.push_back(sym);
+ rhsf.push_back(sym);
+ } else {
+ unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK;
+ os << '[' << TD::Convert(id) << ']';
+ rhsf.push_back(-id);
+ rhse.push_back(-ntc);
+ ++ntc;
+ }
+ }
+ os << ')';
+ cdec::TreeFragment rule_src(os.str(), true);
+ Tree2StringGrammarNode* cur = root.back().get();
+ for (auto sym : rule_src)
+ cur = &cur->next[sym];
+ TRulePtr rule(new TRule(rhse, rhsf, lhs));
+ rule->ComputeArity();
+ rule->scores_.set_value(kFID, 1.0);
+ cur->rules.push_back(rule);
+ }
+ }
+
+ void RemoveGrammars() {
+ assert(remove_grammars < root.size());
+ root.resize(root.size() - remove_grammars);
+ }
+
bool Translate(const string& input,
SentenceMetadata* smeta,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
+ remove_grammars = 0;
cdec::TreeFragment input_tree(input, false);
+ if (add_pass_through_rules) CreatePassThroughRules(input_tree);
Hypergraph hg;
hg.ReserveNodes(input_tree.nodes.size());
vector<int> tree2hg(input_tree.nodes.size() + 1, -1);
@@ -235,6 +281,7 @@ void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv
}
void Tree2StringTranslator::SentenceCompleteImpl() {
+ pimpl_->RemoveGrammars();
}
std::string Tree2StringTranslator::GetDecoderType() const {