diff options
Diffstat (limited to 'decoder/apply_models.cc')
| -rw-r--r-- | decoder/apply_models.cc | 306 | 
1 files changed, 155 insertions, 151 deletions
| diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9a8f60be..9f8bbead 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; }  #include <boost/functional/hash.hpp> +#include "node_state_hash.h"  #include "verbose.h"  #include "hg.h"  #include "ff.h" @@ -229,7 +230,7 @@ public:      D.clear();    } -  void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { +  void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {      Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);      new_edge->edge_prob_ = item->out_edge_.edge_prob_;      Candidate*& o_item = (*s2n)[item->state_]; @@ -238,6 +239,7 @@ public:      int& node_id = o_item->node_index_;      if (node_id < 0) {        Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); +      new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state        node_states_.push_back(item->state_);        node_id = new_node->id_;      } @@ -287,7 +289,7 @@ public:        cand.pop_back();        // cerr << "POPPED: " << *item << endl;        PushSucc(*item, is_goal, &cand, &unique_cands); -      IncorporateIntoPlusLMForest(item, &state2node, &freelist); +      IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist);        ++pops;      }      D_v.resize(state2node.size()); @@ -306,112 +308,112 @@ public:    }    void KBestFast(const int vert_index, const bool is_goal) { -	  // cerr << "KBest(" << vert_index << ")\n"; -	  CandidateList& D_v = D[vert_index]; -	  assert(D_v.empty()); -	  const Hypergraph::Node& v = in.nodes_[vert_index]; -	  // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; -	  const vector<int>& in_edges = v.in_edges_; -	  CandidateHeap cand; -	  CandidateList freelist; -	  cand.reserve(in_edges.size()); -	  //init with j<0,0> for all rules-edges that lead to node-(NT-span) -	  for (int i = 0; i < in_edges.size(); ++i) { -		  const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; -		  const JVector j(edge.tail_nodes_.size(), 0); -		  cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); -	  } -	  // cerr << " making heap of " << cand.size() << " candidates\n"; -	  make_heap(cand.begin(), cand.end(), HeapCandCompare()); -	  State2Node state2node; // "buf" in Figure 2 -	  int pops = 0; -	  while(!cand.empty() && pops < pop_limit_) { -		  pop_heap(cand.begin(), cand.end(), HeapCandCompare()); -		  Candidate* item = cand.back(); -		  cand.pop_back(); -		  // cerr << "POPPED: " << *item << endl; - -		  PushSuccFast(*item, is_goal, &cand); -		  IncorporateIntoPlusLMForest(item, &state2node, &freelist); -		  ++pops; -	  } -	  D_v.resize(state2node.size()); -	  int c = 0; -	  for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ -		  D_v[c++] = i->second; -		  // cerr << "MERGED: " << *i->second << endl; -	  } -	  //cerr <<"Node id: "<< vert_index<< endl; -	  //#ifdef MEASURE_CA -	  // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; -	  // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; -	  //#endif -	  sort(D_v.begin(), D_v.end(), EstProbSorter()); - -	  // cerr << " expanded to " << D_v.size() << " nodes\n"; - -	  for (int i = 0; i < cand.size(); ++i) -		  delete cand[i]; -	  // freelist is necessary since even after an item merged, it still stays in -	  // the unique set so it can't be deleted til now -	  for (int i = 0; i < freelist.size(); ++i) -		  delete freelist[i]; +    // cerr << "KBest(" << vert_index << ")\n"; +    CandidateList& D_v = D[vert_index]; +    assert(D_v.empty()); +    const Hypergraph::Node& v = in.nodes_[vert_index]; +    // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; +    const vector<int>& in_edges = v.in_edges_; +    CandidateHeap cand; +    CandidateList freelist; +    cand.reserve(in_edges.size()); +    //init with j<0,0> for all rules-edges that lead to node-(NT-span) +    for (int i = 0; i < in_edges.size(); ++i) { +      const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; +      const JVector j(edge.tail_nodes_.size(), 0); +      cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); +    } +    // cerr << " making heap of " << cand.size() << " candidates\n"; +    make_heap(cand.begin(), cand.end(), HeapCandCompare()); +    State2Node state2node; // "buf" in Figure 2 +    int pops = 0; +    while(!cand.empty() && pops < pop_limit_) { +      pop_heap(cand.begin(), cand.end(), HeapCandCompare()); +      Candidate* item = cand.back(); +      cand.pop_back(); +      // cerr << "POPPED: " << *item << endl; + +      PushSuccFast(*item, is_goal, &cand); +      IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); +      ++pops; +    } +    D_v.resize(state2node.size()); +    int c = 0; +    for (auto& i : state2node) { +      D_v[c++] = i.second; +      // cerr << "MERGED: " << *i.second << endl; +    } +    //cerr <<"Node id: "<< vert_index<< endl; +    //#ifdef MEASURE_CA +    // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; +    // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; +    //#endif +    sort(D_v.begin(), D_v.end(), EstProbSorter()); + +    // cerr << " expanded to " << D_v.size() << " nodes\n"; + +    for (int i = 0; i < cand.size(); ++i) +      delete cand[i]; +    // freelist is necessary since even after an item merged, it still stays in +    // the unique set so it can't be deleted til now +    for (int i = 0; i < freelist.size(); ++i) +      delete freelist[i];    }    void KBestFast2(const int vert_index, const bool is_goal) { -	  // cerr << "KBest(" << vert_index << ")\n"; -	  CandidateList& D_v = D[vert_index]; -	  assert(D_v.empty()); -	  const Hypergraph::Node& v = in.nodes_[vert_index]; -	  // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; -	  const vector<int>& in_edges = v.in_edges_; -	  CandidateHeap cand; -	  CandidateList freelist; -	  cand.reserve(in_edges.size()); -	  UniqueCandidateSet unique_accepted; -	  //init with j<0,0> for all rules-edges that lead to node-(NT-span) -	  for (int i = 0; i < in_edges.size(); ++i) { -		  const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; -		  const JVector j(edge.tail_nodes_.size(), 0); -		  cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); -	  } -	  // cerr << " making heap of " << cand.size() << " candidates\n"; -	  make_heap(cand.begin(), cand.end(), HeapCandCompare()); -	  State2Node state2node; // "buf" in Figure 2 -	  int pops = 0; -	  while(!cand.empty() && pops < pop_limit_) { -		  pop_heap(cand.begin(), cand.end(), HeapCandCompare()); -		  Candidate* item = cand.back(); -		  cand.pop_back(); +    // cerr << "KBest(" << vert_index << ")\n"; +    CandidateList& D_v = D[vert_index]; +    assert(D_v.empty()); +    const Hypergraph::Node& v = in.nodes_[vert_index]; +    // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; +    const vector<int>& in_edges = v.in_edges_; +    CandidateHeap cand; +    CandidateList freelist; +    cand.reserve(in_edges.size()); +    UniqueCandidateSet unique_accepted; +    //init with j<0,0> for all rules-edges that lead to node-(NT-span) +    for (int i = 0; i < in_edges.size(); ++i) { +      const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; +      const JVector j(edge.tail_nodes_.size(), 0); +      cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); +    } +    // cerr << " making heap of " << cand.size() << " candidates\n"; +    make_heap(cand.begin(), cand.end(), HeapCandCompare()); +    State2Node state2node; // "buf" in Figure 2 +    int pops = 0; +    while(!cand.empty() && pops < pop_limit_) { +      pop_heap(cand.begin(), cand.end(), HeapCandCompare()); +      Candidate* item = cand.back(); +      cand.pop_back();                    bool is_new = unique_accepted.insert(item).second; -		  assert(is_new); // these should all be unique! -		  // cerr << "POPPED: " << *item << endl; - -		  PushSuccFast2(*item, is_goal, &cand, &unique_accepted); -		  IncorporateIntoPlusLMForest(item, &state2node, &freelist); -		  ++pops; -	  } -	  D_v.resize(state2node.size()); -	  int c = 0; -	  for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ -		  D_v[c++] = i->second; -		  // cerr << "MERGED: " << *i->second << endl; -	  } -	  //cerr <<"Node id: "<< vert_index<< endl; -	  //#ifdef MEASURE_CA -	  // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; -	  // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; -	  //#endif -	  sort(D_v.begin(), D_v.end(), EstProbSorter()); - -	  // cerr << " expanded to " << D_v.size() << " nodes\n"; - -	  for (int i = 0; i < cand.size(); ++i) -		  delete cand[i]; -	  // freelist is necessary since even after an item merged, it still stays in -	  // the unique set so it can't be deleted til now -	  for (int i = 0; i < freelist.size(); ++i) -		  delete freelist[i]; +      assert(is_new); // these should all be unique! +      // cerr << "POPPED: " << *item << endl; + +      PushSuccFast2(*item, is_goal, &cand, &unique_accepted); +      IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); +      ++pops; +    } +    D_v.resize(state2node.size()); +    int c = 0; +    for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ +      D_v[c++] = i->second; +      // cerr << "MERGED: " << *i->second << endl; +    } +    //cerr <<"Node id: "<< vert_index<< endl; +    //#ifdef MEASURE_CA +    // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; +    // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; +    //#endif +    sort(D_v.begin(), D_v.end(), EstProbSorter()); + +    // cerr << " expanded to " << D_v.size() << " nodes\n"; + +    for (int i = 0; i < cand.size(); ++i) +      delete cand[i]; +    // freelist is necessary since even after an item merged, it still stays in +    // the unique set so it can't be deleted til now +    for (int i = 0; i < freelist.size(); ++i) +      delete freelist[i];    }    void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { @@ -434,50 +436,50 @@ public:    //PushSucc following unique ancestor generation function    void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){ -	  CandidateHeap& cand = *pcand; -	  for (int i = 0; i < item.j_.size(); ++i) { -		  JVector j = item.j_; -		  ++j[i]; -		  if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { -			  Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); -			  cand.push_back(new_cand); -			  push_heap(cand.begin(), cand.end(), HeapCandCompare()); -		  } -		  if(item.j_[i]!=0){ -			  return; -		  } -	  } +    CandidateHeap& cand = *pcand; +    for (int i = 0; i < item.j_.size(); ++i) { +      JVector j = item.j_; +      ++j[i]; +      if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { +        Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); +        cand.push_back(new_cand); +        push_heap(cand.begin(), cand.end(), HeapCandCompare()); +      } +      if(item.j_[i]!=0){ +        return; +      } +    }    }    //PushSucc only if all ancest Cand are added    void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ -	  CandidateHeap& cand = *pcand; -	  for (int i = 0; i < item.j_.size(); ++i) { -		  JVector j = item.j_; -		  ++j[i]; -		  if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { -			  Candidate query_unique(*item.in_edge_, j); -			  if (HasAllAncestors(&query_unique,ps)) { -				  Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); -				  cand.push_back(new_cand); -				  push_heap(cand.begin(), cand.end(), HeapCandCompare()); -			  } -		  } -	  } +    CandidateHeap& cand = *pcand; +    for (int i = 0; i < item.j_.size(); ++i) { +      JVector j = item.j_; +      ++j[i]; +      if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { +        Candidate query_unique(*item.in_edge_, j); +        if (HasAllAncestors(&query_unique,ps)) { +          Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); +          cand.push_back(new_cand); +          push_heap(cand.begin(), cand.end(), HeapCandCompare()); +        } +      } +    }    }    bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ -	  for (int i = 0; i < item->j_.size(); ++i) { -		  JVector j = item->j_; -		  --j[i]; -		  if (j[i] >=0) { -			  Candidate query_unique(*item->in_edge_, j); -			  if (cs->count(&query_unique) == 0) { -				  return false; -			  } -		  } -	  } -	  return true; +    for (int i = 0; i < item->j_.size(); ++i) { +      JVector j = item->j_; +      --j[i]; +      if (j[i] >=0) { +        Candidate query_unique(*item->in_edge_, j); +        if (cs->count(&query_unique) == 0) { +          return false; +        } +      } +    } +    return true;    }    const ModelSet& models; @@ -491,7 +493,7 @@ public:    FFStates node_states_;  // for each node in the out-HG what is                               // its q function value?    const int pop_limit_; - const int strategy_;       //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) +  const int strategy_;       //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)  };  struct NoPruningRescorer { @@ -507,7 +509,7 @@ struct NoPruningRescorer {    typedef unordered_map<FFState, int, boost::hash<FFState> > State2NodeIndex; -  void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { +  void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) {      const int arity = in_edge.Arity();      Hypergraph::TailNodeVector ends(arity);      for (int i = 0; i < arity; ++i) @@ -531,7 +533,9 @@ struct NoPruningRescorer {        }        int& head_plus1 = (*state2node)[head_state];        if (!head_plus1) { -        head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1; +        HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); +        new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state +        head_plus1 = new_node->id_ + 1;          node_states_.push_back(head_state);          nodemap[in_edge.head_node_].push_back(head_plus1 - 1);        } @@ -553,7 +557,7 @@ struct NoPruningRescorer {      const Hypergraph::Node& node = in.nodes_[node_num];      for (int i = 0; i < node.in_edges_.size(); ++i) {        const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; -      ExpandEdge(edge, is_goal, &state2node); +      ExpandEdge(edge, is_goal, node.node_hash, &state2node);      }    } @@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in,        cerr << "  Note: reducing pop_limit to " << pl << " for very large forest\n";      }      if      (config.algorithm == IntersectionConfiguration::CUBE) { -    	CubePruningRescorer ma(models, smeta, in, pl, out); -        ma.Apply(); +      CubePruningRescorer ma(models, smeta, in, pl, out); +      ma.Apply();      }      else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ -    	CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); -        ma.Apply(); +      CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); +      ma.Apply();      }      else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ -    	CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); -        ma.Apply(); +      CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); +      ma.Apply();      }    } else { | 
