diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-07-26 15:39:26 +0100 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-07-26 15:39:26 +0100 |
commit | 004d7447341b938efdec9223f899cf67f6b0dd86 (patch) | |
tree | 3b6da42ae9f5158b7c7e366aa1051e3277d937e6 /decoder/apply_models.cc | |
parent | c3828b0a2deb42de5c7378e93f93f5e69efb304c (diff) | |
parent | f80150140b5273fd1eb0dfb34bdd789c4cbd35e6 (diff) |
Merge remote branch 'agesmundo/master'
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r-- | decoder/apply_models.cc | 196 |
1 files changed, 190 insertions, 6 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9390c809..62eff262 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -17,6 +17,10 @@ #include "hg.h" #include "ff.h" +#define NORMAL_CP 1 +#define FAST_CP 2 +#define FAST_CP_2 3 + using namespace std; using namespace std::tr1; @@ -164,13 +168,15 @@ public: const SentenceMetadata& sm, const Hypergraph& i, int pop_limit, - Hypergraph* o) : + Hypergraph* o, + int s = NORMAL_CP ) : models(m), smeta(sm), in(i), out(*o), D(in.nodes_.size()), - pop_limit_(pop_limit) { + pop_limit_(pop_limit), + strategy_(s){ if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; node_states_.reserve(kRESERVE_NUM_NODES); } @@ -186,7 +192,15 @@ public: if (!SILENT) cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { if (!SILENT && i % every == 0) cerr << '.'; - KBest(i, i == goal_id); + if (strategy_==NORMAL_CP){ + KBest(i, i == goal_id); + } + if (strategy_==FAST_CP){ + KBestFast(i, i == goal_id); + } + if (strategy_==FAST_CP_2){ + KBestFast2(i, i == goal_id); + } } if (!SILENT) { cerr << endl; @@ -283,6 +297,114 @@ public: delete freelist[i]; } + void KBestFast(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void KBestFast2(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + assert(unique_accepted.insert(item).second); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { CandidateHeap& cand = *pcand; for (int i = 0; i < item.j_.size(); ++i) { @@ -300,6 +422,54 @@ public: } } + //PushSucc following unique ancestor generation function + void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } + } + + //PushSucc only if all ancest Cand are added + void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } + } + + bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; + } + const ModelSet& models; const SentenceMetadata& smeta; const Hypergraph& in; @@ -311,6 +481,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -412,15 +583,28 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE) { + } else if (config.algorithm == IntersectionConfiguration::CUBE + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { pl = max_pl_for_large; cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + if (config.algorithm == IntersectionConfiguration::CUBE) { + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); + } + } else { cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); |