diff options
Diffstat (limited to 'decoder/tree2string_translator.cc')
-rw-r--r-- | decoder/tree2string_translator.cc | 120 |
1 files changed, 106 insertions, 14 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index 1c249836..cd6ee550 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,5 +1,6 @@ #include <algorithm> #include <vector> +#include <queue> #include <boost/functional/hash.hpp> #include <unordered_map> #include "tree_fragment.h" @@ -15,11 +16,10 @@ using namespace std; struct Tree2StringGrammarNode { map<unsigned, Tree2StringGrammarNode> next; - string rules; + vector<TRulePtr> rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) { - unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots; +void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); @@ -28,32 +28,124 @@ void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGram unsigned xc = 0; while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); - Tree2StringGrammarNode* cur = &roots[rule_src.root]; - for (auto sym : rule_src) + Tree2StringGrammarNode* cur = root; + ostringstream os; + int lhs = -(rule_src.root & cdec::ALL_MASK); + // build source RHS for SCFG projection + // TODO - this is buggy - it will generate a well-formed SCFG rule + // but it will not generate source strings correctly + vector<int> frhs; + for (auto sym : rule_src) { cur = &cur->next[sym]; + if (sym) { + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } + } + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } pos += 3 + xc; while(line[pos] == ' ') { ++pos; } - size_t pos2 = line.find("|||", pos); - assert(pos2 != string::npos); - while (line[pos2 - 1] == ' ') { --pos2; } - cur->rules = line.substr(pos, pos2 - pos); - cerr << "OUTPUT = '" << cur->rules << "'\n"; + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + cur->rules.push_back(rule); } } +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const int rt) : + in_iter(it), + root_type(rt), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + root_type(p.root_type), + node(n) {} + vector<ParserState> future_work; + int root_type; // lhs of top level NT + Tree2StringGrammarNode* node; +}; + struct Tree2StringTranslatorImpl { - unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules + Tree2StringGrammarNode root; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { ReadFile rf(conf["grammar"].as<vector<string>>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + ReadTree2StringGrammar(rf.stream(), &root); } bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + const int kS = -TD::Convert("S"); + Hypergraph hg; + queue<ParserState> q; + q.push(ParserState(input_tree.begin(), &root, kS)); + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + cerr << "I traversed a subtree of the input...\n"; + if (s.node->rules.size()) { + // TODO: build hypergraph + for (auto& r : s.node->rules) + cerr << "I can build: " << r->AsString() << endl; + for (auto& w : s.future_work) + q.push(w); + } else { + cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState new_s(++var, &nit2->second, s); + ParserState new_work(s.in_iter.remainder(), &root, -(sym & cdec::ALL_MASK)); + new_s.future_work.push_back(new_work); // if this traversal of the input succeeds, future_work goes on the q + q.push(new_s); + } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + q.push(ParserState(++s.in_iter, &nit1->second, s)); + } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + q.push(ParserState(++s.in_iter, &nit->second, s)); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + minus_lm_forest->swap(hg); } }; |