diff options
author | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-06-19 00:05:18 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-06-19 00:05:18 -0400 |
commit | fcd8e74ca9c16fe0e3001906ae2bd0ac0686f813 (patch) | |
tree | 61e3a3b19b65f05d5e74cb91626631c78ba83d59 /decoder | |
parent | 5cd58c1355811caf0941ad6f0340c2deb52cc99c (diff) | |
parent | a47bbc78b3d38ea998b2d484470061140142048d (diff) |
Merge branch 'master' of https://github.com/pks/cdec-dtrain
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/decoder.cc | 22 | ||||
-rw-r--r-- | decoder/viterbi.cc | 12 | ||||
-rw-r--r-- | decoder/viterbi.h | 5 |
3 files changed, 30 insertions, 9 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index cbb97a0d..333f0fb6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -3,6 +3,7 @@ #include <tr1/unordered_map> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include <boost/make_shared.hpp> #include "program_options.h" #include "stringlib.h" @@ -187,8 +188,8 @@ struct DecoderImpl { } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } - void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false) { - cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false, bool extract_rules=false, boost::shared_ptr<WriteFile> extract_file = boost::make_shared<WriteFile>()) { + cerr << viterbi_stats(forest,name,true,show_tree,show_deriv,extract_rules, extract_file); cerr << endl; } @@ -424,7 +425,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") - ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") + ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory") ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") @@ -570,6 +571,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream // cube pruning pop-limit: we may want to configure this on a per-pass basis pop_limit = conf["cubepruning_pop_limit"].as<int>(); + if (conf.count("extract_rules")) { + if (!DirectoryExists(conf["extract_rules"].as<string>())) + MkDirP(conf["extract_rules"].as<string>()); + } + // determine the number of rescoring/pruning/weighting passes configured const int MAX_PASSES = 3; for (int pass = 0; pass < MAX_PASSES; ++pass) { @@ -712,9 +718,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream cfg_options.Validate(); #endif - if (conf.count("extract_rules")) - extract_file.reset(new WriteFile(str("extract_rules",conf))); - + if (conf.count("extract_rules")) { + stringstream ss; + ss << sent_id; + extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str())); + } combine_size = conf["combine_size"].as<int>(); if (combine_size < 1) combine_size = 1; sent_id = -1; @@ -851,7 +859,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { #endif forest.swap(rescored_forest); forest.Reweight(cur_weights); - if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation); + if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); } if (conf.count("show_partition")) { diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index 9d19914b..1b9c6665 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -5,11 +5,12 @@ #include <vector> #include "hg.h" + //#define DEBUG_VITERBI_SORT using namespace std; -std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree,bool show_derivation) +std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree,bool show_derivation, bool extract_rules, boost::shared_ptr<WriteFile> extract_file) { ostringstream o; o << hg.stats(name); @@ -22,6 +23,9 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es if (etree) { o<<name<<" tree: "<<ViterbiETree(hg)<<endl; } + if (extract_rules) { + ViterbiRules(hg, extract_file->stream()); + } if (show_derivation) { o<<name<<" derivation: "; o << hg.show_viterbi_tree(false); // last item should be goal (or at least depend on prev items). TODO: this doesn't actually reorder the nodes in hg. @@ -36,6 +40,12 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es return o.str(); } +void ViterbiRules(const Hypergraph& hg, ostream* o) { + vector<Hypergraph::Edge const*> edges; + Viterbi<ViterbiPathTraversal>(hg, &edges); + for (unsigned i = 0; i < edges.size(); i++) + (*o) << edges[i]->rule_->AsString(true) << endl; +} string ViterbiETree(const Hypergraph& hg) { vector<WordID> tmp; diff --git a/decoder/viterbi.h b/decoder/viterbi.h index 3092f6da..03e961a2 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -5,8 +5,10 @@ #include "prob.h" #include "hg.h" #include "tdict.h" +#include "filelib.h" +#include <boost/make_shared.hpp> -std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false, bool derivation_tree=false); +std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false, bool derivation_tree=false, bool extract_rules=false, boost::shared_ptr<WriteFile> extract_file = boost::make_shared<WriteFile>()); /// computes for each hg node the best (according to WeightType/WeightFunction) derivation, and some homomorphism (bottom up expression tree applied through Traversal) of it. T is the "return type" of Traversal, which is called only once for the best edge for a node's result (i.e. result will start default constructed) //TODO: make T a typename inside Traversal and WeightType a typename inside WeightFunction? @@ -201,6 +203,7 @@ struct FeatureVectorTraversal { std::string JoshuaVisualizationString(const Hypergraph& hg); prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result); std::string ViterbiETree(const Hypergraph& hg); +void ViterbiRules(const Hypergraph& hg, std::ostream* s); prob_t ViterbiFSentence(const Hypergraph& hg, std::vector<WordID>* result); std::string ViterbiFTree(const Hypergraph& hg); int ViterbiELength(const Hypergraph& hg); |