diff options
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r-- | decoder/apply_models.cc | 205 |
1 files changed, 196 insertions, 9 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9390c809..40fd27e4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -17,6 +17,10 @@ #include "hg.h" #include "ff.h" +#define NORMAL_CP 1 +#define FAST_CP 2 +#define FAST_CP_2 3 + using namespace std; using namespace std::tr1; @@ -164,13 +168,15 @@ public: const SentenceMetadata& sm, const Hypergraph& i, int pop_limit, - Hypergraph* o) : + Hypergraph* o, + int s = NORMAL_CP ) : models(m), smeta(sm), in(i), out(*o), D(in.nodes_.size()), - pop_limit_(pop_limit) { + pop_limit_(pop_limit), + strategy_(s){ if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; node_states_.reserve(kRESERVE_NUM_NODES); } @@ -184,9 +190,21 @@ public: if (num_nodes > 100) every = 10; assert(in.nodes_[pregoal].out_edges_.size() == 1); if (!SILENT) cerr << " "; + int has = 0; for (int i = 0; i < in.nodes_.size(); ++i) { - if (!SILENT && i % every == 0) cerr << '.'; - KBest(i, i == goal_id); + if (!SILENT) { + int needs = (50 * i / in.nodes_.size()); + while (has < needs) { cerr << '.'; ++has; } + } + if (strategy_==NORMAL_CP){ + KBest(i, i == goal_id); + } + if (strategy_==FAST_CP){ + KBestFast(i, i == goal_id); + } + if (strategy_==FAST_CP_2){ + KBestFast2(i, i == goal_id); + } } if (!SILENT) { cerr << endl; @@ -258,8 +276,7 @@ public: make_heap(cand.begin(), cand.end(), HeapCandCompare()); State2Node state2node; // "buf" in Figure 2 int pops = 0; - int pop_limit_eff=max(1,int(v.promise*pop_limit_)); - while(!cand.empty() && pops < pop_limit_eff) { + while(!cand.empty() && pops < pop_limit_) { pop_heap(cand.begin(), cand.end(), HeapCandCompare()); Candidate* item = cand.back(); cand.pop_back(); @@ -283,6 +300,114 @@ public: delete freelist[i]; } + void KBestFast(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void KBestFast2(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector<int>& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + assert(unique_accepted.insert(item).second); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl; + // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl; + //#endif + sort(D_v.begin(), D_v.end(), EstProbSorter()); + + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { CandidateHeap& cand = *pcand; for (int i = 0; i < item.j_.size(); ++i) { @@ -300,6 +425,54 @@ public: } } + //PushSucc following unique ancestor generation function + void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } + } + + //PushSucc only if all ancest Cand are added + void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } + } + + bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; + } + const ModelSet& models; const SentenceMetadata& smeta; const Hypergraph& in; @@ -311,6 +484,7 @@ public: FFStates node_states_; // for each node in the out-HG what is // its q function value? const int pop_limit_; + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) }; struct NoPruningRescorer { @@ -412,15 +586,28 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE) { + } else if (config.algorithm == IntersectionConfiguration::CUBE + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING + || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { pl = max_pl_for_large; cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); + if (config.algorithm == IntersectionConfiguration::CUBE) { + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); + } + } else { cerr << "Don't understand intersection algorithm " << config.algorithm << endl; exit(1); |