diff options
-rw-r--r-- | decoder/decoder.cc | 12 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 22 |
2 files changed, 24 insertions, 10 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ff068be9..2c3a06de 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") + ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value<string>(),"Directory to write forests to"); + // ob.AddOptions(&opts); #ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); @@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream kbest = conf.count("k_best"); unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); + oracle.show_derivation=conf.count("show_derivations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else { if (kbest && !has_ref) { //TODO: does this work properly? - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } if (conf.count("graphviz")) forest.PrintGraphviz(); - if (kbest) - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + if (kbest) { + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); cout << (log(ref_z) - log(first_z)) << endl << flush; diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 15d48588..b603e27a 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -272,23 +272,31 @@ struct OracleBleu { } kbest_out<<endl<<flush; if (show_derivation) { - deriv_out<<"\nsent_id="<<sent_id<<"\n"; + deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k + deriv_out<<log(d->score)<<"\n"; deriv_out<<kbest.derivation_tree(*d,true); - deriv_out<<flush; + deriv_out<<"\n"<<flush; } } } // TODO decoder output should probably be moved to another file - how about oracle_bleu.h - void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) { + void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) { WriteFile ko(kbest_out_filename_); - std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl; + std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; + std::ostringstream sderiv; + sderiv << deriv_out_filename_; + if (show_derivation) { + sderiv << "/derivs." << sent_id; + std::cerr << "Output derivations to " << deriv_out_filename_ << std::endl; + } + WriteFile oderiv(sderiv.str()); if (!unique) - kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr); + kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get()); else { - kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr); + kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get()); } } @@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; - DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str()); + DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-"); } }; |