summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-xdecoder/oracle_bleu.h50
1 files changed, 50 insertions, 0 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 273e14b8..32525466 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -1,6 +1,7 @@
#ifndef ORACLE_BLEU_H
#define ORACLE_BLEU_H
+#include <sstream>
#include <iostream>
#include <string>
#include <vector>
@@ -14,6 +15,7 @@
#include "viterbi.h"
#include "sentence_metadata.h"
#include "apply_models.h"
+#include "kbest.h"
//TODO: put function impls into .cc
//TODO: disentangle
@@ -148,6 +150,54 @@ struct OracleBleu {
// dest_forest->SortInEdgesByEdgeWeights();
}
+// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+ using namespace std;
+ using namespace boost;
+ cerr << "In kbest\n";
+
+ ofstream kbest_out;
+ kbest_out.open(kbest_out_filename_.c_str());
+ cerr << "Output kbest to " << kbest_out_filename_;
+
+ //add length (f side) src length of this sentence to the psuedo-doc src length count
+ float curr_src_length = doc_src_length + tmp_src_length;
+
+ if (unique) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ //calculate score in context of psuedo-doc
+ Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
+ sentscore->PlusEquals(*doc_score,float(1));
+ float bleu = curr_src_length * sentscore->ComputeScore();
+ kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl;
+ // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ // << d->feature_values << " ||| " << log(d->score) << endl;
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
+ << d->feature_values << " ||| " << log(d->score) << endl;
+ }
+ }
+ }
+
+ void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique)
+ {
+ std::ostringstream kbest_string_stream;
+ kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id;
+ DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str());
+
+ }
+
};