From 03b8a687ff9537a75931d79657d9bb57aefb4bc6 Mon Sep 17 00:00:00 2001 From: graehl Date: Thu, 15 Jul 2010 04:05:17 +0000 Subject: move kbest to oracle_bleu git-svn-id: https://ws10smt.googlecode.com/svn/trunk@262 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/oracle_bleu.h | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 273e14b8..32525466 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -1,6 +1,7 @@ #ifndef ORACLE_BLEU_H #define ORACLE_BLEU_H +#include #include #include #include @@ -14,6 +15,7 @@ #include "viterbi.h" #include "sentence_metadata.h" #include "apply_models.h" +#include "kbest.h" //TODO: put function impls into .cc //TODO: disentangle @@ -148,6 +150,54 @@ struct OracleBleu { // dest_forest->SortInEdgesByEdgeWeights(); } +// TODO decoder output should probably be moved to another file - how about oracle_bleu.h + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) { + using namespace std; + using namespace boost; + cerr << "In kbest\n"; + + ofstream kbest_out; + kbest_out.open(kbest_out_filename_.c_str()); + cerr << "Output kbest to " << kbest_out_filename_; + + //add length (f side) src length of this sentence to the psuedo-doc src length count + float curr_src_length = doc_src_length + tmp_src_length; + + if (unique) { + KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + //calculate score in context of psuedo-doc + Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); + sentscore->PlusEquals(*doc_score,float(1)); + float bleu = curr_src_length * sentscore->ComputeScore(); + kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl; + // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + // << d->feature_values << " ||| " << log(d->score) << endl; + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << endl; + } + } + } + + void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique) + { + std::ostringstream kbest_string_stream; + kbest_string_stream << conf["forest_output"].as() << "/kbest_"<