#include <algorithm> #include <vector> #include <queue> #include <map> #include <unordered_map> #include <unordered_set> #include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> #include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" #include "sentence_metadata.h" #include "filelib.h" #include "stringlib.h" #include "tdict.h" #include "verbose.h" using namespace std; struct Tree2StringGrammarNode { map<unsigned, Tree2StringGrammarNode> next; vector<TRulePtr> rules; }; // this needs to be rewritten so it is fast and checks errors well // use a lexer probably static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; int lc = 0; while(getline(*in, line)) { ++lc; if (line.size() == 0 || line[0] == '#') continue; std::vector<StringPiece> fields = TokenizeMultisep(line, " ||| "); if (has_multiple_states && fields.size() < 4) { cerr << "Expected at least 4 fields in rule file but line " << lc << " is:\n" << line << endl; abort(); } if (!has_multiple_states && fields.size() < 3) { cerr << "Expected at least 3 fields in rule file but line " << lc << " is:\n" << line << endl; abort(); } cdec::TreeFragment rule_src(fields[has_multiple_states ? 1 : 0], true); // TODO transducer_state should be read from input const unsigned transducer_state = 0; Tree2StringGrammarNode* cur = &root->next[transducer_state]; ostringstream os; int lhs = -(rule_src.root & cdec::ALL_MASK); // build source RHS for SCFG projection vector<int> frhs; // we traverse the rule_src in left to right, DFS order for (auto sym : rule_src) { //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; cur = &cur->next[sym]; if (cdec::IsFrontier(sym)) { // frontier symbols -> variables int nt = (sym & cdec::ALL_MASK); frhs.push_back(-nt); } else if (cdec::IsTerminal(sym)) { frhs.push_back(sym); } // else internal NT, nothing to do } os << '[' << TD::Convert(-lhs) << "] |||"; for (auto x : frhs) { os << ' '; if (x < 0) os << '[' << TD::Convert(-x) << ']'; else os << TD::Convert(x); } TRulePtr rule; if (has_multiple_states) { cerr << "Not implemented...\n"; abort(); // TODO read in states } else { os << " ||| " << fields[1] << " ||| " << fields[2]; if (fields.size() > 3) os << " ||| " << fields[3]; rule.reset(new TRule(os.str())); } cur->rules.push_back(rule); //cerr << "RULE: " << rule->AsString() << "\n\n"; } } // represents where in an input parse tree the transducer must continue // and what state it is in struct TransducerState { TransducerState() : input_node_idx(), transducer_state() {} TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} bool operator==(const TransducerState& o) const { return input_node_idx == o.input_node_idx && transducer_state == o.transducer_state; } unsigned input_node_idx; unsigned transducer_state; }; // represents the state of the composition algorithm struct ParserState { ParserState() : in_iter(), node() {} cdec::TreeFragment::iterator in_iter; ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : in_iter(it), task(it.node_idx(), q), node(n) {} ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : in_iter(it), future_work(p.future_work), task(p.task), node(n) {} bool operator==(const ParserState& o) const { return node == o.node && task == o.task && future_work == o.future_work && in_iter == o.in_iter; } vector<TransducerState> future_work; TransducerState task; // subtree root where and in what state did the transducer start? Tree2StringGrammarNode* node; // pointer into grammar trie }; namespace std { template<> struct hash<TransducerState> { size_t operator()(const TransducerState& q) const { size_t h = boost::hash_value(q.transducer_state); boost::hash_combine(h, boost::hash_value(q.input_node_idx)); return h; } }; template<> struct hash<ParserState> { size_t operator()(const ParserState& s) const { size_t h = boost::hash_value(s.node); for (auto& w : s.future_work) boost::hash_combine(h, hash<TransducerState>()(w)); boost::hash_combine(h, hash<TransducerState>()(s.task)); // TODO hash with iterator return h; } }; }; void AddDummyGoalNode(Hypergraph* hg) { static const int kGOAL = -TD::Convert("Goal"); unsigned old_goal_node_idx = hg->nodes_.size() - 1; int old_goal_cat = hg->nodes_[old_goal_node_idx].cat_; TRulePtr goal_rule(new TRule("[Goal] ||| [X] ||| [1]")); goal_rule->f_[0] = old_goal_cat; HG::Node* goal_node = hg->AddNode(kGOAL); goal_node->node_hash = 1; TailNodeVector tail(1, old_goal_node_idx); HG::Edge* new_edge = hg->AddEdge(goal_rule, tail); hg->ConnectEdgeToHeadNode(new_edge, goal_node); } struct Tree2StringTranslatorImpl { vector<boost::shared_ptr<Tree2StringGrammarNode>> root; bool add_pass_through_rules; bool has_multiple_states; unsigned remove_grammars; Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, bool has_multiple_states) : add_pass_through_rules(conf.count("add_pass_through_rules")), has_multiple_states(has_multiple_states) { if (conf.count("grammar")) { const vector<string> gf = conf["grammar"].as<vector<string>>(); root.resize(gf.size()); unsigned gc = 0; for (auto& f : gf) { ReadFile rf(f); root[gc].reset(new Tree2StringGrammarNode); ReadTree2StringGrammar(rf.stream(), &*root[gc++], has_multiple_states); } } } // loads a per-sentence grammar void LoadSupplementalGrammar(const string& gfile) { root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; ReadFile rf(gfile); ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); } // src must be fully abstract bool DoesAbstractPassThroughRuleExist(unsigned state, const cdec::TreeFragment& src) const { unsigned len = root.size(); if (len <= 1) return false; --len; for (unsigned i = 0; i < len; ++i) { const Tree2StringGrammarNode* cur = &*root[i]; auto it = cur->next.find(state); if (it == cur->next.end()) continue; cur = &it->second; bool failed = false; vector<int> trg; for (auto sym : src) { it = cur->next.find(sym); if (it == cur->next.end()) { failed = true; break; } if (cdec::IsFrontier(sym)) trg.push_back(-trg.size()); cur = &it->second; } if (failed) continue; // TODO check for destination states in t2t for (auto r : cur->rules) if (r->e_ == trg) return true; } return false; } void CreatePassThroughRules(const cdec::TreeFragment& tree) { static const int kFIDlex = FD::Convert("PassThrough_Lexical"); static const int kFIDabs = FD::Convert("PassThrough_Abstract"); static const int kFIDmix = FD::Convert("PassThrough_Mix"); static const int kFID = FD::Convert("PassThrough"); static unordered_map<int, int> pntfid; root.resize(root.size() + 1); root.back().reset(new Tree2StringGrammarNode); ++remove_grammars; unordered_set<vector<int>,boost::hash<vector<int>>> unique_rule_check; for (auto& prod : tree.nodes) { int ntc = 0; int lhs = -(prod.lhs & cdec::ALL_MASK); int &ntfid = pntfid[lhs]; if (!ntfid) { ostringstream fos; fos << "PassThrough:" << TD::Convert(-lhs); ntfid = FD::Convert(fos.str()); } // check for duplicate rule in tree vector<int> key; key.push_back(prod.lhs); bool has_lex = false; bool has_nt = false; vector<int> rhse, rhsf; ostringstream os; os << '(' << TD::Convert(-lhs); for (auto& sym : prod.rhs) { os << ' '; if (cdec::IsTerminal(sym)) { has_lex = true; os << TD::Convert(sym); rhse.push_back(sym); rhsf.push_back(sym); key.push_back(sym); } else { has_nt = true; unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; os << '[' << TD::Convert(id) << ']'; rhsf.push_back(-id); rhse.push_back(-ntc); key.push_back(-id); ++ntc; } } os << ')'; if (!unique_rule_check.insert(key).second) continue; cdec::TreeFragment rule_src(os.str(), true); Tree2StringGrammarNode* cur = root.back().get(); // do we need all transducer states here??? a list??? no pass through rules??? unsigned transducer_state = 0; const bool abstract_rule = (has_nt && !has_lex); // the following reduces ambiguity quite a lot if (abstract_rule && DoesAbstractPassThroughRuleExist(transducer_state, rule_src)) continue; cur = &cur->next[transducer_state]; for (auto sym : rule_src) cur = &cur->next[sym]; TRulePtr rule(new TRule(rhse, rhsf, lhs)); rule->a_.push_back(AlignmentPoint(0, 0)); rule->ComputeArity(); rule->scores_.set_value(ntfid, 1.0); rule->scores_.set_value(kFID, 1.0); if (has_lex && has_nt) rule->scores_.set_value(kFIDmix, 1.0); else if (has_lex) rule->scores_.set_value(kFIDlex, 1.0); else if (has_nt) rule->scores_.set_value(kFIDabs, 1.0); cur->rules.push_back(rule); } } void RemoveGrammars() { assert(remove_grammars <= root.size()); root.resize(root.size() - remove_grammars); } bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); smeta->src_tree_ = input_tree; smeta->input_type_ = cdec::kTREE; if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); unordered_map<TransducerState, unsigned> x2hg(input_tree.nodes.size() * 5); queue<ParserState> q; unordered_set<ParserState> unique; // only create items one time for (auto& g : root) { unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported auto rit = g->next.find(q_0); if (rit != g->next.end()) { // does this g have this transducer state? q.push(ParserState(input_tree.begin(), q_0, &rit->second)); unique.insert(q.back()); } } if (q.size() == 0) return false; const TransducerState tree_top = q.front().task; while(!q.empty()) { ParserState& s = q.front(); if (s.in_iter.at_end()) { // completed a traversal of a subtree //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; if (s.node->rules.size()) { auto it = x2hg.find(s.task); if (it == x2hg.end()) { // TODO create composite state symbol that encodes transducer state type? HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); new_node->node_hash = std::hash<TransducerState>()(s.task); it = x2hg.insert(make_pair(s.task, new_node->id_)).first; } const unsigned node_id = it->second; TailNodeVector tail; for (const auto& n : s.future_work) { auto it = x2hg.find(n); if (it == x2hg.end()) { // TODO create composite state symbol that encodes transducer state type? HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); new_node->node_hash = std::hash<TransducerState>()(n); it = x2hg.insert(make_pair(n, new_node->id_)).first; } tail.push_back(it->second); } for (auto& r : s.node->rules) { assert(tail.size() == r->Arity()); HG::Edge* new_edge = hg.AddEdge(r, tail); new_edge->feature_values_ = r->GetFeatureValues(); auto& inspan = input_tree.nodes[s.task.input_node_idx].span; new_edge->i_ = inspan.first; new_edge->j_ = inspan.second; hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); } for (const auto& n : s.future_work) { const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n for (auto& g : root) { auto rit = g->next.find(n.transducer_state); if (rit != g->next.end()) { // does this g have this transducer state? const ParserState s(it, n.transducer_state, &rit->second); if (unique.insert(s).second) q.push(s); } } } } else { //cerr << "I can't build anything :(\n"; } } else { // more input tree to match unsigned sym = *s.in_iter; if (cdec::IsLHS(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ParserState news(++s.in_iter, &nit->second, s); if (unique.insert(news).second) q.push(news); } } else if (cdec::IsRHS(sym)) { //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; cdec::TreeFragment::iterator var = s.in_iter; var.truncate(); auto nit1 = s.node->next.find(sym); auto nit2 = s.node->next.find(*var); if (nit2 != s.node->next.end()) { //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; ++var; // TODO: find out from rule what the new target state is (the 0 in the next line) // if it is associated with the rule, we won't know until we match the whole input // so the 0 may be okay (if this is the case, which is probably the easiest thing, // then the state must be dealt with when the future work becomes real work) const TransducerState new_task(s.in_iter.child_node(), 0); ParserState new_s(var, &nit2->second, s); new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q if (unique.insert(new_s).second) q.push(new_s); } //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } if (nit1 != s.node->next.end()) { //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; const ParserState new_s(++s.in_iter, &nit1->second, s); if (unique.insert(new_s).second) q.push(new_s); } //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } } else if (cdec::IsTerminal(sym)) { auto nit = s.node->next.find(sym); if (nit != s.node->next.end()) { //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; const ParserState new_s(++s.in_iter, &nit->second, s); if (unique.insert(new_s).second) q.push(new_s); } } else { cerr << "This can never happen!\n"; abort(); } } q.pop(); } const auto goal_it = x2hg.find(tree_top); if (goal_it == x2hg.end()) return false; //cerr << "Goal node: " << goal << endl; hg.TopologicallySortNodesAndEdges(goal_it->second); // there might be nodes that cannot be derived // the following takes care of them vector<bool> prune(hg.edges_.size(), false); hg.PruneEdges(prune, true); if (hg.edges_.size() == 0) return false; // rescoring assumes the goal edge is arity 1 (code laziness), add that here AddDummyGoalNode(&hg); hg.Reweight(weights); //hg.PrintGraphviz(); minus_lm_forest->swap(hg); return true; } }; Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf, bool has_multiple_states) : pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { pimpl_->remove_grammars = 0; if (kv.find("grammar0") != kv.end()) { cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; abort(); } unsigned gc = 0; set<string> loaded; while(true) { string gkey = "grammar"; if (gc > 0) gkey += boost::lexical_cast<string>(gc); ++gc; map<string,string>::const_iterator it = kv.find(gkey); if (it == kv.end()) break; const string& gfile = it->second; if (loaded.count(gfile) == 1) { cerr << "Attempting to load " << gfile << " twice!\n"; abort(); } loaded.insert(gfile); pimpl_->LoadSupplementalGrammar(gfile); } } bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } void Tree2StringTranslator::SentenceCompleteImpl() { pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { return "tree2string"; }