diff options
Diffstat (limited to 'decoder/tree2string_translator.cc')
-rw-r--r-- | decoder/tree2string_translator.cc | 406 |
1 files changed, 371 insertions, 35 deletions
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index ac9c0d74..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,7 +1,11 @@ #include <algorithm> #include <vector> +#include <queue> +#include <map> +#include <unordered_set> +#include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> -#include <unordered_map> +#include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -13,60 +17,394 @@ using namespace std; -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// a implication: (S (A a *INCOMPLETE* [B]) -// a implication: (S (A a a *INCOMPLETE* [B]) -// *0* implication: (S (A a a) _[B]) -// D implication: (S (A a a) (B [D] *INCOMPLETE*) -// *0* implication: (S (A a a) (B _[D])) -// d implication: (S (A a a) (B (D d *INCOMPLETE*)) -// *0* implication: (S (A a a) (B (D d))) -// --there are no further outgoing links possible-- - -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// *0* implication: (S [A] _[B]) -// b implication: (S [A] (B b *INCOMPLETE*)) struct Tree2StringGrammarNode { map<unsigned, Tree2StringGrammarNode> next; - string rules; + vector<TRulePtr> rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) { - unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots; +// this needs to be rewritten so it is fast and checks errors well +// use a lexer probably +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); assert(pos != string::npos); assert(pos > 3); - if (line[pos - 1] == ' ') --pos; + unsigned xc = 0; + while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); + // TODO transducer_state should (optionally?) be read from input + const unsigned transducer_state = 0; + Tree2StringGrammarNode* cur = &root->next[transducer_state]; + ostringstream os; + int lhs = -(rule_src.root & cdec::ALL_MASK); + // build source RHS for SCFG projection + vector<int> frhs; + // we traverse the rule_src in left to right, DFS order + for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; + cur = &cur->next[sym]; + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } // else internal NT, nothing to do + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } + pos += 3 + xc; + while(line[pos] == ' ') { ++pos; } + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + // TODO the transducer_state you end up in after using this rule (for each NT) + // needs to be read and encoded somehow in the rule (for use XXX) + cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } +// represents where in an input parse tree the transducer must continue +// and what state it is in +struct TransducerState { + TransducerState() : input_node_idx(), transducer_state() {} + TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} + bool operator==(const TransducerState& o) const { + return input_node_idx == o.input_node_idx && + transducer_state == o.transducer_state; + } + unsigned input_node_idx; + unsigned transducer_state; +}; + +// represents the state of the composition algorithm +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : + in_iter(it), + task(it.node_idx(), q), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + task(p.task), + node(n) {} + bool operator==(const ParserState& o) const { + return node == o.node && task == o.task && + future_work == o.future_work && in_iter == o.in_iter; + } + vector<TransducerState> future_work; + TransducerState task; // subtree root where and in what state did the transducer start? + Tree2StringGrammarNode* node; // pointer into grammar trie +}; + +namespace std { + template<> + struct hash<TransducerState> { + size_t operator()(const TransducerState& q) const { + size_t h = boost::hash_value(q.transducer_state); + boost::hash_combine(h, boost::hash_value(q.input_node_idx)); + return h; + } + }; + template<> + struct hash<ParserState> { + size_t operator()(const ParserState& s) const { + size_t h = boost::hash_value(s.node); + for (auto& w : s.future_work) + boost::hash_combine(h, hash<TransducerState>()(w)); + boost::hash_combine(h, hash<TransducerState>()(s.task)); + // TODO hash with iterator + return h; + } + }; +}; + +void AddDummyGoalNode(Hypergraph* hg) { + static const int kGOAL = -TD::Convert("Goal"); + unsigned old_goal_node_idx = hg->nodes_.size() - 1; + int old_goal_cat = hg->nodes_[old_goal_node_idx].cat_; + TRulePtr goal_rule(new TRule("[Goal] ||| [X] ||| [1]")); + goal_rule->f_[0] = old_goal_cat; + HG::Node* goal_node = hg->AddNode(kGOAL); + goal_node->node_hash = 1; + TailNodeVector tail(1, old_goal_node_idx); + HG::Edge* new_edge = hg->AddEdge(goal_rule, tail); + hg->ConnectEdgeToHeadNode(new_edge, goal_node); +} + struct Tree2StringTranslatorImpl { - unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { - ReadFile rf(conf["grammar"].as<vector<string>>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + vector<boost::shared_ptr<Tree2StringGrammarNode>> root; + bool add_pass_through_rules; + bool has_multiple_states; + unsigned remove_grammars; + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + add_pass_through_rules(conf.count("add_pass_through_rules")), + has_multiple_states(has_multiple_states) { + if (conf.count("grammar")) { + const vector<string> gf = conf["grammar"].as<vector<string>>(); + root.resize(gf.size()); + unsigned gc = 0; + for (auto& f : gf) { + ReadFile rf(f); + root[gc].reset(new Tree2StringGrammarNode); + ReadTree2StringGrammar(rf.stream(), &*root[gc++], has_multiple_states); + } + } + } + + // loads a per-sentence grammar + void LoadSupplementalGrammar(const string& gfile) { + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + ReadFile rf(gfile); + ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); + } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFIDlex = FD::Convert("PassThrough_Lexical"); + static const int kFIDabs = FD::Convert("PassThrough_Abstract"); + static const int kFIDmix = FD::Convert("PassThrough_Mix"); + static const int kFID = FD::Convert("PassThrough"); + static unordered_map<int, int> pntfid; + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + unordered_set<vector<int>,boost::hash<vector<int>>> unique_rule_check; + for (auto& prod : tree.nodes) { + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + int &ntfid = pntfid[lhs]; + if (!ntfid) { + ostringstream fos; + fos << "PassThrough:" << TD::Convert(-lhs); + ntfid = FD::Convert(fos.str()); + } + + // check for duplicate rule in tree + vector<int> key; + key.push_back(prod.lhs); + + bool has_lex = false; + bool has_nt = false; + vector<int> rhse, rhsf; + ostringstream os; + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + has_lex = true; + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + key.push_back(sym); + } else { + has_nt = true; + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + key.push_back(-id); + ++ntc; + } + } + os << ')'; + if (!unique_rule_check.insert(key).second) continue; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + // do we need all transducer states here??? a list??? no pass through rules??? + unsigned transducer_state = 0; + cur = &cur->next[transducer_state]; + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(ntfid, 1.0); + rule->scores_.set_value(kFID, 1.0); + if (has_lex && has_nt) + rule->scores_.set_value(kFIDmix, 1.0); + else if (has_lex) rule->scores_.set_value(kFIDlex, 1.0); + else if (has_nt) rule->scores_.set_value(kFIDabs, 1.0); + cur->rules.push_back(rule); + } } + + void RemoveGrammars() { + assert(remove_grammars <= root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + if (add_pass_through_rules) CreatePassThroughRules(input_tree); + Hypergraph hg; + hg.ReserveNodes(input_tree.nodes.size()); + unordered_map<TransducerState, unsigned> x2hg(input_tree.nodes.size() * 5); + queue<ParserState> q; + unordered_set<ParserState> unique; // only create items one time + for (auto& g : root) { + unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported + auto rit = g->next.find(q_0); + if (rit != g->next.end()) { // does this g have this transducer state? + q.push(ParserState(input_tree.begin(), q_0, &rit->second)); + unique.insert(q.back()); + } + } + if (q.size() == 0) return false; + const TransducerState tree_top = q.front().task; + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << + // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; + if (s.node->rules.size()) { + auto it = x2hg.find(s.task); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(s.task); + it = x2hg.insert(make_pair(s.task, new_node->id_)).first; + } + const unsigned node_id = it->second; + TailNodeVector tail; + for (const auto& n : s.future_work) { + auto it = x2hg.find(n); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(n); + it = x2hg.insert(make_pair(n, new_node->id_)).first; + } + tail.push_back(it->second); + } + for (auto& r : s.node->rules) { + assert(tail.size() == r->Arity()); + HG::Edge* new_edge = hg.AddEdge(r, tail); + new_edge->feature_values_ = r->GetFeatureValues(); + // TODO: set i and j + hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); + } + for (const auto& n : s.future_work) { + const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n + for (auto& g : root) { + auto rit = g->next.find(n.transducer_state); + if (rit != g->next.end()) { // does this g have this transducer state? + const ParserState s(it, n.transducer_state, &rit->second); + if (unique.insert(s).second) q.push(s); + } + } + } + } else { + //cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState news(++s.in_iter, &nit->second, s); + if (unique.insert(news).second) q.push(news); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ++var; + // TODO: find out from rule what the new target state is (the 0 in the next line) + // if it is associated with the rule, we won't know until we match the whole input + // so the 0 may be okay (if this is the case, which is probably the easiest thing, + // then the state must be dealt with when the future work becomes real work) + const TransducerState new_task(s.in_iter.child_node(), 0); + ParserState new_s(var, &nit2->second, s); + new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q + if (unique.insert(new_s).second) q.push(new_s); + } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + const ParserState new_s(++s.in_iter, &nit1->second, s); + if (unique.insert(new_s).second) q.push(new_s); + } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + const ParserState new_s(++s.in_iter, &nit->second, s); + if (unique.insert(new_s).second) q.push(new_s); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + const auto goal_it = x2hg.find(tree_top); + if (goal_it == x2hg.end()) return false; + //cerr << "Goal node: " << goal << endl; + hg.TopologicallySortNodesAndEdges(goal_it->second); + + // there might be nodes that cannot be derived + // the following takes care of them + vector<bool> prune(hg.edges_.size(), false); + hg.PruneEdges(prune, true); + if (hg.edges_.size() == 0) return false; + // rescoring assumes the goal edge is arity 1 (code laziness), add that here + AddDummyGoalNode(&hg); + + hg.Reweight(weights); + //hg.PrintGraphviz(); + + minus_lm_forest->swap(hg); + return true; } }; -Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : - pimpl_(new Tree2StringTranslatorImpl(conf)) {} +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} + +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { + pimpl_->remove_grammars = 0; + if (kv.find("grammar0") != kv.end()) { + cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; + abort(); + } + unsigned gc = 0; + set<string> loaded; + while(true) { + string gkey = "grammar"; + if (gc > 0) gkey += boost::lexical_cast<string>(gc); + ++gc; + map<string,string>::const_iterator it = kv.find(gkey); + if (it == kv.end()) break; + const string& gfile = it->second; + if (loaded.count(gfile) == 1) { + cerr << "Attempting to load " << gfile << " twice!\n"; + abort(); + } + loaded.insert(gfile); + pimpl_->LoadSupplementalGrammar(gfile); + } +} bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, @@ -75,10 +413,8 @@ bool Tree2StringTranslator::TranslateImpl(const string& input, return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { -} - void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { |