summaryrefslogtreecommitdiff
path: root/decoder/apply_models.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r--decoder/apply_models.cc205
1 files changed, 196 insertions, 9 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9390c809..40fd27e4 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -17,6 +17,10 @@
#include "hg.h"
#include "ff.h"
+#define NORMAL_CP 1
+#define FAST_CP 2
+#define FAST_CP_2 3
+
using namespace std;
using namespace std::tr1;
@@ -164,13 +168,15 @@ public:
const SentenceMetadata& sm,
const Hypergraph& i,
int pop_limit,
- Hypergraph* o) :
+ Hypergraph* o,
+ int s = NORMAL_CP ) :
models(m),
smeta(sm),
in(i),
out(*o),
D(in.nodes_.size()),
- pop_limit_(pop_limit) {
+ pop_limit_(pop_limit),
+ strategy_(s){
if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
node_states_.reserve(kRESERVE_NUM_NODES);
}
@@ -184,9 +190,21 @@ public:
if (num_nodes > 100) every = 10;
assert(in.nodes_[pregoal].out_edges_.size() == 1);
if (!SILENT) cerr << " ";
+ int has = 0;
for (int i = 0; i < in.nodes_.size(); ++i) {
- if (!SILENT && i % every == 0) cerr << '.';
- KBest(i, i == goal_id);
+ if (!SILENT) {
+ int needs = (50 * i / in.nodes_.size());
+ while (has < needs) { cerr << '.'; ++has; }
+ }
+ if (strategy_==NORMAL_CP){
+ KBest(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP){
+ KBestFast(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP_2){
+ KBestFast2(i, i == goal_id);
+ }
}
if (!SILENT) {
cerr << endl;
@@ -258,8 +276,7 @@ public:
make_heap(cand.begin(), cand.end(), HeapCandCompare());
State2Node state2node; // "buf" in Figure 2
int pops = 0;
- int pop_limit_eff=max(1,int(v.promise*pop_limit_));
- while(!cand.empty() && pops < pop_limit_eff) {
+ while(!cand.empty() && pops < pop_limit_) {
pop_heap(cand.begin(), cand.end(), HeapCandCompare());
Candidate* item = cand.back();
cand.pop_back();
@@ -283,6 +300,114 @@ public:
delete freelist[i];
}
+ void KBestFast(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast(*item, is_goal, &cand);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void KBestFast2(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_accepted;
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ assert(unique_accepted.insert(item).second); // these should all be unique!
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
CandidateHeap& cand = *pcand;
for (int i = 0; i < item.j_.size(); ++i) {
@@ -300,6 +425,54 @@ public:
}
}
+ //PushSucc following unique ancestor generation function
+ void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ if(item.j_[i]!=0){
+ return;
+ }
+ }
+ }
+
+ //PushSucc only if all ancest Cand are added
+ void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (HasAllAncestors(&query_unique,ps)) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ }
+ }
+ }
+
+ bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){
+ for (int i = 0; i < item->j_.size(); ++i) {
+ JVector j = item->j_;
+ --j[i];
+ if (j[i] >=0) {
+ Candidate query_unique(*item->in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
const ModelSet& models;
const SentenceMetadata& smeta;
const Hypergraph& in;
@@ -311,6 +484,7 @@ public:
FFStates node_states_; // for each node in the out-HG what is
// its q function value?
const int pop_limit_;
+ const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
};
struct NoPruningRescorer {
@@ -412,15 +586,28 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
pl = max_pl_for_large;
cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
+ if (config.algorithm == IntersectionConfiguration::CUBE) {
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
+ ma.Apply();
+ }
+
} else {
cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
exit(1);