From b8f314dddda3d440164e4772830e3c951ba06ee4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 10 Dec 2010 13:00:04 -0500 Subject: extract kbest alignments --- decoder/aligner.cc | 168 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 67 deletions(-) (limited to 'decoder/aligner.cc') diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 43d4e0ce..7830b955 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -3,8 +3,11 @@ #include #include +#include + #include "array2d.h" #include "hg.h" +#include "kbest.h" #include "sentence_metadata.h" #include "inside_outside.h" #include "viterbi.h" @@ -178,8 +181,21 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, const Hypergraph& in_g, ostream* out, bool map_instead_of_viterbi, + int k_best, // must = 0 if MAP const vector* edges) { bool fix_up_src_spans = false; + if (k_best > 1 && edges) { + cerr << "ERROR: cannot request multiple best alignments and provide an edge set!\n"; + abort(); + } + if (map_instead_of_viterbi) { + if (k_best != 0) { + cerr << "WARNING: K-best alignment extraction not available for MAP, use --aligner_use_viterbi\n"; + } + k_best = 1; + } else { + if (k_best == 0) k_best = 1; + } const Hypergraph* g = &in_g; HypergraphP new_hg; if (!src_lattice.IsSentence() || @@ -190,83 +206,101 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, map_instead_of_viterbi = false; fix_up_src_spans = !src_lattice.IsSentence(); } - if (!map_instead_of_viterbi || edges) { - new_hg = in_g.CreateViterbiHypergraph(edges); - for (int i = 0; i < new_hg->edges_.size(); ++i) - new_hg->edges_[i].edge_prob_ = prob_t::One(); - g = new_hg.get(); - } - vector edge_posteriors(g->edges_.size(), prob_t::Zero()); - vector trg_sent; - vector src_sent; - if (fix_up_src_spans) { - ViterbiESentence(*g, &src_sent); - } else { - src_sent.resize(src_lattice.size()); - for (int i = 0; i < src_sent.size(); ++i) - src_sent[i] = src_lattice[i][0].label; - } + KBest::KBestDerivations, ViterbiPathTraversal> kbest(in_g, k_best); + boost::scoped_ptr > kbest_edges; - ViterbiFSentence(*g, &trg_sent); + for (int best = 0; best < k_best; ++best) { + const KBest::KBestDerivations, ViterbiPathTraversal>::Derivation* d = NULL; + if (!map_instead_of_viterbi) { + d = kbest.LazyKthBest(in_g.nodes_.size() - 1, best); + if (!d) break; // there are fewer than k_best derivations! + const vector& yield = d->yield; + kbest_edges.reset(new vector(in_g.edges_.size(), false)); + for (int i = 0; i < yield.size(); ++i) { + assert(yield[i]->id_ < kbest_edges->size()); + (*kbest_edges)[yield[i]->id_] = true; + } + } + if (!map_instead_of_viterbi || edges) { + if (kbest_edges) edges = kbest_edges.get(); + new_hg = in_g.CreateViterbiHypergraph(edges); + for (int i = 0; i < new_hg->edges_.size(); ++i) + new_hg->edges_[i].edge_prob_ = prob_t::One(); + g = new_hg.get(); + } - if (edges || !map_instead_of_viterbi) { - for (int i = 0; i < edge_posteriors.size(); ++i) - edge_posteriors[i] = prob_t::One(); - } else { - SparseVector posts; - const prob_t z = InsideOutside, TransitionEventWeightFunction>(*g, &posts); - for (int i = 0; i < edge_posteriors.size(); ++i) - edge_posteriors[i] = posts[i] / z; - } - vector > src_cov(g->edges_.size()); - vector > trg_cov(g->edges_.size()); - TargetEdgeCoveragesUsingTree(*g, &trg_cov); + vector edge_posteriors(g->edges_.size(), prob_t::Zero()); + vector trg_sent; + vector src_sent; + if (fix_up_src_spans) { + ViterbiESentence(*g, &src_sent); + } else { + src_sent.resize(src_lattice.size()); + for (int i = 0; i < src_sent.size(); ++i) + src_sent[i] = src_lattice[i][0].label; + } + + ViterbiFSentence(*g, &trg_sent); + + if (edges || !map_instead_of_viterbi) { + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = prob_t::One(); + } else { + SparseVector posts; + const prob_t z = InsideOutside, TransitionEventWeightFunction>(*g, &posts); + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = posts[i] / z; + } + vector > src_cov(g->edges_.size()); + vector > trg_cov(g->edges_.size()); + TargetEdgeCoveragesUsingTree(*g, &trg_cov); - if (fix_up_src_spans) - SourceEdgeCoveragesUsingTree(*g, &src_cov); - else - SourceEdgeCoveragesUsingParseIndices(*g, &src_cov); + if (fix_up_src_spans) + SourceEdgeCoveragesUsingTree(*g, &src_cov); + else + SourceEdgeCoveragesUsingParseIndices(*g, &src_cov); - // figure out the src and reference size; - int src_size = src_sent.size(); - int ref_size = trg_sent.size(); - Array2D align(src_size + 1, ref_size, prob_t::Zero()); - for (int c = 0; c < g->edges_.size(); ++c) { - const prob_t& p = edge_posteriors[c]; - const set& srcs = src_cov[c]; - const set& trgs = trg_cov[c]; - for (set::const_iterator si = srcs.begin(); - si != srcs.end(); ++si) { - for (set::const_iterator ti = trgs.begin(); - ti != trgs.end(); ++ti) { - align(*si + 1, *ti) += p; + // figure out the src and reference size; + int src_size = src_sent.size(); + int ref_size = trg_sent.size(); + Array2D align(src_size + 1, ref_size, prob_t::Zero()); + for (int c = 0; c < g->edges_.size(); ++c) { + const prob_t& p = edge_posteriors[c]; + const set& srcs = src_cov[c]; + const set& trgs = trg_cov[c]; + for (set::const_iterator si = srcs.begin(); + si != srcs.end(); ++si) { + for (set::const_iterator ti = trgs.begin(); + ti != trgs.end(); ++ti) { + align(*si + 1, *ti) += p; + } } } - } - new_hg.reset(); - //if (g != &in_g) { g.reset(); } + new_hg.reset(); + //if (g != &in_g) { g.reset(); } - prob_t threshold(0.9); - const bool use_soft_threshold = true; // TODO configure + prob_t threshold(0.9); + const bool use_soft_threshold = true; // TODO configure - Array2D grid(src_size, ref_size, false); - for (int j = 0; j < ref_size; ++j) { - if (use_soft_threshold) { - threshold = prob_t::Zero(); - for (int i = 0; i <= src_size; ++i) - if (align(i, j) > threshold) threshold = align(i, j); - //threshold *= prob_t(0.99); + Array2D grid(src_size, ref_size, false); + for (int j = 0; j < ref_size; ++j) { + if (use_soft_threshold) { + threshold = prob_t::Zero(); + for (int i = 0; i <= src_size; ++i) + if (align(i, j) > threshold) threshold = align(i, j); + //threshold *= prob_t(0.99); + } + for (int i = 0; i < src_size; ++i) + grid(i, j) = align(i+1, j) >= threshold; } - for (int i = 0; i < src_size; ++i) - grid(i, j) = align(i+1, j) >= threshold; - } - if (out == &cout) { - // TODO need to do some sort of verbose flag - WriteProbGrid(align, &cerr); - cerr << grid << endl; + if (out == &cout && k_best < 2) { + // TODO need to do some sort of verbose flag + WriteProbGrid(align, &cerr); + cerr << grid << endl; + } + (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| "; + AlignmentPharaoh::SerializePharaohFormat(grid, out); } - (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| "; - AlignmentPharaoh::SerializePharaohFormat(grid, out); }; -- cgit v1.2.3