summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorandrea gesmundo <andrea.gesmundo@gmail.com>2011-07-08 13:56:42 +0200
committerandrea gesmundo <andrea.gesmundo@gmail.com>2011-07-08 13:56:42 +0200
commit85b7e8ae194d98ca85da1692e2679db8defff91b (patch)
tree670572379dd186f37b6810de4fe3f2116db329d3 /decoder
parent4b73c9b8d22cd490e4bf735641aab592b346e966 (diff)
add Fast Cube Pruning
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc196
-rw-r--r--decoder/apply_models.h6
-rw-r--r--decoder/decoder.cc10
3 files changed, 204 insertions, 8 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9390c809..62eff262 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -17,6 +17,10 @@
#include "hg.h"
#include "ff.h"
+#define NORMAL_CP 1
+#define FAST_CP 2
+#define FAST_CP_2 3
+
using namespace std;
using namespace std::tr1;
@@ -164,13 +168,15 @@ public:
const SentenceMetadata& sm,
const Hypergraph& i,
int pop_limit,
- Hypergraph* o) :
+ Hypergraph* o,
+ int s = NORMAL_CP ) :
models(m),
smeta(sm),
in(i),
out(*o),
D(in.nodes_.size()),
- pop_limit_(pop_limit) {
+ pop_limit_(pop_limit),
+ strategy_(s){
if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
node_states_.reserve(kRESERVE_NUM_NODES);
}
@@ -186,7 +192,15 @@ public:
if (!SILENT) cerr << " ";
for (int i = 0; i < in.nodes_.size(); ++i) {
if (!SILENT && i % every == 0) cerr << '.';
- KBest(i, i == goal_id);
+ if (strategy_==NORMAL_CP){
+ KBest(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP){
+ KBestFast(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP_2){
+ KBestFast2(i, i == goal_id);
+ }
}
if (!SILENT) {
cerr << endl;
@@ -283,6 +297,114 @@ public:
delete freelist[i];
}
+ void KBestFast(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast(*item, is_goal, &cand);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void KBestFast2(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_accepted;
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ assert(unique_accepted.insert(item).second); // these should all be unique!
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
CandidateHeap& cand = *pcand;
for (int i = 0; i < item.j_.size(); ++i) {
@@ -300,6 +422,54 @@ public:
}
}
+ //PushSucc following unique ancestor generation function
+ void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ if(item.j_[i]!=0){
+ return;
+ }
+ }
+ }
+
+ //PushSucc only if all ancest Cand are added
+ void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (HasAllAncestors(&query_unique,ps)) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ }
+ }
+ }
+
+ bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){
+ for (int i = 0; i < item->j_.size(); ++i) {
+ JVector j = item->j_;
+ --j[i];
+ if (j[i] >=0) {
+ Candidate query_unique(*item->in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
const ModelSet& models;
const SentenceMetadata& smeta;
const Hypergraph& in;
@@ -311,6 +481,7 @@ public:
FFStates node_states_; // for each node in the out-HG what is
// its q function value?
const int pop_limit_;
+ const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
};
struct NoPruningRescorer {
@@ -412,15 +583,28 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
pl = max_pl_for_large;
cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
+ if (config.algorithm == IntersectionConfiguration::CUBE) {
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
+ ma.Apply();
+ }
+
} else {
cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
exit(1);
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index a85694aa..19a4c7be 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -13,6 +13,8 @@ struct IntersectionConfiguration {
enum {
FULL,
CUBE,
+ FAST_CUBE_PRUNING,
+ FAST_CUBE_PRUNING_2,
N_ALGORITHMS
};
@@ -25,7 +27,9 @@ enum {
inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) {
if (c.algorithm == 0) { os << "FULL"; }
else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; }
- else if (c.algorithm == 2) { os << "N_ALGORITHMS"; }
+ else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; }
+ else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; }
+ else if (c.algorithm == 4) { os << "N_ALGORITHMS"; }
else os << "OTHER";
return os;
}
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 2c3a06de..8a4a1485 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -357,7 +357,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -597,6 +597,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning") {
+ palg = 2;
+ cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning_2") {
+ palg = 3;
+ cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit));
} else {
break; // TODO alert user if there are any future configurations