summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-10-12 16:30:02 -0400
committerWu, Ke <wuke@cs.umd.edu>2014-10-12 16:30:02 -0400
commit8c0e4a5c1f168a419b3a236a94815c97164bddbc (patch)
tree177b7cd332482f779f04945fea8325a00e9cfa54 /decoder/oracle_bleu.h
parentd88186af251ecae60974b20395ce75807bfdda35 (diff)
Cherry picked Mr.MIRA compatibility mode code
Diffstat (limited to 'decoder/oracle_bleu.h')
-rw-r--r--decoder/oracle_bleu.h37
1 files changed, 29 insertions, 8 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index d2c4715c..75db61e8 100644
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -21,6 +21,7 @@
#include "kbest.h"
#include "timing_stats.h"
#include "sentences.h"
+#include "b64featvector.h"
//TODO: put function impls into .cc
//TODO: move Translation into its own .h and use in cdec
@@ -253,18 +254,28 @@ struct OracleBleu {
bool show_derivation;
template <class Filter>
- void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
+ void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat,
+ int src_len, std::ostream& kbest_out = std::cout,
+ std::ostream& deriv_out = std::cerr) {
using namespace std;
using namespace boost;
typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;
K kbest(forest,k);
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
- for (int i = 0; i < k; ++i) {
+ if (mr_mira_compat) kbest_out << k << "\n";
+ int i = 0;
+ for (; i < k; ++i) {
typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score);
+ kbest_out << sent_id << " ||| ";
+ if (mr_mira_compat) kbest_out << src_len << " ||| ";
+ kbest_out << TD::GetString(d->yield) << " ||| ";
+ if (mr_mira_compat)
+ kbest_out << EncodeFeatureVector(d->feature_values);
+ else
+ kbest_out << d->feature_values;
+ kbest_out << " ||| " << log(d->score);
if (!refs.empty()) {
ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
@@ -279,10 +290,17 @@ struct OracleBleu {
deriv_out<<"\n"<<flush;
}
}
+ if (mr_mira_compat) {
+ for (; i < k; ++i) kbest_out << "\n";
+ kbest_out << flush;
+ }
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k,
+ const bool unique, const bool mr_mira_compat,
+ const int src_len, std::string const& kbest_out_filename_,
+ std::string const& deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
@@ -295,9 +313,11 @@ struct OracleBleu {
WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::NoFilter<std::vector<WordID> > >(
+ sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len,
+ ko.get(), oderiv.get());
}
}
@@ -305,7 +325,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
+ DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(),
+ "-");
}
};