summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-10-15 11:42:21 -0400
committerWu, Ke <wuke@cs.umd.edu>2014-10-15 11:42:21 -0400
commit7d0990b52cd757f15d190c1fee122f262c4731a7 (patch)
tree047c32f65a0a9ef37be6335f036d9543a711dbfd /decoder
parentd88186af251ecae60974b20395ce75807bfdda35 (diff)
Add --show_derivations_mask to control what to print when --show_derivations
Diffstat (limited to 'decoder')
-rw-r--r--decoder/decoder.cc9
-rw-r--r--decoder/oracle_bleu.h4
2 files changed, 8 insertions, 5 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index c384c33f..77afddd8 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -195,7 +195,7 @@ struct DecoderImpl {
}
forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);
if (!forestname.empty()) forestname=" "+forestname;
- if (!SILENT) {
+ if (!SILENT) {
forest_stats(forest," Pruned "+forestname+" forest",false,false);
cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
}
@@ -261,7 +261,7 @@ struct DecoderImpl {
assert(ref);
LatticeTools::ConvertTextOrPLF(sref, ref);
}
- }
+ }
// used to construct the suffix string to get the name of arguments for multiple passes
// e.g., the "2" in --weights2
@@ -284,7 +284,7 @@ struct DecoderImpl {
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
- bool graphviz;
+ bool graphviz;
bool joshua_viz;
bool encode_b64;
bool kbest;
@@ -404,6 +404,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")
("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
+ ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -665,6 +666,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
+ oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
combine_size = conf["combine_size"].as<int>();
@@ -1098,4 +1100,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
o->NotifyDecodingComplete(smeta);
return true;
}
-
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index d2c4715c..893e36ca 100644
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -252,6 +252,8 @@ struct OracleBleu {
}
bool show_derivation;
+ int show_derivation_mask;
+
template <class Filter>
void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
using namespace std;
@@ -275,7 +277,7 @@ struct OracleBleu {
if (show_derivation) {
deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k
deriv_out<<log(d->score)<<"\n";
- deriv_out<<kbest.derivation_tree(*d,true);
+ deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask);
deriv_out<<"\n"<<flush;
}
}