summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-07 18:39:38 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-07 18:39:38 -0400
commit71daf4bf0b91a247d0d1663ae7850a3db85a378d (patch)
treed47ca0bc9d5ca96deedbbd2c27426be7819ebaea /decoder/decoder.cc
parent75b814cb246052746134f32c723cf6d278b148df (diff)
support for extracting k-best derivation trees
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index ff068be9..2c3a06de 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -426,6 +427,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to");
+
// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
po::options_description cfgo(cfg_options.description());
@@ -677,6 +679,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
kbest = conf.count("k_best");
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
+ oracle.show_derivation=conf.count("show_derivations");
#ifdef FSA_RESCORING
cfg_options.Validate();
@@ -938,7 +941,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else {
if (kbest && !has_ref) {
//TODO: does this work properly?
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1055,8 +1059,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
}
if (conf.count("graphviz")) forest.PrintGraphviz();
- if (kbest)
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ if (kbest) {
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ }
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
cout << (log(ref_z) - log(first_z)) << endl << flush;