diff options
author | Patrick Simianer <p@simianer.de> | 2014-06-12 13:56:42 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2014-06-12 13:56:42 +0200 |
commit | a39aa79b18347e22ef36ebc0da5a7eb220bcb23f (patch) | |
tree | 2c0f3009f8e381002bfeb82c0ea3bd0c41125761 /decoder | |
parent | 62bd9a4bdcea606d6ff2031fa4b207ef20caac31 (diff) | |
parent | 0e2f8d3d049f06afb08b4639c6a28aa5461cdc78 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'decoder')
36 files changed, 1308 insertions, 617 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 0b61c7cd..8e61c13e 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -4,9 +4,12 @@ noinst_PROGRAMS = \ trule_test \ hg_test \ parser_test \ + t2s_test \ grammar_test TESTS = trule_test parser_test grammar_test hg_test +t2s_test_SOURCES = t2s_test.cc +t2s_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a parser_test_SOURCES = parser_test.cc parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a grammar_test_SOURCES = grammar_test.cc @@ -37,7 +40,6 @@ libcdec_a_SOURCES = \ csplit.h \ decoder.h \ earley_composer.h \ - exp_semiring.h \ factored_lexicon_helper.h \ ff.h \ ff_basic.h \ @@ -143,6 +145,7 @@ libcdec_a_SOURCES = \ lattice.cc \ lexalign.cc \ lextrans.cc \ + node_state_hash.h \ tree_fragment.cc \ tree_fragment.h \ maxtrans_blunsom.cc \ diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9a8f60be..9f8bbead 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } #include <boost/functional/hash.hpp> +#include "node_state_hash.h" #include "verbose.h" #include "hg.h" #include "ff.h" @@ -229,7 +230,7 @@ public: D.clear(); } - void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; Candidate*& o_item = (*s2n)[item->state_]; @@ -238,6 +239,7 @@ public: int& node_id = o_item->node_index_; if (node_id < 0) { Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); + new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state node_states_.push_back(item->state_); node_id = new_node->id_; } @@ -287,7 +289,7 @@ public: cand.pop_back(); // cerr << "POPPED: " << *item << endl; PushSucc(*item, is_goal, &cand, &unique_cands); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); ++pops; } D_v.resize(state2node.size()); @@ -306,112 +308,112 @@ public: } void KBestFast(const int vert_index, const bool is_goal) { - // cerr << "KBest(" << vert_index << ")\n"; - CandidateList& D_v = D[vert_index]; - assert(D_v.empty()); - const Hypergraph::Node& v = in.nodes_[vert_index]; - // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; - const vector<int>& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); - // cerr << "POPPED: " << *item << endl; - - PushSuccFast(*item, is_goal, &cand); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; - // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; - //#endif - sort(D_v.begin(), D_v.end(), EstProbSorter()); - - // cerr << " expanded to " << D_v.size() << " nodes\n"; - - for (int i = 0; i < cand.size(); ++i) - delete cand[i]; - // freelist is necessary since even after an item merged, it still stays in - // the unique set so it can't be deleted til now - for (int i = 0; i < freelist.size(); ++i) - delete freelist[i]; + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (auto& i : state2node) { + D_v[c++] = i.second; + // cerr << "MERGED: " << *i.second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; } void KBestFast2(const int vert_index, const bool is_goal) { - // cerr << "KBest(" << vert_index << ")\n"; - CandidateList& D_v = D[vert_index]; - assert(D_v.empty()); - const Hypergraph::Node& v = in.nodes_[vert_index]; - // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; - const vector<int>& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - UniqueCandidateSet unique_accepted; - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); bool is_new = unique_accepted.insert(item).second; - assert(is_new); // these should all be unique! - // cerr << "POPPED: " << *item << endl; - - PushSuccFast2(*item, is_goal, &cand, &unique_accepted); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; - // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; - //#endif - sort(D_v.begin(), D_v.end(), EstProbSorter()); - - // cerr << " expanded to " << D_v.size() << " nodes\n"; - - for (int i = 0; i < cand.size(); ++i) - delete cand[i]; - // freelist is necessary since even after an item merged, it still stays in - // the unique set so it can't be deleted til now - for (int i = 0; i < freelist.size(); ++i) - delete freelist[i]; + assert(is_new); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; } void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { @@ -434,50 +436,50 @@ public: //PushSucc following unique ancestor generation function void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){ - CandidateHeap& cand = *pcand; - for (int i = 0; i < item.j_.size(); ++i) { - JVector j = item.j_; - ++j[i]; - if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - if(item.j_[i]!=0){ - return; - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } } //PushSucc only if all ancest Cand are added void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ - CandidateHeap& cand = *pcand; - for (int i = 0; i < item.j_.size(); ++i) { - JVector j = item.j_; - ++j[i]; - if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { - Candidate query_unique(*item.in_edge_, j); - if (HasAllAncestors(&query_unique,ps)) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } } bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ - for (int i = 0; i < item->j_.size(); ++i) { - JVector j = item->j_; - --j[i]; - if (j[i] >=0) { - Candidate query_unique(*item->in_edge_, j); - if (cs->count(&query_unique) == 0) { - return false; - } - } - } - return true; + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; } const ModelSet& models; @@ -491,7 +493,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; - const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -507,7 +509,7 @@ struct NoPruningRescorer { typedef unordered_map<FFState, int, boost::hash<FFState> > State2NodeIndex; - void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) { const int arity = in_edge.Arity(); Hypergraph::TailNodeVector ends(arity); for (int i = 0; i < arity; ++i) @@ -531,7 +533,9 @@ struct NoPruningRescorer { } int& head_plus1 = (*state2node)[head_state]; if (!head_plus1) { - head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1; + HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); + new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state + head_plus1 = new_node->id_ + 1; node_states_.push_back(head_state); nodemap[in_edge.head_node_].push_back(head_plus1 - 1); } @@ -553,7 +557,7 @@ struct NoPruningRescorer { const Hypergraph::Node& node = in.nodes_[node_num]; for (int i = 0; i < node.in_edges_.size(); ++i) { const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; - ExpandEdge(edge, is_goal, &state2node); + ExpandEdge(edge, is_goal, node.node_hash, &state2node); } } @@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in, cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } if (config.algorithm == IntersectionConfiguration::CUBE) { - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); } } else { diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 8738c8f1..b30f1ec6 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -7,6 +7,8 @@ #include <iostream> #include <map> +#include "node_state_hash.h" +#include "nt_span.h" #include "hg.h" #include "array2d.h" #include "tdict.h" @@ -159,7 +161,7 @@ PassiveChart::PassiveChart(const string& goal, chart_(input.size()+1, input.size()+1), nodemap_(input.size()+1, input.size()+1), goal_cat_(TD::Convert(goal) * -1), - goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), lc_fid_(FD::Convert("LatticeCost")), unaries_() { @@ -356,5 +358,13 @@ bool ExhaustiveBottomUpParser::Parse(const Lattice& input, kEPS = TD::Convert("*EPS*"); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } return result; } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 31049216..6783cad0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -490,8 +490,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } formalism = LowercaseString(str("formalism",conf)); - if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; + if (formalism != "t2s" && formalism != "t2t" && formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 't2s', 't2t', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -626,6 +626,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); + else if (formalism == "t2s") + translator.reset(new Tree2StringTranslator(conf, false)); + else if (formalism == "t2t") + translator.reset(new Tree2StringTranslator(conf, true)); else if (formalism == "fst") translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") @@ -748,6 +752,16 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { return false; } + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + + if (!forest.ArePreGoalEdgesArity1()) { + cerr << "Pre-goal edges are not arity-1. The decoder requires this.\n"; + abort(); + } + const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { @@ -811,6 +825,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } } if (conf.count("show_partition")) { @@ -982,6 +1000,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; } forest.Reweight(last_weights); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (conf.count("show_partition")) { diff --git a/decoder/exp_semiring.h b/decoder/exp_semiring.h deleted file mode 100644 index 2a9034bb..00000000 --- a/decoder/exp_semiring.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef _EXP_SEMIRING_H_ -#define _EXP_SEMIRING_H_ - -#include <iostream> - -// this file implements the first-order expectation semiring described -// in Li & Eisner (EMNLP 2009) - -// requirements: -// RType * RType ==> RType -// PType * PType ==> PType -// RType * PType ==> RType -// good examples: -// PType scalar, RType vector -// BAD examples: -// PType vector, RType scalar -template <class PType, class RType> -struct PRPair { - PRPair() : p(), r() {} - // Inside algorithm requires that T(0) and T(1) - // return the 0 and 1 values of the semiring - explicit PRPair(double x) : p(x), r() {} - PRPair(const PType& p, const RType& r) : p(p), r(r) {} - PRPair& operator+=(const PRPair& o) { - p += o.p; - r += o.r; - return *this; - } - PRPair& operator*=(const PRPair& o) { - r = (o.r * p) + (o.p * r); - p *= o.p; - return *this; - } - PType p; - RType r; -}; - -template <class P, class R> -std::ostream& operator<<(std::ostream& o, const PRPair<P,R>& x) { - return o << '<' << x.p << ", " << x.r << '>'; -} - -template <class P, class R> -const PRPair<P,R> operator+(const PRPair<P,R>& a, const PRPair<P,R>& b) { - PRPair<P,R> result = a; - result += b; - return result; -} - -template <class P, class R> -const PRPair<P,R> operator*(const PRPair<P,R>& a, const PRPair<P,R>& b) { - PRPair<P,R> result = a; - result *= b; - return result; -} - -template <class P, class PWeightFunction, class R, class RWeightFunction> -struct PRWeightFunction { - explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), - const RWeightFunction& rwf = RWeightFunction()) : - pweight(pwf), rweight(rwf) {} - PRPair<P,R> operator()(const HG::Edge& e) const { - const P p = pweight(e); - const R r = rweight(e); - return PRPair<P,R>(p, r * p); - } - const PWeightFunction pweight; - const RWeightFunction rweight; -}; - -#endif diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 074de4c9..4253b652 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -67,6 +67,12 @@ struct FSTTranslatorImpl { Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); forest->ConnectEdgeToHeadNode(hg_edge, goal); forest->Reweight(weights); + + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; } if (add_pass_through_rules) fst->ClearPassThroughTranslations(); diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 6d2c6e67..69240139 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -33,9 +33,9 @@ BOOST_AUTO_TEST_CASE(TestTextGrammar) { ModelSet models(w, ms); TextGrammar g; - TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true)); - TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true)); - TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true)); + TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3")); + TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4")); + TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3")); cerr << r1->AsString() << endl; g.AddRule(r1); g.AddRule(r2); diff --git a/decoder/hg.cc b/decoder/hg.cc index 7240a8ab..46543b01 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -1,14 +1,17 @@ -//TODO: lazily generate feature vectors for hyperarcs (because some of them will be pruned). this means 1) storing ref to rule for those features 2) providing ff interface for regenerating its feature vector from hyperedge+states and probably 3) still caching feat. vect on hyperedge once it's been generated. ff would normally just contribute its weighted score and result state, not component features. however, the hypergraph drops the state used by ffs after rescoring is done, so recomputation would have to start at the leaves and work bottom up. question: which takes more space, feature id+value, or state? - #include "hg.h" #include <algorithm> #include <cassert> #include <numeric> -#include <set> #include <map> #include <iostream> #include <sstream> +#ifndef HAVE_OLD_CPP +# include <unordered_set> +#else +# include <tr1/unordered_set> +namespace std { using std::tr1::unordered_set; } +#endif #include "viterbi.h" #include "inside_outside.h" @@ -17,28 +20,28 @@ using namespace std; -#if 0 -Hypergraph::Edge const* Hypergraph::ViterbiGoalEdge() const -{ - Edge const* r=0; - for (unsigned i=0,e=edges_.size();i<e;++i) { - Edge const& e=edges_[i]; - if (e.rule_ && e.rule_->IsGoal() && (!r || e.edge_prob_ > r->edge_prob_)) - r=&e; - } - return r; +bool Hypergraph::AreNodesUniquelyIdentified() const { + unordered_set<size_t> s(nodes_.size() * 3 + 7); + for (const auto& n : nodes_) + if (!s.insert(n.node_hash).second) + return false; + return true; } -#endif -Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() -{ +bool Hypergraph::ArePreGoalEdgesArity1() const { + auto& n = nodes_.back(); + for (auto eid : n.in_edges_) + if (edges_[eid].Arity() != 1) return false; + return true; +} + +Hypergraph::Edge const* Hypergraph::ViterbiSortInEdges() { NodeProbs nv; ComputeNodeViterbi(&nv); return SortInEdgesByNodeViterbi(nv); } -Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) -{ +Hypergraph::Edge const* Hypergraph::SortInEdgesByNodeViterbi(NodeProbs const& nv) { EdgeProbs ev; ComputeEdgeViterbi(nv,&ev); return ViterbiSortInEdges(ev); @@ -375,9 +378,7 @@ bool Hypergraph::PruneInsideOutside(double alpha,double density,const EdgeMask* void Hypergraph::PrintGraphviz() const { int ei = 0; cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n"; - for (vector<Edge>::const_iterator i = edges_.begin(); - i != edges_.end(); ++i) { - const Edge& edge=*i; + for (const auto& edge : edges_) { ++ei; static const string none = "<null>"; string rule = (edge.rule_ ? edge.rule_->AsString(false) : none); @@ -399,14 +400,10 @@ void Hypergraph::PrintGraphviz() const { } cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n"; } - for (vector<Node>::const_iterator ni = nodes_.begin(); - ni != nodes_.end(); ++ni) { - cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "") - //cerr << " " << ni->id_ << "[label=\"" << ni->cat_ - << " n=" << ni->id_ -// << ",x=" << &*ni -// << ",in=" << ni->in_edges_.size() -// << ",out=" << ni->out_edges_.size() + for (const auto& node : nodes_) { + cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "") + << " n=" << node.id_ + << " h=" << node.node_hash << "\"];\n"; } cerr << "}\n"; diff --git a/decoder/hg.h b/decoder/hg.h index 3d8cd9bc..4ed27d87 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -25,6 +25,7 @@ #include "tdict.h" #include "trule.h" #include "prob.h" +#include "exp_semiring.h" #include "indices_after.h" #include "nt_span.h" @@ -141,13 +142,15 @@ namespace HG { // TODO get rid of cat_? // TODO keep cat_ and add span and/or state? :) struct Node { - Node() : id_(), cat_() {} + Node() : node_hash(), id_(), cat_() {} + size_t node_hash; // hash of all the information that makes this node unique int id_; // equal to this object's position in the nodes_ vector WordID cat_; // non-terminal category if <0, 0 if not set WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting + node_hash = o.node_hash; cat_=o.cat_; } void copy_reindex(Node const& o,indices_after const& n2,indices_after const& e2) { @@ -191,13 +194,14 @@ public: SetNodeOrigin(nodeid,r); return r; } - Span NodeSpan(int nodeid) const { + Span NodeSpan(int nodeid, Span* prev = nullptr) const { Span s; Node const &n=nodes_[nodeid]; if (!n.in_edges_.empty()) { Edge const& e=edges_[n.in_edges_.front()]; s.l=e.i_; s.r=e.j_; + if (prev) { prev->l = e.prev_i_; prev->r = e.prev_j_; } } return s; } @@ -261,6 +265,13 @@ public: for (int i = 0; i < size; ++i) nodes_[i].id_ = i; } + // if all node states are unique, return true + bool AreNodesUniquelyIdentified() const; + + // the feature function interface assumes that pre-goal edges are + // arity 1 (this simplifies the "final transition" feature computation) + bool ArePreGoalEdgesArity1() const; + // reserves space in the nodes vector to prevent memory locations // from changing void ReserveNodes(size_t n, size_t e = 0) { @@ -527,7 +538,21 @@ struct EdgeFeaturesAndProbWeightFunction { struct TransitionCountWeightFunction { typedef double Weight; - inline double operator()(const HG::Edge& e) const { (void)e; return 1.0; } + inline double operator()(const HG::Edge&) const { return 1.0; } +}; + +template <class P, class PWeightFunction, class R, class RWeightFunction> +struct PRWeightFunction { + explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), + const RWeightFunction& rwf = RWeightFunction()) : + pweight(pwf), rweight(rwf) {} + PRPair<P,R> operator()(const HG::Edge& e) const { + const P p = pweight(e); + const R r = rweight(e); + return PRPair<P,R>(p, r * p); + } + const PWeightFunction pweight; + const RWeightFunction rweight; }; #endif diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 31a9a1ce..02f5a401 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -92,7 +92,7 @@ bool Intersect(const Lattice& target, Hypergraph* hg) { return FastLinearIntersect(target, hg); vector<bool> rem(hg->edges_.size(), false); - const RuleFilter filter(target, 15); // TODO make configurable + const RuleFilter filter(target, 9999); // TODO make configurable for (unsigned i = 0; i < rem.size(); ++i) rem[i] = filter(*hg->edges_[i].rule_); hg->PruneEdges(rem, true); diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 64c6663e..eb0be3d4 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -1,5 +1,7 @@ #include "hg_io.h" +#include <cstdio> +#include <cstdlib> #include <fstream> #include <sstream> #include <iostream> @@ -15,10 +17,15 @@ using namespace std; struct HGReader : public JSONParser { HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - void CreateNode(const string& cat, const vector<int>& in_edges) { + void CreateNode(const string& cat, const string& shash, const vector<int>& in_edges) { WordID c = TD::Convert("X") * -1; if (!cat.empty()) c = TD::Convert(cat) * -1; Hypergraph::Node* node = hg.AddNode(c); + char* dend; + if (shash.size()) + node->node_hash = strtoull(shash.c_str(), &dend, 16); + else + node->node_hash = 0; for (int i = 0; i < in_edges.size(); ++i) { if (in_edges[i] >= hg.edges_.size()) { cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] @@ -102,17 +109,19 @@ struct HGReader : public JSONParser { ++nodes; in_edges.clear(); cat.clear(); + shash.clear(); state = 9; break; case 9: if (type == JSON_T_OBJECT_END) { //cerr << "Creating NODE\n"; - CreateNode(cat, in_edges); + CreateNode(cat, shash, in_edges); state = 0; break; } assert(type == JSON_T_KEY); cur_key = value->vu.str.value; if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } + if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; return false; case 10: @@ -224,6 +233,12 @@ struct HGReader : public JSONParser { assert(spanc < 4); spans[spanc] = value->vu.integer_value; ++spanc; + break; + case 24: // read node hash + assert(type == JSON_T_STRING); + shash = value->vu.str.value; + state = 9; + break; } return true; } @@ -231,6 +246,7 @@ struct HGReader : public JSONParser { string cat; SmallVectorUnsigned tail; vector<int> in_edges; + string shash; TRulePtr cur_rule; map<int, TRulePtr> rules; vector<int> fdict; @@ -340,6 +356,9 @@ bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* o << ",\"cat\":"; JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); } + char buf[48]; + sprintf(buf, "%016lX", node.node_hash); + o << ",\"node_hash\":\"" << buf << "\""; o << "}"; } o << "}\n"; diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 8519e559..5cb8626a 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -18,8 +18,10 @@ using namespace std; BOOST_FIXTURE_TEST_SUITE( s, HGSetup ); BOOST_AUTO_TEST_CASE(Controlled) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + cerr << "PATH: " << path << "/hg.tiny\n"; Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -37,10 +39,18 @@ BOOST_AUTO_TEST_CASE(Controlled) { } BOOST_AUTO_TEST_CASE(Union) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg1; Hypergraph hg2; - CreateHG_tiny(&hg1); - CreateHG(&hg2); + CreateHG_tiny(path, &hg1); + CreateHG(path, &hg2); + int nc = 0; + for (auto& node: hg1.nodes_) + node.node_hash = ++nc; + for (auto& node: hg2.nodes_) + node.node_hash = ++nc; + hg1.nodes_.back().node_hash = nc; + SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -53,8 +63,11 @@ BOOST_AUTO_TEST_CASE(Union) { int l2 = ViterbiPathLength(hg2); cerr << c1 << "\t" << TD::GetString(t1) << endl; cerr << c2 << "\t" << TD::GetString(t2) << endl; + hg1.PrintGraphviz(); + hg2.PrintGraphviz(); HG::Union(hg2, &hg1); hg1.Reweight(wts); + hg1.PrintGraphviz(); c3 = ViterbiESentence(hg1, &t3); int l3 = ViterbiPathLength(hg1); cerr << c3 << "\t" << TD::GetString(t3) << endl; @@ -81,11 +94,19 @@ BOOST_AUTO_TEST_CASE(Union) { BOOST_CHECK_CLOSE(log(list[0].second), log(c4), 1e-4); BOOST_CHECK_EQUAL(list.size(), 6); BOOST_CHECK_CLOSE(log(list.back().second / list.front().second), -97.7, 1e-4); + hg1 = hg2; + BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size()); + BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size()); + HG::Union(hg1, &hg2); // this should be a no-op + BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size()); + BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size()); + cerr << "DONE UNION\n"; } BOOST_AUTO_TEST_CASE(ControlledKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector<double> w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector<WordID> trans; @@ -107,10 +128,11 @@ BOOST_AUTO_TEST_CASE(ControlledKBest) { BOOST_AUTO_TEST_CASE(InsideScore) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 1.0); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); hg.Reweight(wts); vector<WordID> trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -130,10 +152,11 @@ BOOST_AUTO_TEST_CASE(InsideScore) { BOOST_AUTO_TEST_CASE(PruneInsideOutside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); SparseVector<double> wts; wts.set_value(FD::Convert("Feature_1"), 1.0); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); hg.Reweight(wts); vector<WordID> trans; prob_t cost = ViterbiESentence(hg, &trans); @@ -152,8 +175,9 @@ BOOST_AUTO_TEST_CASE(PruneInsideOutside) { } BOOST_AUTO_TEST_CASE(TestPruneEdges) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -166,8 +190,9 @@ BOOST_AUTO_TEST_CASE(TestPruneEdges) { } BOOST_AUTO_TEST_CASE(TestIntersect) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -192,8 +217,9 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { } BOOST_AUTO_TEST_CASE(TestPrune2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_int(&hg); + CreateHG_int(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -207,8 +233,9 @@ BOOST_AUTO_TEST_CASE(TestPrune2) { } BOOST_AUTO_TEST_CASE(Sample) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateLatticeHG(&hg); + CreateLatticeHG(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("Feature_1"), 0.0); hg.Reweight(wts); @@ -220,6 +247,7 @@ BOOST_AUTO_TEST_CASE(Sample) { } BOOST_AUTO_TEST_CASE(PLF) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); @@ -234,8 +262,9 @@ BOOST_AUTO_TEST_CASE(PLF) { } BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); vector<double> w(2); w[0]=0.4; w[1]=0.8; hg.Reweight(w); vector<WordID> trans; @@ -248,8 +277,9 @@ BOOST_AUTO_TEST_CASE(PushWeightsToGoal) { } BOOST_AUTO_TEST_CASE(TestSpecialKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHGBalanced(&hg); + CreateHGBalanced(path, &hg); vector<double> w(1); w[0]=0; hg.Reweight(w); vector<pair<vector<WordID>, prob_t> > list; @@ -264,8 +294,9 @@ BOOST_AUTO_TEST_CASE(TestSpecialKBest) { } BOOST_AUTO_TEST_CASE(TestGenericViterbi) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG_tiny(&hg); + CreateHG_tiny(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -279,8 +310,9 @@ BOOST_AUTO_TEST_CASE(TestGenericViterbi) { } BOOST_AUTO_TEST_CASE(TestGenericInside) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateTinyLatticeHG(&hg); + CreateTinyLatticeHG(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 1.0); hg.Reweight(wts); @@ -296,8 +328,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside) { } BOOST_AUTO_TEST_CASE(TestGenericInside2) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -322,8 +355,9 @@ BOOST_AUTO_TEST_CASE(TestGenericInside2) { } BOOST_AUTO_TEST_CASE(TestAddExpectations) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); + CreateHG(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 0.8); @@ -338,8 +372,8 @@ BOOST_AUTO_TEST_CASE(TestAddExpectations) { } BOOST_AUTO_TEST_CASE(Small) { - Hypergraph hg; std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + Hypergraph hg; CreateSmallHG(&hg, path); SparseVector<double> wts; wts.set_value(FD::Convert("Model_0"), -2.0); @@ -361,6 +395,7 @@ BOOST_AUTO_TEST_CASE(Small) { } BOOST_AUTO_TEST_CASE(JSONTest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); ostringstream os; JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); @@ -370,9 +405,10 @@ BOOST_AUTO_TEST_CASE(JSONTest) { } BOOST_AUTO_TEST_CASE(TestGenericKBest) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - CreateHG(&hg); - //CreateHGBalanced(&hg); + CreateHG(path, &hg); + //CreateHGBalanced(path, &hg); SparseVector<double> wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -392,8 +428,9 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } BOOST_AUTO_TEST_CASE(TestReadWriteHG) { + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg,hg2; - CreateHG(&hg); + CreateHG(path, &hg); hg.edges_.front().j_ = 23; hg.edges_.back().prev_i_ = 99; ostringstream os; diff --git a/decoder/hg_test.h b/decoder/hg_test.h index e96cb0b1..b7bab3c2 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -23,25 +23,13 @@ Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 // you can inherit from this or just use the static methods struct HGSetup { - enum { - HG, - HG_int, - HG_tiny, - HGBalanced, - LatticeHG, - TinyLatticeHG, - }; - static void CreateHG(Hypergraph* hg); - static void CreateHG_int(Hypergraph* hg); - static void CreateHG_tiny(Hypergraph* hg); - static void CreateHGBalanced(Hypergraph* hg); - static void CreateLatticeHG(Hypergraph* hg); - static void CreateTinyLatticeHG(Hypergraph* hg); - - static void Json(Hypergraph *hg,std::string const& json) { - std::istringstream i(json); - HypergraphIO::ReadFromJSON(&i, hg); - } + static void CreateHG(const std::string& path,Hypergraph* hg); + static void CreateHG_int(const std::string& path,Hypergraph* hg); + static void CreateHG_tiny(const std::string& path, Hypergraph* hg); + static void CreateHGBalanced(const std::string& path,Hypergraph* hg); + static void CreateLatticeHG(const std::string& path,Hypergraph* hg); + static void CreateTinyLatticeHG(const std::string& path,Hypergraph* hg); + static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); HypergraphIO::ReadFromJSON(rf.stream(), hg); @@ -52,18 +40,6 @@ struct HGSetup { static void CreateSmallHG(Hypergraph *hg, std::string path) { JsonTestFile(hg,path,small_json); } }; -namespace { -Name HGjsons[]= { - "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}", -"{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}", - "{\"rules\":[1,\"[X] ||| <s>\",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}", - "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}", - "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}", -}; - -} - void AddNullEdge(Hypergraph* hg) { TRule x; x.arity_ = 0; @@ -71,31 +47,36 @@ void AddNullEdge(Hypergraph* hg) { hg->edges_.back().head_node_ = 0; } -void HGSetup::CreateTinyLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[TinyLatticeHG]); +void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny_lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateLatticeHG(Hypergraph* hg) { - Json(hg,HGjsons[LatticeHG]); +void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.lattice"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); AddNullEdge(hg); } -void HGSetup::CreateHG_tiny(Hypergraph* hg) { - Json(hg,HGjsons[HG_tiny]); +void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { + ReadFile rf(path + "/hg_test.tiny"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG_int(Hypergraph* hg) { - Json(hg,HGjsons[HG_int]); +void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_int"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHG(Hypergraph* hg) { - Json(hg,HGjsons[HG]); +void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } -void HGSetup::CreateHGBalanced(Hypergraph* hg) { - Json(hg,HGjsons[HGBalanced]); +void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { + ReadFile rf(path + "/hg_test.hg_balanced"); + HypergraphIO::ReadFromJSON(rf.stream(), hg); } - #endif diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc index 37082976..a659b6bc 100644 --- a/decoder/hg_union.cc +++ b/decoder/hg_union.cc @@ -1,56 +1,104 @@ #include "hg_union.h" +#ifndef HAVE_OLD_CPP +# include <unordered_map> +#else +# include <tr1/unordered_map> +namespace std { using std::tr1::unordered_set; } +#endif + +#include "verbose.h" #include "hg.h" +#include "sparse_vector.h" using namespace std; namespace HG { +static bool EdgesMatch(const HG::Edge& a, const Hypergraph& ahg, const HG::Edge& b, const Hypergraph& bhg) { + const unsigned arity = a.tail_nodes_.size(); + if (arity != b.tail_nodes_.size()) return false; + if (a.rule_->e() != b.rule_->e()) return false; + if (a.rule_->f() != b.rule_->f()) return false; + + for (unsigned i = 0; i < arity; ++i) + if (ahg.nodes_[a.tail_nodes_[i]].node_hash != bhg.nodes_[b.tail_nodes_[i]].node_hash) return false; + const SparseVector<double> diff = a.feature_values_ - b.feature_values_; + for (auto& kv : diff) + if (fabs(kv.second) > 1e-6) return false; + return true; +} + void Union(const Hypergraph& in, Hypergraph* out) { if (&in == out) return; if (out->nodes_.empty()) { out->nodes_ = in.nodes_; out->edges_ = in.edges_; return; } - unsigned noff = out->nodes_.size(); - unsigned eoff = out->edges_.size(); - int ogoal = in.nodes_.size() - 1; - int cgoal = noff - 1; - // keep a single goal node, so add nodes.size - 1 - out->nodes_.resize(out->nodes_.size() + ogoal); - // add all edges - out->edges_.resize(out->edges_.size() + in.edges_.size()); - - for (int i = 0; i < ogoal; ++i) { - const Hypergraph::Node& on = in.nodes_[i]; - Hypergraph::Node& cn = out->nodes_[i + noff]; - cn.id_ = i + noff; - cn.in_edges_.resize(on.in_edges_.size()); - for (unsigned j = 0; j < on.in_edges_.size(); ++j) - cn.in_edges_[j] = on.in_edges_[j] + eoff; - - cn.out_edges_.resize(on.out_edges_.size()); - for (unsigned j = 0; j < on.out_edges_.size(); ++j) - cn.out_edges_[j] = on.out_edges_[j] + eoff; + if (!in.AreNodesUniquelyIdentified()) { + cerr << "Union: Nodes are not uniquely identified in input!\n"; + abort(); + } + if (!out->AreNodesUniquelyIdentified()) { + cerr << "Union: Nodes are not uniquely identified in output!\n"; + abort(); } + if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) { + cerr << "Union: Goal nodes are mismatched!\n a=" << in.nodes_.back().node_hash << " b=" << out->nodes_.back().node_hash << "\n"; + abort(); + } + const int cgoal = out->nodes_.back().id_; - for (unsigned i = 0; i < in.edges_.size(); ++i) { - const Hypergraph::Edge& oe = in.edges_[i]; - Hypergraph::Edge& ce = out->edges_[i + eoff]; - ce.id_ = i + eoff; - ce.rule_ = oe.rule_; - ce.feature_values_ = oe.feature_values_; - if (oe.head_node_ == ogoal) { - ce.head_node_ = cgoal; - out->nodes_[cgoal].in_edges_.push_back(ce.id_); - } else { - ce.head_node_ = oe.head_node_ + noff; + unordered_map<size_t, unsigned> h2n; + for (const auto& node : out->nodes_) + h2n[node.node_hash] = node.id_; + for (const auto& node : in.nodes_) { + if (h2n.count(node.node_hash) == 0) { + HG::Node* new_node = out->AddNode(node.cat_); + new_node->node_hash = node.node_hash; + h2n[node.node_hash] = new_node->id_; } - ce.tail_nodes_.resize(oe.tail_nodes_.size()); - for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j) - ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff; } + double n_exists = 0; + double n_created = 0; + for (const auto& in_node : in.nodes_) { + HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]]; + for (const auto oeid : out_node.in_edges_) { + // TODO hash currently existing edges for quick check for duplication + } + for (const auto ieid : in_node.in_edges_) { + const HG::Edge& in_edge = in.edges_[ieid]; + // TODO: replace slow N^2 check with hashing + bool edge_exists = false; + for (const auto oeid : out_node.in_edges_) { + if (EdgesMatch(in_edge, in, out->edges_[oeid], *out)) { + edge_exists = true; + break; + } + } + if (!edge_exists) { + const unsigned arity = in_edge.tail_nodes_.size(); + TailNodeVector t(arity); + HG::Node& head = out->nodes_[h2n[in_node.node_hash]]; + for (unsigned i = 0; i < arity; ++i) + t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash]; + HG::Edge* new_edge = out->AddEdge(in_edge, t); + out->ConnectEdgeToHeadNode(new_edge, &head); + ++n_created; + //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n"; + } else { + ++n_exists; + } + // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n"; + //} + } + } + if (!SILENT) + cerr << " Union: edges_created=" << n_created + << " edges_already_existing=" + << n_exists << " ratio_new=" << (n_created / (n_exists + n_created)) + << endl; out->TopologicallySortNodesAndEdges(cgoal); } diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 6adb1892..11f20de7 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -124,6 +124,11 @@ bool LexicalAlign::TranslateImpl(const string& input, pimpl_->BuildTrellis(lattice, *smeta, forest); forest->is_linear_chain_ = true; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 8c3269bf..74a18c3f 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -280,6 +280,11 @@ bool LexicalTrans::TranslateImpl(const string& input, smeta->SetSourceLength(lattice.size()); if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false; forest->Reweight(weights); + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h new file mode 100644 index 00000000..9fc01a09 --- /dev/null +++ b/decoder/node_state_hash.h @@ -0,0 +1,42 @@ +#ifndef _NODE_STATE_HASH_ +#define _NODE_STATE_HASH_ + +#include <cassert> +#include <cstring> +#include "tdict.h" +#include "murmur_hash3.h" +#include "ffset.h" + +namespace cdec { + + struct FirstPassNode { + FirstPassNode(int cat, int i, int j, int pi, int pj) : s(i), t(j), u(pi), v(pj) { + memset(lhs, 0, 120); + unsigned it = 0; + for (auto& c : TD::Convert(-cat)) { lhs[it++] = c; if (it == 120) break; } + } + char lhs[120]; + short s; + short t; + short u; + short v; + }; + + inline uint64_t HashNode(int cat, int i, int j, int pi, int pj) { + FirstPassNode fpn(cat, i, j, pi, pj); + return MurmurHash3_64(&fpn, sizeof(FirstPassNode), 2654435769U); + } + + inline uint64_t HashNode(uint64_t old_hash, const FFState& state) { + if (state.size() == 0) return old_hash; + uint8_t buf[1024]; + std::memcpy(buf, &old_hash, sizeof(uint64_t)); + assert(state.size() < (1024u - sizeof(uint64_t))); + std::memcpy(&buf[sizeof(uint64_t)], state.begin(), state.size()); + return MurmurHash3_64(buf, sizeof(uint64_t) + state.size(), 2654435769U); + } + +} + +#endif + diff --git a/decoder/nt_span.h b/decoder/nt_span.h index a918f301..6ff9391f 100644 --- a/decoder/nt_span.h +++ b/decoder/nt_span.h @@ -7,7 +7,7 @@ struct Span { int l,r; - Span() : l(-1) { } + Span() : l(-1), r(-1) { } bool is_null() const { return l<0; } void print(std::ostream &o,char const* for_null="") const { if (is_null()) diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index f844e5b2..e15c056d 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -9,6 +9,7 @@ struct RuleLexer { typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra); static void ReadRules(std::istream* in, RuleCallback func, const std::string& fname, void* extra); + static void ReadRule(const std::string&, RuleCallback func, bool mono_rule, void* extra); }; #endif diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index 05963d05..d4a8d86b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -14,6 +14,7 @@ #include "verbose.h" #include "tree_fragment.h" +bool lex_mono_rules = false; int lex_line = 0; std::istream* scfglex_stream = NULL; RuleLexer::RuleCallback rule_callback = NULL; @@ -119,8 +120,8 @@ void check_and_update_ctf_stack(const TRulePtr& rp) { %} -REAL [\-+]?[0-9]+(\.[0-9]*([eE][-+]*[0-9]+)?)?|inf|[\-+]inf -NT [^\t \[\],]+ +REAL [\-+]?[0-9]+(\.[0-9]*)?([eE][-+]*[0-9]+)? +NT ([^\t \n\r\[\],]+|Goal) %x LHS_END SRC TRG FEATS FEATVAL ALIGNS TREE %% @@ -132,7 +133,7 @@ NT [^\t \[\],]+ <INITIAL>\[{NT}\] { scfglex_tmp_token.assign(yytext + 1, yyleng - 2); scfglex_lhs = -TD::Convert(scfglex_tmp_token); - // std::cerr << scfglex_tmp_token << "\n"; + //std::cerr << "LHS: " << scfglex_tmp_token << "\n"; BEGIN(LHS_END); } @@ -199,9 +200,9 @@ NT [^\t \[\],]+ <SRC>\|\|\| { memset(scfglex_nt_sanity, 0, scfglex_src_arity * sizeof(int)); - BEGIN(TRG); + if (lex_mono_rules) { BEGIN(FEATS); } else { BEGIN(TRG); } } -<SRC>[^ \t]+ { +<SRC>[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_src_rhs[scfglex_src_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_src_rhs_size; @@ -217,14 +218,28 @@ NT [^\t \[\],]+ <TRG>\|\|\| { BEGIN(FEATS); } -<TRG>[^ \t]+ { +<TRG>[^ \t\n\r]+ { scfglex_tmp_token.assign(yytext, yyleng); scfglex_trg_rhs[scfglex_trg_rhs_size] = TD::Convert(scfglex_tmp_token); ++scfglex_trg_rhs_size; } <TRG>[ \t]+ { ; } -<TRG,FEATS,ALIGNS,TREE>\n { +<SRC,TRG,FEATS,ALIGNS,TREE>\n { + if (lex_mono_rules) { + if (scfglex_trg_rhs_size != 0) { + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": expected monolingual rule\n"; + abort(); + } + scfglex_trg_arity = scfglex_src_arity; + scfglex_trg_rhs_size = scfglex_src_rhs_size; + int ntc = 0; + for (int i = 0; i < scfglex_src_rhs_size; ++i) + if (scfglex_trg_rhs[i] <= 0) + scfglex_trg_rhs[i] = ntc--; + else + scfglex_trg_rhs[i] = scfglex_src_rhs[i]; + } if (scfglex_src_arity != scfglex_trg_arity) { std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); @@ -243,7 +258,7 @@ NT [^\t \[\],]+ TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); rule_callback(rp, ctf_level, coarse_rp, rule_callback_extra); ctf_rule_stack.push(rp); - // std::cerr << rp->AsString() << std::endl; + //std::cerr << "RULE: " << rp->AsString() << std::endl; num_rules++; lex_line++; if (!SILENT) { @@ -317,7 +332,7 @@ NT [^\t \[\],]+ #include "filelib.h" -void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { +static void init_default_feature_names() { if (scfglex_phrase_fnames.empty()) { scfglex_phrase_fnames.resize(100); for (int i = 0; i < scfglex_phrase_fnames.size(); ++i) { @@ -326,6 +341,11 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const scfglex_phrase_fnames[i] = FD::Convert(os.str()); } } +} + +void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { + init_default_feature_names(); + lex_mono_rules = false; lex_line = 1; scfglex_fname = fname; scfglex_stream = in; @@ -334,3 +354,14 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const yylex(); } +void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { + init_default_feature_names(); + lex_mono_rules = mono; + lex_line = 1; + rule_callback_extra = extra; + rule_callback = func; + yy_scan_string(srule.c_str()); + yylex(); + yylex_destroy(); +} + diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 159a1d60..88f62769 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -47,7 +47,7 @@ GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt, const TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [1]")); AddRule(stop_glue); RefineRule(stop_glue, ctf_level); - TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1")); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + "] ["+ default_nt + "] ||| [1] [2] ||| Glue=1")); AddRule(glue); RefineRule(glue, ctf_level); } diff --git a/decoder/t2s_test.cc b/decoder/t2s_test.cc new file mode 100644 index 00000000..5ebb2662 --- /dev/null +++ b/decoder/t2s_test.cc @@ -0,0 +1,114 @@ +#include "tree_fragment.h" + +#define BOOST_TEST_MODULE T2STest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> +#include <iostream> +#include "tdict.h" + +using namespace std; + +BOOST_AUTO_TEST_CASE(TestTreeFragments) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + cdec::TreeFragment tree2("(S (NP (DT a) (NN cat)) (VP (V ate) (NP (DT the) (NN cake pie))))"); + vector<unsigned> a, b; + vector<WordID> aw, bw; + cerr << "TREE1: " << tree << endl; + cerr << "TREE2: " << tree2 << endl; + for (auto& sym : tree) { + if (cdec::IsLHS(sym)) cerr << "("; + cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; + if (cdec::IsTerminal(sym)) aw.push_back(sym); else a.push_back(sym); + } + for (auto& sym : tree2) + if (cdec::IsTerminal(sym)) bw.push_back(sym); else b.push_back(sym); + BOOST_CHECK_EQUAL(a.size(), b.size()); + BOOST_CHECK_EQUAL(aw.size() + 1, bw.size()); + BOOST_CHECK_EQUAL(aw.size(), 5); + BOOST_CHECK_EQUAL(TD::GetString(aw), "the boy saw a cat"); + BOOST_CHECK_EQUAL(TD::GetString(bw), "a cat ate the cake pie"); + if (a != b) { + BOOST_CHECK_EQUAL(1,2); + } + + string nts; + for (cdec::TreeFragment::iterator it = tree.begin(); it != tree.end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) it.truncate(); + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + cerr << "Truncated: " << nts << endl; + BOOST_CHECK_EQUAL(nts, "(S NP* VP*"); + + nts.clear(); + int ntc = 0; + for (auto it = tree.bfs_begin(); it != tree.bfs_end(); ++it) { + if (cdec::IsNT(*it)) { + if (cdec::IsRHS(*it)) { + ++ntc; + if (ntc > 1) it.truncate(); + } + if (nts.size()) nts += " "; + if (cdec::IsLHS(*it)) nts += "("; + nts += TD::Convert(*it & cdec::ALL_MASK); + if (cdec::IsFrontier(*it)) nts += "*"; + } + } + BOOST_CHECK_EQUAL(nts, "(S NP VP* (NP DT* NN*"); +} + +BOOST_AUTO_TEST_CASE(TestSharing) { + cdec::TreeFragment rule1("(S [NP] [VP])", true); + cdec::TreeFragment rule2("(S [NP] (VP [V] [NP]))", true); + string r1,r2; + for (auto sym : rule1) { + if (r1.size()) r1 += " "; + if (cdec::IsLHS(sym)) r1 += "("; + r1 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r1 += "*"; + } + for (auto sym : rule2) { + if (r2.size()) r2 += " "; + if (cdec::IsLHS(sym)) r2 += "("; + r2 += TD::Convert(sym & cdec::ALL_MASK); + if (cdec::IsFrontier(sym)) r2 += "*"; + } + cerr << rule1 << endl; + cerr << r1 << endl; + cerr << rule2 << endl; + cerr << r2 << endl; + BOOST_CHECK_EQUAL(r1, "(S NP* VP*"); + BOOST_CHECK_EQUAL(r2, "(S NP* VP (VP V* NP*"); +} + +BOOST_AUTO_TEST_CASE(TestEndInvariants) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + BOOST_CHECK(tree.end().at_end()); + BOOST_CHECK(!tree.begin().at_end()); +} + +BOOST_AUTO_TEST_CASE(TestBegins) { + cdec::TreeFragment tree("(S (NP (DT the) (NN boy)) (VP (V saw) (NP (DT a) (NN cat))))"); + for (auto it = tree.begin(1); it != tree.end(); ++it) { + cerr << TD::Convert(*it & cdec::ALL_MASK) << endl; + } +} + +BOOST_AUTO_TEST_CASE(TestRemainder) { + cdec::TreeFragment tree("(S (A a) (B b))"); + auto it = tree.begin(); + ++it; + BOOST_CHECK(cdec::IsRHS(*it)); + cerr << tree << endl; + auto itr = it.remainder(); + while(itr != tree.end()) { + cerr << TD::Convert(*itr & cdec::ALL_MASK) << endl; + ++itr; + } +} + + diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 63e855c8..30fb055f 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -108,6 +108,11 @@ bool Tagger::TranslateImpl(const string& input, pimpl_->BuildTrellis(sequence, forest); forest->Reweight(weights); forest->is_linear_chain_ = true; + // since we don't do any pruning, the node_hash will be the same for + // every run of the composer + int nc = 0; + for (auto& node : forest->nodes_) + node.node_hash = ++nc; return true; } diff --git a/decoder/test_data/hg_test.hg b/decoder/test_data/hg_test.hg new file mode 100644 index 00000000..ef98e9d4 --- /dev/null +++ b/decoder/test_data/hg_test.hg @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| A [X] ||| A [1]",3,"[X] ||| c ||| c",4,"[X] ||| C [X] ||| C [1]",5,"[X] ||| [X] B [X] ||| [1] B [2]",6,"[X] ||| [X] b [X] ||| [1] b [2]",7,"[X] ||| X [X] ||| X [1]",8,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[24568,32767,24568,32767],"feats":[0,-0.8,1,-0.1],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[24568,32767,24568,32767],"feats":[1,-1],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[2],"spans":[24568,32767,24568,32767],"feats":[0,-0.2,1,-0.1],"rule":4}],"node":{"in_edges":[3],"cat":"X"},"edges":[{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-1.2,1,-0.2],"rule":5},{"tail":[1,3],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-1.3],"rule":6}],"node":{"in_edges":[4,5],"cat":"X"},"edges":[{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.5,1,-0.8],"rule":7},{"tail":[4],"spans":[24568,32767,24568,32767],"feats":[0,-0.7,1,-0.9],"rule":8}],"node":{"in_edges":[6,7],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_balanced b/decoder/test_data/hg_test.hg_balanced new file mode 100644 index 00000000..0f0f499f --- /dev/null +++ b/decoder/test_data/hg_test.hg_balanced @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| i ||| i",2,"[X] ||| a ||| a",3,"[X] ||| b ||| b",4,"[X] ||| [X] [X] ||| [1] [2]",5,"[X] ||| [X] [X] ||| [1] [2]",6,"[X] ||| c ||| c",7,"[X] ||| d ||| d",8,"[X] ||| [X] [X] ||| [1] [2]",9,"[X] ||| [X] [X] ||| [1] [2]",10,"[X] ||| [X] [X] ||| [1] [2]",11,"[X] ||| [X] [X] ||| [1] [2]",12,"[X] ||| [X] [X] ||| [1] [2]",13,"[X] ||| [X] [X] ||| [1] [2]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":3}],"node":{"in_edges":[2],"cat":"X"},"edges":[{"tail":[1,2],"spans":[32760,32767,32760,32767],"feats":[],"rule":4},{"tail":[2,1],"spans":[32760,32767,32760,32767],"feats":[],"rule":5}],"node":{"in_edges":[3,4],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":6}],"node":{"in_edges":[5],"cat":"X"},"edges":[{"tail":[],"spans":[32760,32767,32760,32767],"feats":[],"rule":7}],"node":{"in_edges":[6],"cat":"X"},"edges":[{"tail":[4,5],"spans":[32760,32767,32760,32767],"feats":[],"rule":8},{"tail":[5,4],"spans":[32760,32767,32760,32767],"feats":[],"rule":9}],"node":{"in_edges":[7,8],"cat":"X"},"edges":[{"tail":[3,6],"spans":[32760,32767,32760,32767],"feats":[],"rule":10},{"tail":[6,3],"spans":[32760,32767,32760,32767],"feats":[],"rule":11}],"node":{"in_edges":[9,10],"cat":"X"},"edges":[{"tail":[7,0],"spans":[32760,32767,32760,32767],"feats":[],"rule":12},{"tail":[0,7],"spans":[32760,32767,32760,32767],"feats":[],"rule":13}],"node":{"in_edges":[11,12],"cat":"X"}} diff --git a/decoder/test_data/hg_test.hg_int b/decoder/test_data/hg_test.hg_int new file mode 100644 index 00000000..9c4603bc --- /dev/null +++ b/decoder/test_data/hg_test.hg_int @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| a ||| a",2,"[X] ||| b ||| b",3,"[X] ||| a [X] ||| a [1]",4,"[X] ||| [X] b ||| [1] b"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":1},{"tail":[],"spans":[-8200,32767,-8200,32767],"feats":[0,0.1],"rule":2}],"node":{"in_edges":[0,1],"cat":"X"},"edges":[{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.3],"rule":3},{"tail":[0],"spans":[-8200,32767,-8200,32767],"feats":[0,0.2],"rule":4}],"node":{"in_edges":[2,3],"cat":"Goal"}} diff --git a/decoder/test_data/hg_test.lattice b/decoder/test_data/hg_test.lattice new file mode 100644 index 00000000..29e021c5 --- /dev/null +++ b/decoder/test_data/hg_test.lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] A A ||| [1] A A",4,"[X] ||| [X] b ||| [1] b",5,"[X] ||| [X] c ||| [1] c",6,"[X] ||| [X] B C ||| [1] B C",7,"[X] ||| [X] A B C ||| [1] A B C",8,"[X] ||| [X] CC ||| [1] CC"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[2,-0.3],"rule":1},{"tail":[0],"feats":[2,-0.6],"rule":2},{"tail":[0],"feats":[2,-1.7],"rule":3}],"node":{"in_edges":[0,1,2]},"edges":[{"tail":[1],"feats":[2,-0.5],"rule":4}],"node":{"in_edges":[3]},"edges":[{"tail":[2],"feats":[2,-0.6],"rule":5},{"tail":[1],"feats":[2,-0.8],"rule":6},{"tail":[0],"feats":[2,-0.01],"rule":7},{"tail":[2],"feats":[2,-0.8],"rule":8}],"node":{"in_edges":[4,5,6,7]}}" diff --git a/decoder/test_data/hg_test.tiny b/decoder/test_data/hg_test.tiny new file mode 100644 index 00000000..101b96e9 --- /dev/null +++ b/decoder/test_data/hg_test.tiny @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| <s> ||| <s>",2,"[X] ||| X [X] ||| X [1]",3,"[X] ||| Z [X] ||| Z [1]"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7","LatticeCost"],"edges":[{"tail":[],"spans":[25080,32767,25080,32767],"feats":[0,-2,1,-99],"rule":1}],"node":{"in_edges":[0],"cat":"X"},"edges":[{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.5,1,-0.8],"rule":2},{"tail":[0],"spans":[25080,32767,25080,32767],"feats":[0,-0.7,1,-0.9],"rule":3}],"node":{"in_edges":[1,2],"cat":"X"}} diff --git a/decoder/test_data/hg_test.tiny_lattice b/decoder/test_data/hg_test.tiny_lattice new file mode 100644 index 00000000..b9adf3cd --- /dev/null +++ b/decoder/test_data/hg_test.tiny_lattice @@ -0,0 +1 @@ +{"rules":[1,"[X] ||| [X] a ||| [1] a",2,"[X] ||| [X] A ||| [1] A",3,"[X] ||| [X] b ||| [1] b",4,"[X] ||| [X] B' ||| [1] B'"],"features":["f1","f2","Feature_1","Feature_0","Model_0","Model_1","Model_2","Model_3","Model_4","Model_5","Model_6","Model_7"],"edges":[],"node":{"in_edges":[]},"edges":[{"tail":[0],"feats":[0,-0.2],"rule":1},{"tail":[0],"feats":[0,-0.6],"rule":2}],"node":{"in_edges":[0,1]},"edges":[{"tail":[1],"feats":[0,-0.1],"rule":3},{"tail":[1],"feats":[0,-0.9],"rule":4}],"node":{"in_edges":[2,3]}} diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz Binary files differindex 892ba360..f6f37293 100644 --- a/decoder/test_data/small.json.gz +++ b/decoder/test_data/small.json.gz diff --git a/decoder/translator.h b/decoder/translator.h index 72b2f0b0..ba218a0b 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -101,7 +101,8 @@ class RescoreTranslator : public Translator { class Tree2StringTranslatorImpl; class Tree2StringTranslator : public Translator { public: - Tree2StringTranslator(const boost::program_options::variables_map& conf); + Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states); virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index ac9c0d74..b5b47d5d 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -1,7 +1,11 @@ #include <algorithm> #include <vector> +#include <queue> +#include <map> +#include <unordered_set> +#include <boost/shared_ptr.hpp> #include <boost/functional/hash.hpp> -#include <unordered_map> +#include "fast_lexical_cast.hpp" #include "tree_fragment.h" #include "translator.h" #include "hg.h" @@ -13,60 +17,394 @@ using namespace std; -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// a implication: (S (A a *INCOMPLETE* [B]) -// a implication: (S (A a a *INCOMPLETE* [B]) -// *0* implication: (S (A a a) _[B]) -// D implication: (S (A a a) (B [D] *INCOMPLETE*) -// *0* implication: (S (A a a) (B _[D])) -// d implication: (S (A a a) (B (D d *INCOMPLETE*)) -// *0* implication: (S (A a a) (B (D d))) -// --there are no further outgoing links possible-- - -// root: S -// A implication: (S [A] *INCOMPLETE* -// B implication: (S [A] [B] *INCOMPLETE* -// *0* implication: (S _[A] [B]) -// *0* implication: (S [A] _[B]) -// b implication: (S [A] (B b *INCOMPLETE*)) struct Tree2StringGrammarNode { map<unsigned, Tree2StringGrammarNode> next; - string rules; + vector<TRulePtr> rules; }; -void ReadTree2StringGrammar(istream* in, unordered_map<unsigned, Tree2StringGrammarNode>* proots) { - unordered_map<unsigned, Tree2StringGrammarNode>& roots = *proots; +// this needs to be rewritten so it is fast and checks errors well +// use a lexer probably +static void ReadTree2StringGrammar(istream* in, Tree2StringGrammarNode* root, bool has_multiple_states) { string line; while(getline(*in, line)) { size_t pos = line.find("|||"); assert(pos != string::npos); assert(pos > 3); - if (line[pos - 1] == ' ') --pos; + unsigned xc = 0; + while (line[pos - 1] == ' ') { --pos; xc++; } cdec::TreeFragment rule_src(line.substr(0, pos), true); + // TODO transducer_state should (optionally?) be read from input + const unsigned transducer_state = 0; + Tree2StringGrammarNode* cur = &root->next[transducer_state]; + ostringstream os; + int lhs = -(rule_src.root & cdec::ALL_MASK); + // build source RHS for SCFG projection + vector<int> frhs; + // we traverse the rule_src in left to right, DFS order + for (auto sym : rule_src) { + //cerr << TD::Convert(sym & cdec::ALL_MASK) << endl; + cur = &cur->next[sym]; + if (cdec::IsFrontier(sym)) { // frontier symbols -> variables + int nt = (sym & cdec::ALL_MASK); + frhs.push_back(-nt); + } else if (cdec::IsTerminal(sym)) { + frhs.push_back(sym); + } // else internal NT, nothing to do + } + os << '[' << TD::Convert(-lhs) << "] |||"; + for (auto x : frhs) { + os << ' '; + if (x < 0) + os << '[' << TD::Convert(-x) << ']'; + else + os << TD::Convert(x); + } + pos += 3 + xc; + while(line[pos] == ' ') { ++pos; } + os << " ||| " << line.substr(pos); + TRulePtr rule(new TRule(os.str())); + // TODO the transducer_state you end up in after using this rule (for each NT) + // needs to be read and encoded somehow in the rule (for use XXX) + cur->rules.push_back(rule); + //cerr << "RULE: " << rule->AsString() << "\n\n"; } } +// represents where in an input parse tree the transducer must continue +// and what state it is in +struct TransducerState { + TransducerState() : input_node_idx(), transducer_state() {} + TransducerState(unsigned n, unsigned q) : input_node_idx(n), transducer_state(q) {} + bool operator==(const TransducerState& o) const { + return input_node_idx == o.input_node_idx && + transducer_state == o.transducer_state; + } + unsigned input_node_idx; + unsigned transducer_state; +}; + +// represents the state of the composition algorithm +struct ParserState { + ParserState() : in_iter(), node() {} + cdec::TreeFragment::iterator in_iter; + ParserState(const cdec::TreeFragment::iterator& it, unsigned q, Tree2StringGrammarNode* n) : + in_iter(it), + task(it.node_idx(), q), + node(n) {} + ParserState(const cdec::TreeFragment::iterator& it, Tree2StringGrammarNode* n, const ParserState& p) : + in_iter(it), + future_work(p.future_work), + task(p.task), + node(n) {} + bool operator==(const ParserState& o) const { + return node == o.node && task == o.task && + future_work == o.future_work && in_iter == o.in_iter; + } + vector<TransducerState> future_work; + TransducerState task; // subtree root where and in what state did the transducer start? + Tree2StringGrammarNode* node; // pointer into grammar trie +}; + +namespace std { + template<> + struct hash<TransducerState> { + size_t operator()(const TransducerState& q) const { + size_t h = boost::hash_value(q.transducer_state); + boost::hash_combine(h, boost::hash_value(q.input_node_idx)); + return h; + } + }; + template<> + struct hash<ParserState> { + size_t operator()(const ParserState& s) const { + size_t h = boost::hash_value(s.node); + for (auto& w : s.future_work) + boost::hash_combine(h, hash<TransducerState>()(w)); + boost::hash_combine(h, hash<TransducerState>()(s.task)); + // TODO hash with iterator + return h; + } + }; +}; + +void AddDummyGoalNode(Hypergraph* hg) { + static const int kGOAL = -TD::Convert("Goal"); + unsigned old_goal_node_idx = hg->nodes_.size() - 1; + int old_goal_cat = hg->nodes_[old_goal_node_idx].cat_; + TRulePtr goal_rule(new TRule("[Goal] ||| [X] ||| [1]")); + goal_rule->f_[0] = old_goal_cat; + HG::Node* goal_node = hg->AddNode(kGOAL); + goal_node->node_hash = 1; + TailNodeVector tail(1, old_goal_node_idx); + HG::Edge* new_edge = hg->AddEdge(goal_rule, tail); + hg->ConnectEdgeToHeadNode(new_edge, goal_node); +} + struct Tree2StringTranslatorImpl { - unordered_map<unsigned, Tree2StringGrammarNode> roots; // root['S'] gives rule network for S rules - Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf) { - ReadFile rf(conf["grammar"].as<vector<string>>()[0]); - ReadTree2StringGrammar(rf.stream(), &roots); + vector<boost::shared_ptr<Tree2StringGrammarNode>> root; + bool add_pass_through_rules; + bool has_multiple_states; + unsigned remove_grammars; + Tree2StringTranslatorImpl(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + add_pass_through_rules(conf.count("add_pass_through_rules")), + has_multiple_states(has_multiple_states) { + if (conf.count("grammar")) { + const vector<string> gf = conf["grammar"].as<vector<string>>(); + root.resize(gf.size()); + unsigned gc = 0; + for (auto& f : gf) { + ReadFile rf(f); + root[gc].reset(new Tree2StringGrammarNode); + ReadTree2StringGrammar(rf.stream(), &*root[gc++], has_multiple_states); + } + } + } + + // loads a per-sentence grammar + void LoadSupplementalGrammar(const string& gfile) { + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + ReadFile rf(gfile); + ReadTree2StringGrammar(rf.stream(), root.back().get(), has_multiple_states); + } + + void CreatePassThroughRules(const cdec::TreeFragment& tree) { + static const int kFIDlex = FD::Convert("PassThrough_Lexical"); + static const int kFIDabs = FD::Convert("PassThrough_Abstract"); + static const int kFIDmix = FD::Convert("PassThrough_Mix"); + static const int kFID = FD::Convert("PassThrough"); + static unordered_map<int, int> pntfid; + root.resize(root.size() + 1); + root.back().reset(new Tree2StringGrammarNode); + ++remove_grammars; + unordered_set<vector<int>,boost::hash<vector<int>>> unique_rule_check; + for (auto& prod : tree.nodes) { + int ntc = 0; + int lhs = -(prod.lhs & cdec::ALL_MASK); + int &ntfid = pntfid[lhs]; + if (!ntfid) { + ostringstream fos; + fos << "PassThrough:" << TD::Convert(-lhs); + ntfid = FD::Convert(fos.str()); + } + + // check for duplicate rule in tree + vector<int> key; + key.push_back(prod.lhs); + + bool has_lex = false; + bool has_nt = false; + vector<int> rhse, rhsf; + ostringstream os; + os << '(' << TD::Convert(-lhs); + for (auto& sym : prod.rhs) { + os << ' '; + if (cdec::IsTerminal(sym)) { + has_lex = true; + os << TD::Convert(sym); + rhse.push_back(sym); + rhsf.push_back(sym); + key.push_back(sym); + } else { + has_nt = true; + unsigned id = tree.nodes[sym & cdec::ALL_MASK].lhs & cdec::ALL_MASK; + os << '[' << TD::Convert(id) << ']'; + rhsf.push_back(-id); + rhse.push_back(-ntc); + key.push_back(-id); + ++ntc; + } + } + os << ')'; + if (!unique_rule_check.insert(key).second) continue; + cdec::TreeFragment rule_src(os.str(), true); + Tree2StringGrammarNode* cur = root.back().get(); + // do we need all transducer states here??? a list??? no pass through rules??? + unsigned transducer_state = 0; + cur = &cur->next[transducer_state]; + for (auto sym : rule_src) + cur = &cur->next[sym]; + TRulePtr rule(new TRule(rhse, rhsf, lhs)); + rule->ComputeArity(); + rule->scores_.set_value(ntfid, 1.0); + rule->scores_.set_value(kFID, 1.0); + if (has_lex && has_nt) + rule->scores_.set_value(kFIDmix, 1.0); + else if (has_lex) rule->scores_.set_value(kFIDlex, 1.0); + else if (has_nt) rule->scores_.set_value(kFIDabs, 1.0); + cur->rules.push_back(rule); + } } + + void RemoveGrammars() { + assert(remove_grammars <= root.size()); + root.resize(root.size() - remove_grammars); + } + bool Translate(const string& input, SentenceMetadata* smeta, const vector<double>& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); - cerr << "Tree2StringTranslatorImpl: please implement this!\n"; - return false; + if (add_pass_through_rules) CreatePassThroughRules(input_tree); + Hypergraph hg; + hg.ReserveNodes(input_tree.nodes.size()); + unordered_map<TransducerState, unsigned> x2hg(input_tree.nodes.size() * 5); + queue<ParserState> q; + unordered_set<ParserState> unique; // only create items one time + for (auto& g : root) { + unsigned q_0 = 0; // TODO initialize q_0 properly once multi-state transducers are supported + auto rit = g->next.find(q_0); + if (rit != g->next.end()) { // does this g have this transducer state? + q.push(ParserState(input_tree.begin(), q_0, &rit->second)); + unique.insert(q.back()); + } + } + if (q.size() == 0) return false; + const TransducerState tree_top = q.front().task; + while(!q.empty()) { + ParserState& s = q.front(); + + if (s.in_iter.at_end()) { // completed a traversal of a subtree + //cerr << "I traversed a subtree of the input rooted at node=" << s.input_node_idx << " sym=" << + // TD::Convert(input_tree.nodes[s.input_node_idx].lhs & cdec::ALL_MASK) << endl; + if (s.node->rules.size()) { + auto it = x2hg.find(s.task); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[s.task.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(s.task); + it = x2hg.insert(make_pair(s.task, new_node->id_)).first; + } + const unsigned node_id = it->second; + TailNodeVector tail; + for (const auto& n : s.future_work) { + auto it = x2hg.find(n); + if (it == x2hg.end()) { + // TODO create composite state symbol that encodes transducer state type? + HG::Node* new_node = hg.AddNode(-(input_tree.nodes[n.input_node_idx].lhs & cdec::ALL_MASK)); + new_node->node_hash = std::hash<TransducerState>()(n); + it = x2hg.insert(make_pair(n, new_node->id_)).first; + } + tail.push_back(it->second); + } + for (auto& r : s.node->rules) { + assert(tail.size() == r->Arity()); + HG::Edge* new_edge = hg.AddEdge(r, tail); + new_edge->feature_values_ = r->GetFeatureValues(); + // TODO: set i and j + hg.ConnectEdgeToHeadNode(new_edge, &hg.nodes_[node_id]); + } + for (const auto& n : s.future_work) { + const auto it = input_tree.begin(n.input_node_idx); // start tree iterator at node n + for (auto& g : root) { + auto rit = g->next.find(n.transducer_state); + if (rit != g->next.end()) { // does this g have this transducer state? + const ParserState s(it, n.transducer_state, &rit->second); + if (unique.insert(s).second) q.push(s); + } + } + } + } else { + //cerr << "I can't build anything :(\n"; + } + } else { // more input tree to match + unsigned sym = *s.in_iter; + if (cdec::IsLHS(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED LHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ParserState news(++s.in_iter, &nit->second, s); + if (unique.insert(news).second) q.push(news); + } + } else if (cdec::IsRHS(sym)) { + //cerr << "Attempting to match RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + cdec::TreeFragment::iterator var = s.in_iter; + var.truncate(); + auto nit1 = s.node->next.find(sym); + auto nit2 = s.node->next.find(*var); + if (nit2 != s.node->next.end()) { + //cerr << "MATCHED VAR RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + ++var; + // TODO: find out from rule what the new target state is (the 0 in the next line) + // if it is associated with the rule, we won't know until we match the whole input + // so the 0 may be okay (if this is the case, which is probably the easiest thing, + // then the state must be dealt with when the future work becomes real work) + const TransducerState new_task(s.in_iter.child_node(), 0); + ParserState new_s(var, &nit2->second, s); + new_s.future_work.push_back(new_task); // if this traversal of the input succeeds, future_work goes on the q + if (unique.insert(new_s).second) q.push(new_s); + } + //else { cerr << "did not match [" << TD::Convert(sym & cdec::ALL_MASK) << "]\n"; } + if (nit1 != s.node->next.end()) { + //cerr << "MATCHED FULL RHS: " << TD::Convert(sym & cdec::ALL_MASK) << endl; + const ParserState new_s(++s.in_iter, &nit1->second, s); + if (unique.insert(new_s).second) q.push(new_s); + } + //else { cerr << "did not match " << TD::Convert(sym & cdec::ALL_MASK) << "\n"; } + } else if (cdec::IsTerminal(sym)) { + auto nit = s.node->next.find(sym); + if (nit != s.node->next.end()) { + //cerr << "MATCHED TERMINAL: " << TD::Convert(sym) << endl; + const ParserState new_s(++s.in_iter, &nit->second, s); + if (unique.insert(new_s).second) q.push(new_s); + } + } else { + cerr << "This can never happen!\n"; abort(); + } + } + q.pop(); + } + const auto goal_it = x2hg.find(tree_top); + if (goal_it == x2hg.end()) return false; + //cerr << "Goal node: " << goal << endl; + hg.TopologicallySortNodesAndEdges(goal_it->second); + + // there might be nodes that cannot be derived + // the following takes care of them + vector<bool> prune(hg.edges_.size(), false); + hg.PruneEdges(prune, true); + if (hg.edges_.size() == 0) return false; + // rescoring assumes the goal edge is arity 1 (code laziness), add that here + AddDummyGoalNode(&hg); + + hg.Reweight(weights); + //hg.PrintGraphviz(); + + minus_lm_forest->swap(hg); + return true; } }; -Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf) : - pimpl_(new Tree2StringTranslatorImpl(conf)) {} +Tree2StringTranslator::Tree2StringTranslator(const boost::program_options::variables_map& conf, + bool has_multiple_states) : + pimpl_(new Tree2StringTranslatorImpl(conf, has_multiple_states)) {} + +void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { + pimpl_->remove_grammars = 0; + if (kv.find("grammar0") != kv.end()) { + cerr << "SGML tag grammar0 is not expected (order is: grammar, grammar1, grammar2, ...)\n"; + abort(); + } + unsigned gc = 0; + set<string> loaded; + while(true) { + string gkey = "grammar"; + if (gc > 0) gkey += boost::lexical_cast<string>(gc); + ++gc; + map<string,string>::const_iterator it = kv.find(gkey); + if (it == kv.end()) break; + const string& gfile = it->second; + if (loaded.count(gfile) == 1) { + cerr << "Attempting to load " << gfile << " twice!\n"; + abort(); + } + loaded.insert(gfile); + pimpl_->LoadSupplementalGrammar(gfile); + } +} bool Tree2StringTranslator::TranslateImpl(const string& input, SentenceMetadata* smeta, @@ -75,10 +413,8 @@ bool Tree2StringTranslator::TranslateImpl(const string& input, return pimpl_->Translate(input, smeta, weights, minus_lm_forest); } -void Tree2StringTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) { -} - void Tree2StringTranslator::SentenceCompleteImpl() { + pimpl_->RemoveGrammars(); } std::string Tree2StringTranslator::GetDecoderType() const { diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index d5c30f58..696c8601 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -2,6 +2,8 @@ #include <cassert> +#include "tdict.h" + using namespace std; namespace cdec { @@ -36,7 +38,7 @@ void TreeFragment::DebugRec(unsigned cur, ostream* out) const { *out << ' '; if (IsFrontier(x)) { *out << '[' << TD::Convert(x & ALL_MASK) << ']'; - } else if (IsInternalNT(x)) { + } else if (IsRHS(x)) { DebugRec(x & ALL_MASK, out); } else { // must be terminal *out << TD::Convert(x); @@ -66,7 +68,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned // recursively call parser to deal with constituent ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np); unsigned ind = np - 1; - rhs.push_back(ind | NT_BIT); + rhs.push_back(ind | RHS_BIT); } else { // deal with terminal / nonterminal substitution ++symp; assert(tree[cp] != ' '); @@ -95,7 +97,7 @@ void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned } // continuent has completed, cp is at ), build node const unsigned j = symp; // span from (i,j) // add an internal non-terminal symbol - const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | NT_BIT; + const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT; nodes[np] = TreeFragmentProduction(nt, rhs); //cerr << np << " production(" << i << "," << j << ")= " << TD::Convert(nt & ALL_MASK) << " -->"; //for (auto& x : rhs) { diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 83cd1c1e..79722b5a 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -1,26 +1,44 @@ #ifndef TREE_FRAGMENT #define TREE_FRAGMENT +#include <deque> #include <iostream> #include <vector> #include <string> - -#include "tdict.h" +#include <cassert> +#include <cstddef> namespace cdec { -static const unsigned NT_BIT = 0x40000000u; -static const unsigned FRONTIER_BIT = 0x80000000u; -static const unsigned ALL_MASK = 0x0FFFFFFFu; +class BreadthFirstIterator; +class DepthFirstIterator; + +static const unsigned LHS_BIT = 0x10000000u; +static const unsigned RHS_BIT = 0x20000000u; +static const unsigned FRONTIER_BIT = 0x40000000u; +static const unsigned RESERVED_BIT = 0x80000000u; +static const unsigned ALL_MASK = 0x0FFFFFFFu; + +inline bool IsNT(unsigned x) { + return (x & (LHS_BIT | RHS_BIT | FRONTIER_BIT)); +} -inline bool IsInternalNT(unsigned x) { - return (x & NT_BIT); +inline bool IsLHS(unsigned x) { + return (x & LHS_BIT); +} + +inline bool IsRHS(unsigned x) { + return (x & RHS_BIT); } inline bool IsFrontier(unsigned x) { return (x & FRONTIER_BIT); } +inline bool IsTerminal(unsigned x) { + return (x & ALL_MASK) == x; +} + struct TreeFragmentProduction { TreeFragmentProduction() {} TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {} @@ -36,6 +54,21 @@ class TreeFragment { // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar)))) explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false); void DebugRec(unsigned cur, std::ostream* out) const; + typedef DepthFirstIterator iterator; + typedef ptrdiff_t difference_type; + typedef unsigned value_type; + typedef const unsigned * pointer; + typedef const unsigned & reference; + + // default iterator is DFS + iterator begin() const; + iterator begin(unsigned node_idx) const; + iterator end() const; + + BreadthFirstIterator bfs_begin() const; + BreadthFirstIterator bfs_begin(unsigned node_idx) const; + BreadthFirstIterator bfs_end() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built @@ -49,6 +82,202 @@ class TreeFragment { std::vector<TreeFragmentProduction> nodes; }; +struct TFIState { + TFIState() : node(), rhspos(), state() {} + TFIState(unsigned n, int p, unsigned s) : node(n), rhspos(p), state(s) {} + bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos && state == o.state; } + bool operator!=(const TFIState& o) const { return node != o.node || rhspos != o.rhspos || state != o.state; } + unsigned short node; + short rhspos; + unsigned char state; +}; + +class DepthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> { + const TreeFragment* tf_; + std::deque<TFIState> q_; + unsigned sym; + public: + DepthFirstIterator() : tf_(), sym() {} + // used for begin + explicit DepthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, -1, 0)); + Stage(); + q_.back().state++; + } + // used for end + explicit DepthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const DepthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const DepthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const DepthFirstIterator& operator++() { + TFIState& s = q_.back(); + if (s.state == 0) { + Stage(); + s.state++; + } else if (s.state == 1) { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_back(); + while (!q_.empty()) { + TFIState& s = q_.back(); + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos < len) break; + q_.pop_back(); + } + } + Stage(); + } + return *this; + } + DepthFirstIterator operator++(int) { + DepthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const DepthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + DepthFirstIterator remainder() const { + assert(IsRHS(sym)); + return DepthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.back(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else if (s.state == 1) { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, -1, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + DepthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + +class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> { + const TreeFragment* tf_; + std::deque<TFIState> q_; + unsigned sym; + public: + BreadthFirstIterator() : tf_(), sym() {} + // used for begin + explicit BreadthFirstIterator(const TreeFragment* tf, unsigned node_idx) : tf_(tf) { + q_.push_back(TFIState(node_idx, 0, 0)); + Stage(); + } + // used for end + explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {} + const unsigned& operator*() const { return sym; } + const unsigned* operator->() const { return &sym; } + bool operator==(const BreadthFirstIterator& other) const { + return (tf_ == other.tf_) && (q_ == other.q_); + } + bool operator!=(const BreadthFirstIterator& other) const { + return (tf_ != other.tf_) || (q_ != other.q_); + } + unsigned node_idx() const { return q_.front().node; } + const BreadthFirstIterator& operator++() { + TFIState& s = q_.front(); + if (s.state == 0) { + s.state++; + Stage(); + } else { + const unsigned len = tf_->nodes[s.node].rhs.size(); + s.rhspos++; + if (s.rhspos >= len) { + q_.pop_front(); + Stage(); + } else { + Stage(); + } + } + return *this; + } + BreadthFirstIterator operator++(int) { + BreadthFirstIterator res = *this; + ++(*this); + return res; + } + // tell iterator not to explore the subtree rooted at sym + // should only be called once per NT symbol encountered + const BreadthFirstIterator& truncate() { + assert(IsRHS(sym)); + sym &= ALL_MASK; + sym |= FRONTIER_BIT; + q_.pop_back(); + return *this; + } + unsigned child_node() const { + assert(IsRHS(sym)); + return q_.back().node; + } + BreadthFirstIterator remainder() const { + assert(IsRHS(sym)); + return BreadthFirstIterator(tf_, q_.back()); + } + bool at_end() const { + return q_.empty(); + } + private: + void Stage() { + if (q_.empty()) return; + const TFIState& s = q_.front(); + if (s.state == 0) { + sym = (tf_->nodes[s.node].lhs & ALL_MASK) | LHS_BIT; + } else { + sym = tf_->nodes[s.node].rhs[s.rhspos]; + if (IsRHS(sym)) { + q_.push_back(TFIState(sym & ALL_MASK, 0, 0)); + sym = tf_->nodes[sym & ALL_MASK].lhs | RHS_BIT; + } + } + } + + // used by remainder + BreadthFirstIterator(const TreeFragment* tf, const TFIState& s) : tf_(tf) { + q_.push_back(s); + Stage(); + } +}; + +inline TreeFragment::iterator TreeFragment::begin() const { return iterator(this, nodes.size() - 1); } +inline TreeFragment::iterator TreeFragment::begin(unsigned node_idx) const { return iterator(this, node_idx); } +inline TreeFragment::iterator TreeFragment::end() const { return iterator(this); } + +inline BreadthFirstIterator TreeFragment::bfs_begin() const { return BreadthFirstIterator(this, nodes.size() - 1); } +inline BreadthFirstIterator TreeFragment::bfs_begin(unsigned node_idx) const { return BreadthFirstIterator(this, node_idx); } +inline BreadthFirstIterator TreeFragment::bfs_end() const { return BreadthFirstIterator(this); } + inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) { x.DebugRec(x.nodes.size() - 1, &os); return os; diff --git a/decoder/trule.cc b/decoder/trule.cc index c22baae3..bee211d5 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -17,73 +17,16 @@ bool TRule::IsGoal() const { return GetLHS() == kGOAL; } -static WordID ConvertTrgString(const string& w) { - const unsigned len = w.size(); - WordID id = 0; - // [X,0] or [0] - // for target rules, we ignore the category, just keep the index - if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' && - (len == 3 || (len > 4 && w[len-3] == ','))) { - id = w[len-2] - '0'; - id = 1 - id; - } else { - id = TD::Convert(w); - } - return id; -} - -static WordID ConvertSrcString(const string& w, bool mono = false) { - const unsigned len = w.size(); - // [X,0] - // for source rules, we keep the category and ignore the index (source rules are - // always numbered 1, 2, 3... - if (mono) { - if (len > 2 && w[0]=='[' && w[len-1]==']') { - if (len > 4 && w[len-3] == ',') { - cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n " - << w << endl; - exit(1); - } - // TODO check that source indices go 1,2,3,etc. - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w); - } - } else { - if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') { - return TD::Convert(w.substr(1, len-4)) * -1; - } else { - return TD::Convert(w); - } - } -} - -static WordID ConvertLHS(const string& w) { - if (w[0] == '[') { - const unsigned len = w.size(); - if (len < 3) { cerr << "Format error: " << w << endl; exit(1); } - return TD::Convert(w.substr(1, len-2)) * -1; - } else { - return TD::Convert(w) * -1; - } -} - TRule* TRule::CreateRuleSynchronous(const string& rule) { TRule* res = new TRule; - if (res->ReadFromString(rule, true, false)) return res; + if (res->ReadFromString(rule)) return res; cerr << "[ERROR] Failed to creating rule from: " << rule << endl; delete res; return NULL; } TRule* TRule::CreateRulePhrasetable(const string& rule) { - // TODO make this faster - // TODO add configuration for default NT type - if (rule[0] == '[') { - cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl; - return NULL; - } - TRule* res = new TRule("[X] ||| " + rule, true, false); + TRule* res = new TRule("[X] ||| " + rule); if (res->Arity() != 0) { cerr << "Phrasetable rules should have arity 0:\n " << rule << endl; delete res; @@ -93,138 +36,33 @@ TRule* TRule::CreateRulePhrasetable(const string& rule) { } TRule* TRule::CreateRuleMonolingual(const string& rule) { - return new TRule(rule, false, true); + return new TRule(rule, true); } namespace { -// callback for lexer +// callback for single rule lexer int n_assigned=0; -void assign_trule(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { - (void) ctf_level; - (void) coarse_rule; - TRule *assignto=(TRule *)extra; - *assignto=*new_rule; - ++n_assigned; -} - -} - -bool TRule::ReadFromString(const string& line, bool strict, bool mono) { - if (!is_single_line_stripped(line)) - cerr<<"\nWARNING: building rule from multi-line string "<<line<<".\n"; - // backed off of this: it's failing to parse TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["+ default_nt + ",2] ||| [1] [2] ||| Glue=1")); thinks [1] is the features! - if (false && !(mono||strict)) { - // use lexer - istringstream il(line); - n_assigned=0; - RuleLexer::ReadRules(&il,assign_trule,"STRING",this); - if (n_assigned>1) - cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<<line<<".\n"; - return n_assigned; + void assign_trule(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra) { + (void) ctf_level; + (void) coarse_rule; + *static_cast<TRule*>(extra) = *new_rule; + ++n_assigned; } +} - e_.clear(); - f_.clear(); - scores_.clear(); - - string w; - istringstream is(line); - int format = CountSubstrings(line, "|||"); - if (strict && format < 2) { - cerr << "Bad rule format in strict mode:\n" << line << endl; - return false; - } - if (format >= 2 || (mono && format == 1)) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); } - if (!mono) { - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - } - int fv = 0; - if (is) { - string ss; - getline(is, ss); - //cerr << "L: " << ss << endl; - unsigned start = 0; - unsigned len = ss.size(); - const size_t ppos = ss.find(" |||"); - if (ppos != string::npos) { len = ppos; } - while (start < len) { - while(start < len && (ss[start] == ' ' || ss[start] == ';')) - ++start; - if (start == len) break; - unsigned end = start + 1; - while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end == len || ss[end] == ' ' || ss[end] == ';') { - //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n"; - // non-named features - if (end != len) { ss[end] = 0; } - string fname = "PhraseModel_X"; - if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } - fname[12]='0' + fv; - ++fv; - // if the feature set is frozen, this may return zero, indicating an - // undefined feature - const int fid = FD::Convert(fname); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; - } else { - const int fid = FD::Convert(ss.substr(start, end - start)); - start = end + 1; - end = start + 1; - while(end < len && (ss[end] != ' ' && ss[end] != ';')) - ++end; - if (end < len) { ss[end] = 0; } - assert(start < len); - if (fid) - scores_.set_value(fid, atof(&ss[start])); - //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; - } - start = end + 1; - } - } - } else if (format == 1) { - while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } - while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } - f_ = e_; - int x = ConvertLHS("[X]"); - for (unsigned i = 0; i < f_.size(); ++i) - if (f_[i] <= 0) { f_[i] = x; } - } else { - cerr << "F: " << format << endl; - cerr << "[ERROR] Don't know how to read:\n" << line << endl; - } +bool TRule::ReadFromString(const string& line, bool mono) { + n_assigned = 0; + //cerr << "LINE: " << line << " -- mono=" << mono << endl; + RuleLexer::ReadRule(line + '\n', assign_trule, mono, this); + if (n_assigned > 1) + cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<<line<<".\n"; if (mono) { e_ = f_; - int ci = 0; - for (unsigned i = 0; i < e_.size(); ++i) - if (e_[i] < 0) - e_[i] = ci--; - } - ComputeArity(); - return SanityCheck(); -} - -bool TRule::SanityCheck() const { - vector<int> used(f_.size(), 0); - int ac = 0; - for (unsigned i = 0; i < e_.size(); ++i) { - int ind = e_[i]; - if (ind > 0) continue; - ind = -ind; - if ((++used[ind]) != 1) { - cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n"; - return false; - } - ac++; - } - if (ac != Arity()) { - cerr << "[ERROR] e-side arity mismatches f-side\n"; - return false; + int ntc = 0; + for (auto& i : e_) + if (i < 0) i = -ntc++; } - return true; + return n_assigned; } void TRule::ComputeArity() { @@ -245,7 +83,7 @@ string TRule::AsString(bool verbose) const { if (w < 0) { int wi = w * -1; ++idx; - os << " [" << TD::Convert(wi) << ',' << idx << ']'; + os << " [" << TD::Convert(wi) << ']'; } else { os << ' ' << TD::Convert(w); } diff --git a/decoder/trule.h b/decoder/trule.h index e9a10bea..cc370757 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -42,6 +42,9 @@ class TRule { scores_.set_value(feat_ids[i], feat_vals[i]); } + TRule(WordID lhs, const WordID* src, int src_size, const WordID* trg, int trg_size, int arity, int pi, int pj) : + e_(trg, trg + trg_size), f_(src, src + src_size), lhs_(lhs), arity_(arity), prev_i(pi), prev_j(pj) {} + bool IsGoal() const; explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} @@ -51,23 +54,18 @@ class TRule { TRule(const TRule& other) : e_(other.e_), f_(other.f_), lhs_(other.lhs_), scores_(other.scores_), arity_(other.arity_), prev_i(-1), prev_j(-1), a_(other.a_) {} - // if mono or strict is true, then lexer won't be used, and //FIXME: > 9 variables won't work - explicit TRule(const std::string& text, bool strict = false, bool mono = false) : prev_i(-1), prev_j(-1) { - ReadFromString(text, strict, mono); + explicit TRule(const std::string& text, bool mono = false) : prev_i(-1), prev_j(-1) { + ReadFromString(text, mono); } - // deprecated, use lexer // make a rule from a hiero-like rule table, e.g. // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1] - // if misformatted, returns NULL static TRule* CreateRuleSynchronous(const std::string& rule); - // deprecated, use lexer // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g: // el gato ||| the cat ||| Feature_2=0.34 static TRule* CreateRulePhrasetable(const std::string& rule); - // deprecated, use lexer // make a rule from a non-synchrnous CFG representation, e.g.: // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] static TRule* CreateRuleMonolingual(const std::string& rule); @@ -80,11 +78,10 @@ class TRule { std::vector<WordID>* result) const { unsigned vc = 0; result->clear(); - for (std::vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : e_) { if (c < 1) { ++vc; - const std::vector<WordID>& var_value = *var_values[-c]; + const auto& var_value = *var_values[-c]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -99,10 +96,9 @@ class TRule { std::vector<WordID>* result) const { unsigned vc = 0; result->clear(); - for (std::vector<WordID>::const_iterator i = f_.begin(); i != f_.end(); ++i) { - const WordID& c = *i; + for (const auto& c : f_) { if (c < 1) { - const std::vector<WordID>& var_value = *var_values[vc++]; + const auto& var_value = *var_values[vc++]; std::copy(var_value.begin(), var_value.end(), std::back_inserter(*result)); @@ -113,7 +109,7 @@ class TRule { assert(vc == var_values.size()); } - bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false); + bool ReadFromString(const std::string& line, bool monolingual = false); bool Initialized() const { return e_.size(); } @@ -166,7 +162,6 @@ class TRule { private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} - bool SanityCheck() const; }; inline size_t hash_value(const TRule& r) { |