summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-26 15:39:26 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-26 15:39:26 +0100
commit02c47c31050505eacc2aa9a17e77214a7b456171 (patch)
treead39a2b2c328c6e4a5875e82f97d7de0a4cfa8b4 /decoder
parentb89c1f03c89c6c30b88099e4f3e0c1753d338ea7 (diff)
parentd2ca798de168f2009d3190d31e4b3f4b92985962 (diff)
Merge remote branch 'agesmundo/master'
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc196
-rw-r--r--decoder/apply_models.h6
-rw-r--r--decoder/cdec.cc8
-rw-r--r--decoder/decoder.cc23
-rw-r--r--decoder/decoder.h14
5 files changed, 239 insertions, 8 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9390c809..62eff262 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -17,6 +17,10 @@
#include "hg.h"
#include "ff.h"
+#define NORMAL_CP 1
+#define FAST_CP 2
+#define FAST_CP_2 3
+
using namespace std;
using namespace std::tr1;
@@ -164,13 +168,15 @@ public:
const SentenceMetadata& sm,
const Hypergraph& i,
int pop_limit,
- Hypergraph* o) :
+ Hypergraph* o,
+ int s = NORMAL_CP ) :
models(m),
smeta(sm),
in(i),
out(*o),
D(in.nodes_.size()),
- pop_limit_(pop_limit) {
+ pop_limit_(pop_limit),
+ strategy_(s){
if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
node_states_.reserve(kRESERVE_NUM_NODES);
}
@@ -186,7 +192,15 @@ public:
if (!SILENT) cerr << " ";
for (int i = 0; i < in.nodes_.size(); ++i) {
if (!SILENT && i % every == 0) cerr << '.';
- KBest(i, i == goal_id);
+ if (strategy_==NORMAL_CP){
+ KBest(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP){
+ KBestFast(i, i == goal_id);
+ }
+ if (strategy_==FAST_CP_2){
+ KBestFast2(i, i == goal_id);
+ }
}
if (!SILENT) {
cerr << endl;
@@ -283,6 +297,114 @@ public:
delete freelist[i];
}
+ void KBestFast(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast(*item, is_goal, &cand);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void KBestFast2(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_accepted;
+ //init with j<0,0> for all rules-edges that lead to node-(NT-span)
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal));
+ }
+ // cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ assert(unique_accepted.insert(item).second); // these should all be unique!
+ // cerr << "POPPED: " << *item << endl;
+
+ PushSuccFast2(*item, is_goal, &cand, &unique_accepted);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){
+ D_v[c++] = i->second;
+ // cerr << "MERGED: " << *i->second << endl;
+ }
+ //cerr <<"Node id: "<< vert_index<< endl;
+ //#ifdef MEASURE_CA
+ // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<<endl;
+ // cerr << "countAtEnd (pop/tot): node id: " << vert_index << " (" << count_at_end_pop << "/" << count_at_end_tot << ")"<<endl;
+ //#endif
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
CandidateHeap& cand = *pcand;
for (int i = 0; i < item.j_.size(); ++i) {
@@ -300,6 +422,54 @@ public:
}
}
+ //PushSucc following unique ancestor generation function
+ void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ if(item.j_[i]!=0){
+ return;
+ }
+ }
+ }
+
+ //PushSucc only if all ancest Cand are added
+ void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (HasAllAncestors(&query_unique,ps)) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ }
+ }
+ }
+ }
+
+ bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){
+ for (int i = 0; i < item->j_.size(); ++i) {
+ JVector j = item->j_;
+ --j[i];
+ if (j[i] >=0) {
+ Candidate query_unique(*item->in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
const ModelSet& models;
const SentenceMetadata& smeta;
const Hypergraph& in;
@@ -311,6 +481,7 @@ public:
FFStates node_states_; // for each node in the out-HG what is
// its q function value?
const int pop_limit_;
+ const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010)
};
struct NoPruningRescorer {
@@ -412,15 +583,28 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
+ || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
pl = max_pl_for_large;
cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
+ if (config.algorithm == IntersectionConfiguration::CUBE) {
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP);
+ ma.Apply();
+ }
+ else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){
+ CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2);
+ ma.Apply();
+ }
+
} else {
cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
exit(1);
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index a85694aa..19a4c7be 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -13,6 +13,8 @@ struct IntersectionConfiguration {
enum {
FULL,
CUBE,
+ FAST_CUBE_PRUNING,
+ FAST_CUBE_PRUNING_2,
N_ALGORITHMS
};
@@ -25,7 +27,9 @@ enum {
inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) {
if (c.algorithm == 0) { os << "FULL"; }
else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; }
- else if (c.algorithm == 2) { os << "N_ALGORITHMS"; }
+ else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; }
+ else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; }
+ else if (c.algorithm == 4) { os << "N_ALGORITHMS"; }
else os << "OTHER";
return os;
}
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 5c40f56e..c671af57 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -19,11 +19,19 @@ int main(int argc, char** argv) {
assert(*in);
string buf;
+#ifdef CP_TIME
+ clock_t time_cp(0);//, end_cp;
+#endif
while(*in) {
getline(*in, buf);
if (buf.empty()) continue;
decoder.Decode(buf);
}
+#ifdef CP_TIME
+ cerr << "Time required for Cube Pruning execution: "
+ << CpTime::Get()
+ << " seconds." << "\n\n";
+#endif
if (show_feature_dictionary) {
int num = FD::NumFeats();
for (int i = 1; i < num; ++i) {
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 2c3a06de..76f31352 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -46,6 +46,13 @@
#include "cfg_options.h"
#endif
+#ifdef CP_TIME
+ clock_t CpTime::time_;
+ void CpTime::Add(clock_t x){time_+=x;}
+ void CpTime::Sub(clock_t x){time_-=x;}
+ double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;}
+#endif
+
static const double kMINUS_EPSILON = -1e-6; // don't be too strict
using namespace std;
@@ -357,7 +364,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -597,6 +604,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning") {
+ palg = 2;
+ cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning_2") {
+ palg = 3;
+ cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit));
} else {
break; // TODO alert user if there are any future configurations
@@ -798,11 +813,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Timer t("Forest rescoring:");
rp.models->PrepareForInput(smeta);
Hypergraph rescored_forest;
+#ifdef CP_TIME
+ CpTime::Sub(clock());
+#endif
ApplyModelSet(forest,
smeta,
*rp.models,
*rp.inter_conf,
&rescored_forest);
+#ifdef CP_TIME
+ CpTime::Add(clock());
+#endif
forest.swap(rescored_forest);
forest.Reweight(cur_weights);
if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation);
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 813400e3..5491369f 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -7,6 +7,20 @@
#include <boost/shared_ptr.hpp>
#include <boost/program_options/variables_map.hpp>
+#undef CP_TIME
+//#define CP_TIME
+#ifdef CP_TIME
+#include <time.h>
+struct CpTime{
+public:
+ static void Add(clock_t x);
+ static void Sub(clock_t x);
+ static double Get();
+private:
+ static clock_t time_;
+};
+#endif
+
class SentenceMetadata;
struct Hypergraph;
struct DecoderImpl;