summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-07 18:39:38 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-07 18:39:38 -0400
commit71daf4bf0b91a247d0d1663ae7850a3db85a378d (patch)
treed47ca0bc9d5ca96deedbbd2c27426be7819ebaea
parent75b814cb246052746134f32c723cf6d278b148df (diff)
support for extracting k-best derivation trees
-rw-r--r--decoder/decoder.cc12
-rwxr-xr-xdecoder/oracle_bleu.h22
2 files changed, 24 insertions, 10 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index ff068be9..2c3a06de 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to");
+
// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
po::options_description cfgo(cfg_options.description());
@@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
kbest = conf.count("k_best");
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
+ oracle.show_derivation=conf.count("show_derivations");
#ifdef FSA_RESCORING
cfg_options.Validate();
@@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else {
if (kbest && !has_ref) {
//TODO: does this work properly?
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
}
if (conf.count("graphviz")) forest.PrintGraphviz();
- if (kbest)
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ if (kbest) {
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ }
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
cout << (log(ref_z) - log(first_z)) << endl << flush;
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 15d48588..b603e27a 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -272,23 +272,31 @@ struct OracleBleu {
}
kbest_out<<endl<<flush;
if (show_derivation) {
- deriv_out<<"\nsent_id="<<sent_id<<"\n";
+ deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k
+ deriv_out<<log(d->score)<<"\n";
deriv_out<<kbest.derivation_tree(*d,true);
- deriv_out<<flush;
+ deriv_out<<"\n"<<flush;
}
}
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
- std::cerr << "Output kbest to " << kbest_out_filename_<<std::endl;
+ std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
+ std::ostringstream sderiv;
+ sderiv << deriv_out_filename_;
+ if (show_derivation) {
+ sderiv << "/derivs." << sent_id;
+ std::cerr << "Output derivations to " << deriv_out_filename_ << std::endl;
+ }
+ WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),std::cerr);
+ kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),std::cerr);
+ kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
}
}
@@ -296,7 +304,7 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str());
+ DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
}
};