From b9e6e7e24cc48021090b689e143288e2b7f2b5fc Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 7 Apr 2014 22:56:34 -0400 Subject: track node state for smarter union --- decoder/apply_models.cc | 306 ++++++++++++++++++++++++------------------------ 1 file changed, 155 insertions(+), 151 deletions(-) (limited to 'decoder/apply_models.cc') diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9a8f60be..9f8bbead 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } #include +#include "node_state_hash.h" #include "verbose.h" #include "hg.h" #include "ff.h" @@ -229,7 +230,7 @@ public: D.clear(); } - void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; Candidate*& o_item = (*s2n)[item->state_]; @@ -238,6 +239,7 @@ public: int& node_id = o_item->node_index_; if (node_id < 0) { Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); + new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state node_states_.push_back(item->state_); node_id = new_node->id_; } @@ -287,7 +289,7 @@ public: cand.pop_back(); // cerr << "POPPED: " << *item << endl; PushSucc(*item, is_goal, &cand, &unique_cands); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); ++pops; } D_v.resize(state2node.size()); @@ -306,112 +308,112 @@ public: } void KBestFast(const int vert_index, const bool is_goal) { - // cerr << "KBest(" << vert_index << ")\n"; - CandidateList& D_v = D[vert_index]; - assert(D_v.empty()); - const Hypergraph::Node& v = in.nodes_[vert_index]; - // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; - const vector& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); - // cerr << "POPPED: " << *item << endl; - - PushSuccFast(*item, is_goal, &cand); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (auto& i : state2node) { + D_v[c++] = i.second; + // cerr << "MERGED: " << *i.second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; - CandidateHeap cand; - CandidateList freelist; - cand.reserve(in_edges.size()); - UniqueCandidateSet unique_accepted; - //init with j<0,0> for all rules-edges that lead to node-(NT-span) - for (int i = 0; i < in_edges.size(); ++i) { - const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; - const JVector j(edge.tail_nodes_.size(), 0); - cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); - } - // cerr << " making heap of " << cand.size() << " candidates\n"; - make_heap(cand.begin(), cand.end(), HeapCandCompare()); - State2Node state2node; // "buf" in Figure 2 - int pops = 0; - while(!cand.empty() && pops < pop_limit_) { - pop_heap(cand.begin(), cand.end(), HeapCandCompare()); - Candidate* item = cand.back(); - cand.pop_back(); + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); bool is_new = unique_accepted.insert(item).second; - assert(is_new); // these should all be unique! - // cerr << "POPPED: " << *item << endl; - - PushSuccFast2(*item, is_goal, &cand, &unique_accepted); - IncorporateIntoPlusLMForest(item, &state2node, &freelist); - ++pops; - } - D_v.resize(state2node.size()); - int c = 0; - for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ - D_v[c++] = i->second; - // cerr << "MERGED: " << *i->second << endl; - } - //cerr <<"Node id: "<< vert_index<< endl; - //#ifdef MEASURE_CA - // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - if(item.j_[i]!=0){ - return; - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } } //PushSucc only if all ancest Cand are added void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ - CandidateHeap& cand = *pcand; - for (int i = 0; i < item.j_.size(); ++i) { - JVector j = item.j_; - ++j[i]; - if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { - Candidate query_unique(*item.in_edge_, j); - if (HasAllAncestors(&query_unique,ps)) { - Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); - cand.push_back(new_cand); - push_heap(cand.begin(), cand.end(), HeapCandCompare()); - } - } - } + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } } bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ - for (int i = 0; i < item->j_.size(); ++i) { - JVector j = item->j_; - --j[i]; - if (j[i] >=0) { - Candidate query_unique(*item->in_edge_, j); - if (cs->count(&query_unique) == 0) { - return false; - } - } - } - return true; + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; } const ModelSet& models; @@ -491,7 +493,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; - const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -507,7 +509,7 @@ struct NoPruningRescorer { typedef unordered_map > State2NodeIndex; - void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) { const int arity = in_edge.Arity(); Hypergraph::TailNodeVector ends(arity); for (int i = 0; i < arity; ++i) @@ -531,7 +533,9 @@ struct NoPruningRescorer { } int& head_plus1 = (*state2node)[head_state]; if (!head_plus1) { - head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1; + HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); + new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state + head_plus1 = new_node->id_ + 1; node_states_.push_back(head_state); nodemap[in_edge.head_node_].push_back(head_plus1 - 1); } @@ -553,7 +557,7 @@ struct NoPruningRescorer { const Hypergraph::Node& node = in.nodes_[node_num]; for (int i = 0; i < node.in_edges_.size(); ++i) { const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; - ExpandEdge(edge, is_goal, &state2node); + ExpandEdge(edge, is_goal, node.node_hash, &state2node); } } @@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in, cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } if (config.algorithm == IntersectionConfiguration::CUBE) { - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); } else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ - CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); - ma.Apply(); + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); } } else { -- cgit v1.2.3