summaryrefslogtreecommitdiff
path: root/decoder/apply_models.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r--decoder/apply_models.cc306
1 files changed, 155 insertions, 151 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9a8f60be..9f8bbead 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -19,6 +19,7 @@ namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; }
#include <boost/functional/hash.hpp>
+#include "node_state_hash.h"
#include "verbose.h"
#include "hg.h"
#include "ff.h"
@@ -229,7 +230,7 @@ public:
D.clear();
}
- void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) {
+ void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {
Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);
new_edge->edge_prob_ = item->out_edge_.edge_prob_;
Candidate*& o_item = (*s2n)[item->state_];
@@ -238,6 +239,7 @@ public:
int& node_id = o_item->node_index_;
if (node_id < 0) {
Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_);
+ new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state
node_states_.push_back(item->state_);
node_id = new_node->id_;
}
@@ -287,7 +289,7 @@ public:
cand.pop_back();
// cerr << "POPPED: " << *item << endl;
PushSucc(*item, is_goal, &cand, &unique_cands);
- IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist);
++pops;
}
D_v.resize(state2node.size());
@@ -306,112 +308,112 @@ public:
}
void KBestFast(const int vert_index, const bool is_goal) {
- // cerr << "KBest(" << vert_index << ")\n";
- CandidateList& D_v = D[vert_index];
- assert(D_v.empty());
- const Hypergraph::Node& v = in.nodes_[vert_index];
- // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
- const vector<int>& in_edges = v.in_edges_;
- CandidateHeap cand;
- CandidateList freelist;
- cand.reserve(in_edges.size());
- //init with j<0,0> for all rules-edges that lead to node-(NT-span)
- for (int i = 0; i < in_edges.size(); ++i) {
- const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
- const JVector j(edge.tail_nodes_.size(), 0);
- cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
- }
- // cerr << " making heap of " << cand.size() << " candidates\n";
- make_heap(cand.begin(), cand.end(), HeapCandCompare());
- State2Node state2node; // "buf" in Figure 2
- int pops = 0;
- while(!cand.empty() && pops < pop_limit_) {
- pop_heap(cand.begin(), cand.end(), HeapCandCompare());
- Candidate* item = cand.back();
- cand.pop_back();
- // cerr << "POPPED: " << *item << endl;
-
- PushSuccFast(*item, is_goal, &cand);
- IncorporateIntoPlusLMForest(item, &state2node, &freelist);
- ++pops;
- }
- D_v.resize(state2node.size());
- int c = 0;
- for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
- D_v[c++] = i->second;
- // cerr << "MERGED: " << *i->second << endl;
- }
- //cerr <<"Node id: "<< vert_index<< endl;
- //#ifdef MEASURE_CA
- // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
- // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
- //#endif
- sort(D_v.begin(), D_v.end(), EstProbSorter());
-
- // cerr << " expanded to " << D_v.size() << " nodes\n";
-
- for (int i = 0; i < cand.size(); ++i)
- delete cand[i];
- // freelist is necessary since even after an item merged, it still stays in
- // the unique set so it can't be deleted til now
- for (int i = 0; i < freelist.size(); ++i)
- delete freelist[i];
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast(*item, is_goal, &cand);
+ IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (auto& i : state2node) {
+ D_v[c++] = i.second;
+ // cerr << "MERGED: " << *i.second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
}
void KBestFast2(const int vert_index, const bool is_goal) {
- // cerr << "KBest(" << vert_index << ")\n";
- CandidateList& D_v = D[vert_index];
- assert(D_v.empty());
- const Hypergraph::Node& v = in.nodes_[vert_index];
- // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
- const vector<int>& in_edges = v.in_edges_;
- CandidateHeap cand;
- CandidateList freelist;
- cand.reserve(in_edges.size());
- UniqueCandidateSet unique_accepted;
- //init with j<0,0> for all rules-edges that lead to node-(NT-span)
- for (int i = 0; i < in_edges.size(); ++i) {
- const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
- const JVector j(edge.tail_nodes_.size(), 0);
- cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
- }
- // cerr << " making heap of " << cand.size() << " candidates\n";
- make_heap(cand.begin(), cand.end(), HeapCandCompare());
- State2Node state2node; // "buf" in Figure 2
- int pops = 0;
- while(!cand.empty() && pops < pop_limit_) {
- pop_heap(cand.begin(), cand.end(), HeapCandCompare());
- Candidate* item = cand.back();
- cand.pop_back();
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_accepted;
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
bool is_new = unique_accepted.insert(item).second;
- assert(is_new); // these should all be unique!
- // cerr << "POPPED: " << *item << endl;
-
- PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
- IncorporateIntoPlusLMForest(item, &state2node, &freelist);
- ++pops;
- }
- D_v.resize(state2node.size());
- int c = 0;
- for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
- D_v[c++] = i->second;
- // cerr << "MERGED: " << *i->second << endl;
- }
- //cerr <<"Node id: "<< vert_index<< endl;
- //#ifdef MEASURE_CA
- // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
- // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
- //#endif
- sort(D_v.begin(), D_v.end(), EstProbSorter());
-
- // cerr << " expanded to " << D_v.size() << " nodes\n";
-
- for (int i = 0; i < cand.size(); ++i)
- delete cand[i];
- // freelist is necessary since even after an item merged, it still stays in
- // the unique set so it can't be deleted til now
- for (int i = 0; i < freelist.size(); ++i)
- delete freelist[i];
+ assert(is_new); // these should all be unique!
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
+ IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
}
void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
@@ -434,50 +436,50 @@ public:
//PushSucc following unique ancestor generation function
void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){
- CandidateHeap& cand = *pcand;
- for (int i = 0; i < item.j_.size(); ++i) {
- JVector j = item.j_;
- ++j[i];
- if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
- Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
- cand.push_back(new_cand);
- push_heap(cand.begin(), cand.end(), HeapCandCompare());
- }
- if(item.j_[i]!=0){
- return;
- }
- }
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ if(item.j_[i]!=0){
+ return;
+ }
+ }
}
//PushSucc only if all ancest Cand are added
void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){
- CandidateHeap& cand = *pcand;
- for (int i = 0; i < item.j_.size(); ++i) {
- JVector j = item.j_;
- ++j[i];
- if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
- Candidate query_unique(*item.in_edge_, j);
- if (HasAllAncestors(&query_unique,ps)) {
- Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
- cand.push_back(new_cand);
- push_heap(cand.begin(), cand.end(), HeapCandCompare());
- }
- }
- }
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (HasAllAncestors(&query_unique,ps)) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ }
+ }
}
bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){
- for (int i = 0; i < item->j_.size(); ++i) {
- JVector j = item->j_;
- --j[i];
- if (j[i] >=0) {
- Candidate query_unique(*item->in_edge_, j);
- if (cs->count(&query_unique) == 0) {
- return false;
- }
- }
- }
- return true;
+ for (int i = 0; i < item->j_.size(); ++i) {
+ JVector j = item->j_;
+ --j[i];
+ if (j[i] >=0) {
+ Candidate query_unique(*item->in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ return false;
+ }
+ }
+ }
+ return true;
}
const ModelSet& models;
@@ -491,7 +493,7 @@ public:
FFStates node_states_; // for each node in the out-HG what is
// its q function value?
const int pop_limit_;
- const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
+ const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
};
struct NoPruningRescorer {
@@ -507,7 +509,7 @@ struct NoPruningRescorer {
typedef unordered_map<FFState, int, boost::hash<FFState> > State2NodeIndex;
- void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) {
+ void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) {
const int arity = in_edge.Arity();
Hypergraph::TailNodeVector ends(arity);
for (int i = 0; i < arity; ++i)
@@ -531,7 +533,9 @@ struct NoPruningRescorer {
}
int& head_plus1 = (*state2node)[head_state];
if (!head_plus1) {
- head_plus1 = out.AddNode(in_edge.rule_->GetLHS())->id_ + 1;
+ HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS());
+ new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state
+ head_plus1 = new_node->id_ + 1;
node_states_.push_back(head_state);
nodemap[in_edge.head_node_].push_back(head_plus1 - 1);
}
@@ -553,7 +557,7 @@ struct NoPruningRescorer {
const Hypergraph::Node& node = in.nodes_[node_num];
for (int i = 0; i < node.in_edges_.size(); ++i) {
const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]];
- ExpandEdge(edge, is_goal, &state2node);
+ ExpandEdge(edge, is_goal, node.node_hash, &state2node);
}
}
@@ -605,16 +609,16 @@ void ApplyModelSet(const Hypergraph& in,
cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
}
if (config.algorithm == IntersectionConfiguration::CUBE) {
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
}
else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){
- CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
- ma.Apply();
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
+ ma.Apply();
}
else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){
- CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
- ma.Apply();
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
+ ma.Apply();
}
} else {