From ee1520c5095ea8648617a3658b20eedfd4dd2007 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Mon, 18 Jun 2012 17:26:33 +0200 Subject: extract_rules cdec param --- decoder/decoder.cc | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index cbb97a0d..333f0fb6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include "program_options.h" #include "stringlib.h" @@ -187,8 +188,8 @@ struct DecoderImpl { } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } - void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false) { - cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false, bool extract_rules=false, boost::shared_ptr extract_file = boost::make_shared()) { + cerr << viterbi_stats(forest,name,true,show_tree,show_deriv,extract_rules, extract_file); cerr << endl; } @@ -424,7 +425,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") - ("extract_rules", po::value(), "Extract the rules used in translation (de-duped) to this file") + ("extract_rules", po::value(), "Extract the rules used in translation (not de-duped!) to a file in this directory") ("show_derivations", po::value(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") @@ -570,6 +571,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // cube pruning pop-limit: we may want to configure this on a per-pass basis pop_limit = conf["cubepruning_pop_limit"].as(); + if (conf.count("extract_rules")) { + if (!DirectoryExists(conf["extract_rules"].as())) + MkDirP(conf["extract_rules"].as()); + } + // determine the number of rescoring/pruning/weighting passes configured const int MAX_PASSES = 3; for (int pass = 0; pass < MAX_PASSES; ++pass) { @@ -712,9 +718,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream cfg_options.Validate(); #endif - if (conf.count("extract_rules")) - extract_file.reset(new WriteFile(str("extract_rules",conf))); - + if (conf.count("extract_rules")) { + stringstream ss; + ss << sent_id; + extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str())); + } combine_size = conf["combine_size"].as(); if (combine_size < 1) combine_size = 1; sent_id = -1; @@ -851,7 +859,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { #endif forest.swap(rescored_forest); forest.Reweight(cur_weights); - if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation); + if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); } if (conf.count("show_partition")) { -- cgit v1.2.3