diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-10 13:00:04 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-10 13:00:04 -0500 |
commit | 658263d53805482d2bc2dc186476626a01cb4e93 (patch) | |
tree | 1e0be0de4deea445e0670c51564313c11f705b52 | |
parent | 3a4f9a526c58cbc1fe69fe5b3aeefd9639f9c49b (diff) |
extract kbest alignments
-rw-r--r-- | decoder/aligner.cc | 168 | ||||
-rw-r--r-- | decoder/aligner.h | 1 | ||||
-rw-r--r-- | decoder/decoder.cc | 2 | ||||
-rw-r--r-- | decoder/kbest.h | 6 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 2 | ||||
-rw-r--r-- | tests/system_tests/hmm/cdec.ini | 2 | ||||
-rw-r--r-- | vest/ces.cc | 2 |
7 files changed, 109 insertions, 74 deletions
diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 43d4e0ce..7830b955 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -3,8 +3,11 @@ #include <cstdio> #include <set> +#include <boost/scoped_ptr.hpp> + #include "array2d.h" #include "hg.h" +#include "kbest.h" #include "sentence_metadata.h" #include "inside_outside.h" #include "viterbi.h" @@ -178,8 +181,21 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, const Hypergraph& in_g, ostream* out, bool map_instead_of_viterbi, + int k_best, // must = 0 if MAP const vector<bool>* edges) { bool fix_up_src_spans = false; + if (k_best > 1 && edges) { + cerr << "ERROR: cannot request multiple best alignments and provide an edge set!\n"; + abort(); + } + if (map_instead_of_viterbi) { + if (k_best != 0) { + cerr << "WARNING: K-best alignment extraction not available for MAP, use --aligner_use_viterbi\n"; + } + k_best = 1; + } else { + if (k_best == 0) k_best = 1; + } const Hypergraph* g = &in_g; HypergraphP new_hg; if (!src_lattice.IsSentence() || @@ -190,83 +206,101 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, map_instead_of_viterbi = false; fix_up_src_spans = !src_lattice.IsSentence(); } - if (!map_instead_of_viterbi || edges) { - new_hg = in_g.CreateViterbiHypergraph(edges); - for (int i = 0; i < new_hg->edges_.size(); ++i) - new_hg->edges_[i].edge_prob_ = prob_t::One(); - g = new_hg.get(); - } - vector<prob_t> edge_posteriors(g->edges_.size(), prob_t::Zero()); - vector<WordID> trg_sent; - vector<WordID> src_sent; - if (fix_up_src_spans) { - ViterbiESentence(*g, &src_sent); - } else { - src_sent.resize(src_lattice.size()); - for (int i = 0; i < src_sent.size(); ++i) - src_sent[i] = src_lattice[i][0].label; - } + KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best); + boost::scoped_ptr<vector<bool> > kbest_edges; - ViterbiFSentence(*g, &trg_sent); + for (int best = 0; best < k_best; ++best) { + const KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal>::Derivation* d = NULL; + if (!map_instead_of_viterbi) { + d = kbest.LazyKthBest(in_g.nodes_.size() - 1, best); + if (!d) break; // there are fewer than k_best derivations! + const vector<Hypergraph::Edge const*>& yield = d->yield; + kbest_edges.reset(new vector<bool>(in_g.edges_.size(), false)); + for (int i = 0; i < yield.size(); ++i) { + assert(yield[i]->id_ < kbest_edges->size()); + (*kbest_edges)[yield[i]->id_] = true; + } + } + if (!map_instead_of_viterbi || edges) { + if (kbest_edges) edges = kbest_edges.get(); + new_hg = in_g.CreateViterbiHypergraph(edges); + for (int i = 0; i < new_hg->edges_.size(); ++i) + new_hg->edges_[i].edge_prob_ = prob_t::One(); + g = new_hg.get(); + } - if (edges || !map_instead_of_viterbi) { - for (int i = 0; i < edge_posteriors.size(); ++i) - edge_posteriors[i] = prob_t::One(); - } else { - SparseVector<prob_t> posts; - const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(*g, &posts); - for (int i = 0; i < edge_posteriors.size(); ++i) - edge_posteriors[i] = posts[i] / z; - } - vector<set<int> > src_cov(g->edges_.size()); - vector<set<int> > trg_cov(g->edges_.size()); - TargetEdgeCoveragesUsingTree(*g, &trg_cov); + vector<prob_t> edge_posteriors(g->edges_.size(), prob_t::Zero()); + vector<WordID> trg_sent; + vector<WordID> src_sent; + if (fix_up_src_spans) { + ViterbiESentence(*g, &src_sent); + } else { + src_sent.resize(src_lattice.size()); + for (int i = 0; i < src_sent.size(); ++i) + src_sent[i] = src_lattice[i][0].label; + } + + ViterbiFSentence(*g, &trg_sent); + + if (edges || !map_instead_of_viterbi) { + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = prob_t::One(); + } else { + SparseVector<prob_t> posts; + const prob_t z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(*g, &posts); + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = posts[i] / z; + } + vector<set<int> > src_cov(g->edges_.size()); + vector<set<int> > trg_cov(g->edges_.size()); + TargetEdgeCoveragesUsingTree(*g, &trg_cov); - if (fix_up_src_spans) - SourceEdgeCoveragesUsingTree(*g, &src_cov); - else - SourceEdgeCoveragesUsingParseIndices(*g, &src_cov); + if (fix_up_src_spans) + SourceEdgeCoveragesUsingTree(*g, &src_cov); + else + SourceEdgeCoveragesUsingParseIndices(*g, &src_cov); - // figure out the src and reference size; - int src_size = src_sent.size(); - int ref_size = trg_sent.size(); - Array2D<prob_t> align(src_size + 1, ref_size, prob_t::Zero()); - for (int c = 0; c < g->edges_.size(); ++c) { - const prob_t& p = edge_posteriors[c]; - const set<int>& srcs = src_cov[c]; - const set<int>& trgs = trg_cov[c]; - for (set<int>::const_iterator si = srcs.begin(); - si != srcs.end(); ++si) { - for (set<int>::const_iterator ti = trgs.begin(); - ti != trgs.end(); ++ti) { - align(*si + 1, *ti) += p; + // figure out the src and reference size; + int src_size = src_sent.size(); + int ref_size = trg_sent.size(); + Array2D<prob_t> align(src_size + 1, ref_size, prob_t::Zero()); + for (int c = 0; c < g->edges_.size(); ++c) { + const prob_t& p = edge_posteriors[c]; + const set<int>& srcs = src_cov[c]; + const set<int>& trgs = trg_cov[c]; + for (set<int>::const_iterator si = srcs.begin(); + si != srcs.end(); ++si) { + for (set<int>::const_iterator ti = trgs.begin(); + ti != trgs.end(); ++ti) { + align(*si + 1, *ti) += p; + } } } - } - new_hg.reset(); - //if (g != &in_g) { g.reset(); } + new_hg.reset(); + //if (g != &in_g) { g.reset(); } - prob_t threshold(0.9); - const bool use_soft_threshold = true; // TODO configure + prob_t threshold(0.9); + const bool use_soft_threshold = true; // TODO configure - Array2D<bool> grid(src_size, ref_size, false); - for (int j = 0; j < ref_size; ++j) { - if (use_soft_threshold) { - threshold = prob_t::Zero(); - for (int i = 0; i <= src_size; ++i) - if (align(i, j) > threshold) threshold = align(i, j); - //threshold *= prob_t(0.99); + Array2D<bool> grid(src_size, ref_size, false); + for (int j = 0; j < ref_size; ++j) { + if (use_soft_threshold) { + threshold = prob_t::Zero(); + for (int i = 0; i <= src_size; ++i) + if (align(i, j) > threshold) threshold = align(i, j); + //threshold *= prob_t(0.99); + } + for (int i = 0; i < src_size; ++i) + grid(i, j) = align(i+1, j) >= threshold; } - for (int i = 0; i < src_size; ++i) - grid(i, j) = align(i+1, j) >= threshold; - } - if (out == &cout) { - // TODO need to do some sort of verbose flag - WriteProbGrid(align, &cerr); - cerr << grid << endl; + if (out == &cout && k_best < 2) { + // TODO need to do some sort of verbose flag + WriteProbGrid(align, &cerr); + cerr << grid << endl; + } + (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| "; + AlignmentPharaoh::SerializePharaohFormat(grid, out); } - (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| "; - AlignmentPharaoh::SerializePharaohFormat(grid, out); }; diff --git a/decoder/aligner.h b/decoder/aligner.h index a088ba6c..a34795c9 100644 --- a/decoder/aligner.h +++ b/decoder/aligner.h @@ -19,6 +19,7 @@ struct AlignerTools { const Hypergraph& g, std::ostream* out, bool map_instead_of_viterbi = true, + int k_best = 0, const std::vector<bool>* edges = NULL); }; diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fb219663..a21b47c0 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -870,7 +870,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } if (aligner_mode && !output_training_vector) - AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout, 0 == conf.count("aligner_use_viterbi")); + AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout, 0 == conf.count("aligner_use_viterbi"), kbest ? conf["k_best"].as<int>() : 0); if (write_gradient) { const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); ref_exp /= ref_z; diff --git a/decoder/kbest.h b/decoder/kbest.h index 3eb8707b..03a8311c 100644 --- a/decoder/kbest.h +++ b/decoder/kbest.h @@ -12,9 +12,9 @@ namespace KBest { // default, don't filter any derivations from the k-best list + template<typename Dummy> struct NoFilter { - bool operator()(const std::vector<WordID>& yield) { - (void) yield; + bool operator()(const Dummy&) { return false; } }; @@ -32,7 +32,7 @@ namespace KBest { // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005) template<typename T, // yield type (returned by Traversal) typename Traversal, - typename DerivationFilter = NoFilter, + typename DerivationFilter = NoFilter<T>, typename WeightType = prob_t, typename WeightFunction = EdgeProb> struct KBestDerivations { diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d75f50fc..15d48588 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -286,7 +286,7 @@ struct OracleBleu { std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl; if (!unique) - kbest<KBest::NoFilter>(sent_id,forest,k,ko.get(),std::cerr); + kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr); else { kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr); } diff --git a/tests/system_tests/hmm/cdec.ini b/tests/system_tests/hmm/cdec.ini index 04ef0d6e..9e94d7fe 100644 --- a/tests/system_tests/hmm/cdec.ini +++ b/tests/system_tests/hmm/cdec.ini @@ -1,4 +1,4 @@ aligner=true intersection_strategy=full formalism=lexalign -feature_function=MarkovJump +b +feature_function=NewJump diff --git a/vest/ces.cc b/vest/ces.cc index aa341058..4ae6b695 100644 --- a/vest/ces.cc +++ b/vest/ces.cc @@ -42,7 +42,7 @@ void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, Er Lattice ref; LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src); LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref); - AlignerTools::WriteAlignment(src, ref, hg, &os, true, &edges); + AlignerTools::WriteAlignment(src, ref, hg, &os, true, 0, &edges); string tstr = os.str(); TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans); } else { |