diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/decoder.cc | 12 | ||||
| -rwxr-xr-x | decoder/oracle_bleu.h | 22 | 
2 files changed, 24 insertions, 10 deletions
| diff --git a/decoder/decoder.cc b/decoder/decoder.cc index ff068be9..2c3a06de 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")          ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")          ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") +        ("show_derivations", po::value<string>(), "Directory to print the derivation structures to")          ("graphviz","Show (constrained) translation forest in GraphViz format")          ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")          ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")          ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")          ("forest_output,O",po::value<string>(),"Directory to write forests to"); +    // ob.AddOptions(&opts);  #ifdef FSA_RESCORING    po::options_description cfgo(cfg_options.description()); @@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    kbest = conf.count("k_best");    unique_kbest = conf.count("unique_k_best");    get_oracle_forest = conf.count("get_oracle_forest"); +  oracle.show_derivation=conf.count("show_derivations");  #ifdef FSA_RESCORING    cfg_options.Validate(); @@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    } else {      if (kbest && !has_ref) {        //TODO: does this work properly? -      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); +      const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; +      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);      } else if (csplit_output_plf) {        cout << HypergraphIO::AsPLF(forest, false) << endl;      } else { @@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {          }        }        if (conf.count("graphviz")) forest.PrintGraphviz(); -      if (kbest) -        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); +      if (kbest) { +        const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; +        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); +      }        if (conf.count("show_conditional_prob")) {          const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);          cout << (log(ref_z) - log(first_z)) << endl << flush; diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 15d48588..b603e27a 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -272,23 +272,31 @@ struct OracleBleu {        }        kbest_out<<endl<<flush;        if (show_derivation) { -        deriv_out<<"\nsent_id="<<sent_id<<"\n"; +        deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k +        deriv_out<<log(d->score)<<"\n";          deriv_out<<kbest.derivation_tree(*d,true); -        deriv_out<<flush; +        deriv_out<<"\n"<<flush;        }      }    }  // TODO decoder output should probably be moved to another file - how about oracle_bleu.h -  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) { +  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {      WriteFile ko(kbest_out_filename_); -    std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl; +    std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; +    std::ostringstream sderiv; +    sderiv << deriv_out_filename_; +    if (show_derivation) { +      sderiv << "/derivs." << sent_id; +      std::cerr << "Output derivations to " << deriv_out_filename_ << std::endl; +    } +    WriteFile oderiv(sderiv.str());      if (!unique) -      kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr); +      kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());      else { -      kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr); +      kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());      }    } @@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo    {      std::ostringstream kbest_string_stream;      kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; -    DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str()); +    DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");    }  }; | 
