diff options
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-x | decoder/oracle_bleu.h | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 273e14b8..32525466 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -1,6 +1,7 @@ #ifndef ORACLE_BLEU_H #define ORACLE_BLEU_H +#include <sstream> #include <iostream> #include <string> #include <vector> @@ -14,6 +15,7 @@ #include "viterbi.h" #include "sentence_metadata.h" #include "apply_models.h" +#include "kbest.h" //TODO: put function impls into .cc //TODO: disentangle @@ -148,6 +150,54 @@ struct OracleBleu { // dest_forest->SortInEdgesByEdgeWeights(); } +// TODO decoder output should probably be moved to another file - how about oracle_bleu.h + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) { + using namespace std; + using namespace boost; + cerr << "In kbest\n"; + + ofstream kbest_out; + kbest_out.open(kbest_out_filename_.c_str()); + cerr << "Output kbest to " << kbest_out_filename_; + + //add length (f side) src length of this sentence to the psuedo-doc src length count + float curr_src_length = doc_src_length + tmp_src_length; + + if (unique) { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + //calculate score in context of psuedo-doc + Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); + sentscore->PlusEquals(*doc_score,float(1)); + float bleu = curr_src_length * sentscore->ComputeScore(); + kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl; + // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + // << d->feature_values << " ||| " << log(d->score) << endl; + } + } else { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << endl; + } + } + } + + void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique) + { + std::ostringstream kbest_string_stream; + kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id; + DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str()); + + } + }; |