summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-10 13:00:04 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-10 13:00:04 -0500
commitb8f314dddda3d440164e4772830e3c951ba06ee4 (patch)
tree794e70c9b2492181618c2226a199bd6f0b447feb
parentd53b676548531af8f20a000b10d25f4fcaa811b3 (diff)
extract kbest alignments
-rw-r--r--decoder/aligner.cc168
-rw-r--r--decoder/aligner.h1
-rw-r--r--decoder/decoder.cc2
-rw-r--r--decoder/kbest.h6
-rwxr-xr-xdecoder/oracle_bleu.h2
-rw-r--r--tests/system_tests/hmm/cdec.ini2
-rw-r--r--vest/ces.cc2
7 files changed, 109 insertions, 74 deletions
diff --git a/decoder/aligner.cc b/decoder/aligner.cc
index 43d4e0ce..7830b955 100644
--- a/decoder/aligner.cc
+++ b/decoder/aligner.cc
@@ -3,8 +3,11 @@
#include <cstdio>
#include <set>
+#include <boost/scoped_ptr.hpp>
+
#include "array2d.h"
#include "hg.h"
+#include "kbest.h"
#include "sentence_metadata.h"
#include "inside_outside.h"
#include "viterbi.h"
@@ -178,8 +181,21 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,
const Hypergraph& in_g,
ostream* out,
bool map_instead_of_viterbi,
+ int k_best, // must = 0 if MAP
const vector<bool>* edges) {
bool fix_up_src_spans = false;
+ if (k_best > 1 && edges) {
+ cerr << "ERROR: cannot request multiple best alignments and provide an edge set!\n";
+ abort();
+ }
+ if (map_instead_of_viterbi) {
+ if (k_best != 0) {
+ cerr << "WARNING: K-best alignment extraction not available for MAP, use --aligner_use_viterbi\n";
+ }
+ k_best = 1;
+ } else {
+ if (k_best == 0) k_best = 1;
+ }
const Hypergraph* g = &in_g;
HypergraphP new_hg;
if (!src_lattice.IsSentence() ||
@@ -190,83 +206,101 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,
map_instead_of_viterbi = false;
fix_up_src_spans = !src_lattice.IsSentence();
}
- if (!map_instead_of_viterbi || edges) {
- new_hg = in_g.CreateViterbiHypergraph(edges);
- for (int i = 0; i < new_hg->edges_.size(); ++i)
- new_hg->edges_[i].edge_prob_ = prob_t::One();
- g = new_hg.get();
- }
- vector<prob_t> edge_posteriors(g->edges_.size(), prob_t::Zero());
- vector<WordID> trg_sent;
- vector<WordID> src_sent;
- if (fix_up_src_spans) {
- ViterbiESentence(*g, &src_sent);
- } else {
- src_sent.resize(src_lattice.size());
- for (int i = 0; i < src_sent.size(); ++i)
- src_sent[i] = src_lattice[i][0].label;
- }
+ KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best);
+ boost::scoped_ptr<vector<bool> > kbest_edges;
- ViterbiFSentence(*g, &trg_sent);
+ for (int best = 0; best < k_best; ++best) {
+ const KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal>::Derivation* d = NULL;
+ if (!map_instead_of_viterbi) {
+ d = kbest.LazyKthBest(in_g.nodes_.size() - 1, best);
+ if (!d) break; // there are fewer than k_best derivations!
+ const vector<Hypergraph::Edge const*>& yield = d->yield;
+ kbest_edges.reset(new vector<bool>(in_g.edges_.size(), false));
+ for (int i = 0; i < yield.size(); ++i) {
+ assert(yield[i]->id_ < kbest_edges->size());
+ (*kbest_edges)[yield[i]->id_] = true;
+ }
+ }
+ if (!map_instead_of_viterbi || edges) {
+ if (kbest_edges) edges = kbest_edges.get();
+ new_hg = in_g.CreateViterbiHypergraph(edges);
+ for (int i = 0; i < new_hg->edges_.size(); ++i)
+ new_hg->edges_[i].edge_prob_ = prob_t::One();
+ g = new_hg.get();
+ }
- if (edges || !map_instead_of_viterbi) {
- for (int i = 0; i < edge_posteriors.size(); ++i)
- edge_posteriors[i] = prob_t::One();
- } else {
- SparseVector<prob_t> posts;
- const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(*g, &posts);
- for (int i = 0; i < edge_posteriors.size(); ++i)
- edge_posteriors[i] = posts[i] / z;
- }
- vector<set<int> > src_cov(g->edges_.size());
- vector<set<int> > trg_cov(g->edges_.size());
- TargetEdgeCoveragesUsingTree(*g, &trg_cov);
+ vector<prob_t> edge_posteriors(g->edges_.size(), prob_t::Zero());
+ vector<WordID> trg_sent;
+ vector<WordID> src_sent;
+ if (fix_up_src_spans) {
+ ViterbiESentence(*g, &src_sent);
+ } else {
+ src_sent.resize(src_lattice.size());
+ for (int i = 0; i < src_sent.size(); ++i)
+ src_sent[i] = src_lattice[i][0].label;
+ }
+
+ ViterbiFSentence(*g, &trg_sent);
+
+ if (edges || !map_instead_of_viterbi) {
+ for (int i = 0; i < edge_posteriors.size(); ++i)
+ edge_posteriors[i] = prob_t::One();
+ } else {
+ SparseVector<prob_t> posts;
+ const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(*g, &posts);
+ for (int i = 0; i < edge_posteriors.size(); ++i)
+ edge_posteriors[i] = posts[i] / z;
+ }
+ vector<set<int> > src_cov(g->edges_.size());
+ vector<set<int> > trg_cov(g->edges_.size());
+ TargetEdgeCoveragesUsingTree(*g, &trg_cov);
- if (fix_up_src_spans)
- SourceEdgeCoveragesUsingTree(*g, &src_cov);
- else
- SourceEdgeCoveragesUsingParseIndices(*g, &src_cov);
+ if (fix_up_src_spans)
+ SourceEdgeCoveragesUsingTree(*g, &src_cov);
+ else
+ SourceEdgeCoveragesUsingParseIndices(*g, &src_cov);
- // figure out the src and reference size;
- int src_size = src_sent.size();
- int ref_size = trg_sent.size();
- Array2D<prob_t> align(src_size + 1, ref_size, prob_t::Zero());
- for (int c = 0; c < g->edges_.size(); ++c) {
- const prob_t& p = edge_posteriors[c];
- const set<int>& srcs = src_cov[c];
- const set<int>& trgs = trg_cov[c];
- for (set<int>::const_iterator si = srcs.begin();
- si != srcs.end(); ++si) {
- for (set<int>::const_iterator ti = trgs.begin();
- ti != trgs.end(); ++ti) {
- align(*si + 1, *ti) += p;
+ // figure out the src and reference size;
+ int src_size = src_sent.size();
+ int ref_size = trg_sent.size();
+ Array2D<prob_t> align(src_size + 1, ref_size, prob_t::Zero());
+ for (int c = 0; c < g->edges_.size(); ++c) {
+ const prob_t& p = edge_posteriors[c];
+ const set<int>& srcs = src_cov[c];
+ const set<int>& trgs = trg_cov[c];
+ for (set<int>::const_iterator si = srcs.begin();
+ si != srcs.end(); ++si) {
+ for (set<int>::const_iterator ti = trgs.begin();
+ ti != trgs.end(); ++ti) {
+ align(*si + 1, *ti) += p;
+ }
}
}
- }
- new_hg.reset();
- //if (g != &in_g) { g.reset(); }
+ new_hg.reset();
+ //if (g != &in_g) { g.reset(); }
- prob_t threshold(0.9);
- const bool use_soft_threshold = true; // TODO configure
+ prob_t threshold(0.9);
+ const bool use_soft_threshold = true; // TODO configure
- Array2D<bool> grid(src_size, ref_size, false);
- for (int j = 0; j < ref_size; ++j) {
- if (use_soft_threshold) {
- threshold = prob_t::Zero();
- for (int i = 0; i <= src_size; ++i)
- if (align(i, j) > threshold) threshold = align(i, j);
- //threshold *= prob_t(0.99);
+ Array2D<bool> grid(src_size, ref_size, false);
+ for (int j = 0; j < ref_size; ++j) {
+ if (use_soft_threshold) {
+ threshold = prob_t::Zero();
+ for (int i = 0; i <= src_size; ++i)
+ if (align(i, j) > threshold) threshold = align(i, j);
+ //threshold *= prob_t(0.99);
+ }
+ for (int i = 0; i < src_size; ++i)
+ grid(i, j) = align(i+1, j) >= threshold;
}
- for (int i = 0; i < src_size; ++i)
- grid(i, j) = align(i+1, j) >= threshold;
- }
- if (out == &cout) {
- // TODO need to do some sort of verbose flag
- WriteProbGrid(align, &cerr);
- cerr << grid << endl;
+ if (out == &cout && k_best < 2) {
+ // TODO need to do some sort of verbose flag
+ WriteProbGrid(align, &cerr);
+ cerr << grid << endl;
+ }
+ (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| ";
+ AlignmentPharaoh::SerializePharaohFormat(grid, out);
}
- (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| ";
- AlignmentPharaoh::SerializePharaohFormat(grid, out);
};
diff --git a/decoder/aligner.h b/decoder/aligner.h
index a088ba6c..a34795c9 100644
--- a/decoder/aligner.h
+++ b/decoder/aligner.h
@@ -19,6 +19,7 @@ struct AlignerTools {
const Hypergraph& g,
std::ostream* out,
bool map_instead_of_viterbi = true,
+ int k_best = 0,
const std::vector<bool>* edges = NULL);
};
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index fb219663..a21b47c0 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -870,7 +870,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
}
if (aligner_mode && !output_training_vector)
- AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout, 0 == conf.count("aligner_use_viterbi"));
+ AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout, 0 == conf.count("aligner_use_viterbi"), kbest ? conf["k_best"].as<int>() : 0);
if (write_gradient) {
const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
ref_exp /= ref_z;
diff --git a/decoder/kbest.h b/decoder/kbest.h
index 3eb8707b..03a8311c 100644
--- a/decoder/kbest.h
+++ b/decoder/kbest.h
@@ -12,9 +12,9 @@
namespace KBest {
// default, don't filter any derivations from the k-best list
+ template<typename Dummy>
struct NoFilter {
- bool operator()(const std::vector<WordID>& yield) {
- (void) yield;
+ bool operator()(const Dummy&) {
return false;
}
};
@@ -32,7 +32,7 @@ namespace KBest {
// the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005)
template<typename T, // yield type (returned by Traversal)
typename Traversal,
- typename DerivationFilter = NoFilter,
+ typename DerivationFilter = NoFilter<T>,
typename WeightType = prob_t,
typename WeightFunction = EdgeProb>
struct KBestDerivations {
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index d75f50fc..15d48588 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -286,7 +286,7 @@ struct OracleBleu {
std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl;
if (!unique)
- kbest<KBest::NoFilter>(sent_id,forest,k,ko.get(),std::cerr);
+ kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr);
else {
kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr);
}
diff --git a/tests/system_tests/hmm/cdec.ini b/tests/system_tests/hmm/cdec.ini
index 04ef0d6e..9e94d7fe 100644
--- a/tests/system_tests/hmm/cdec.ini
+++ b/tests/system_tests/hmm/cdec.ini
@@ -1,4 +1,4 @@
aligner=true
intersection_strategy=full
formalism=lexalign
-feature_function=MarkovJump +b
+feature_function=NewJump
diff --git a/vest/ces.cc b/vest/ces.cc
index aa341058..4ae6b695 100644
--- a/vest/ces.cc
+++ b/vest/ces.cc
@@ -42,7 +42,7 @@ void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, Er
Lattice ref;
LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src);
LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref);
- AlignerTools::WriteAlignment(src, ref, hg, &os, true, &edges);
+ AlignerTools::WriteAlignment(src, ref, hg, &os, true, 0, &edges);
string tstr = os.str();
TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans);
} else {