From 2931396900c89eb19a50407955574960c364d0ee Mon Sep 17 00:00:00 2001 From: "Wu, Ke" Date: Sun, 12 Oct 2014 16:30:02 -0400 Subject: Cherry picked Mr.MIRA compatibility mode code --- decoder/oracle_bleu.h | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..75db61e8 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@ #include "kbest.h" #include "timing_stats.h" #include "sentences.h" +#include "b64featvector.h" //TODO: put function impls into .cc //TODO: move Translation into its own .h and use in cdec @@ -253,18 +254,28 @@ struct OracleBleu { bool show_derivation; template - void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { + void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, + int src_len, std::ostream& kbest_out = std::cout, + std::ostream& deriv_out = std::cerr) { using namespace std; using namespace boost; typedef KBest::KBestDerivations K; K kbest(forest,k); //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; - for (int i = 0; i < k; ++i) { + if (mr_mira_compat) kbest_out << k << "\n"; + int i = 0; + for (; i < k; ++i) { typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score); + kbest_out << sent_id << " ||| "; + if (mr_mira_compat) kbest_out << src_len << " ||| "; + kbest_out << TD::GetString(d->yield) << " ||| "; + if (mr_mira_compat) + kbest_out << EncodeFeatureVector(d->feature_values); + else + kbest_out << d->feature_values; + kbest_out << " ||| " << log(d->score); if (!refs.empty()) { ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); @@ -279,10 +290,17 @@ struct OracleBleu { deriv_out<<"\n"< > >(sent_id,forest,k,ko.get(),oderiv.get()); + kbest > >( + sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get()); else { - kbest(sent_id,forest,k,ko.get(),oderiv.get()); + kbest(sent_id, forest, k, mr_mira_compat, src_len, + ko.get(), oderiv.get()); } } @@ -305,7 +325,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"< Date: Wed, 15 Oct 2014 11:42:21 -0400 Subject: Add --show_derivations_mask to control what to print when --show_derivations --- decoder/decoder.cc | 9 +++++---- decoder/oracle_bleu.h | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index c384c33f..77afddd8 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -195,7 +195,7 @@ struct DecoderImpl { } forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; - if (!SILENT) { + if (!SILENT) { forest_stats(forest," Pruned "+forestname+" forest",false,false); cerr << " Pruned "< > rng; int sample_max_trans; bool aligner_mode; - bool graphviz; + bool graphviz; bool joshua_viz; bool encode_b64; bool kbest; @@ -404,6 +404,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value(), "Extract the rules used in translation (not de-duped!) to a file in this directory") ("show_derivations", po::value(), "Directory to print the derivation structures to") + ("show_derivations_mask", po::value()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") @@ -665,6 +666,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); + oracle.show_derivation_mask=conf["show_derivations_mask"].as(); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); combine_size = conf["combine_size"].as(); @@ -1098,4 +1100,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { o->NotifyDecodingComplete(smeta); return true; } - diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..893e36ca 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -252,6 +252,8 @@ struct OracleBleu { } bool show_derivation; + int show_derivation_mask; + template void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { using namespace std; @@ -275,7 +277,7 @@ struct OracleBleu { if (show_derivation) { deriv_out<<"\nsent_id="<score)<<"\n"; - deriv_out<