summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cab.ark.cs.cmu.edu>2012-06-19 00:05:18 -0400
committerChris Dyer <cdyer@cab.ark.cs.cmu.edu>2012-06-19 00:05:18 -0400
commit5975dcaa50adb5ce7a05b83583b8f9ddc45f3f0a (patch)
tree2bc2eb4e17576e0726d7a2fa7f20eac9061c311d /decoder
parent78cc819168b2a550e52e9cac06dbbed41a3b04b2 (diff)
parentee1520c5095ea8648617a3658b20eedfd4dd2007 (diff)
Merge branch 'master' of https://github.com/pks/cdec-dtrain
Diffstat (limited to 'decoder')
-rw-r--r--decoder/decoder.cc22
-rw-r--r--decoder/viterbi.cc12
-rw-r--r--decoder/viterbi.h5
3 files changed, 30 insertions, 9 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index cbb97a0d..333f0fb6 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -3,6 +3,7 @@
#include <tr1/unordered_map>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include <boost/make_shared.hpp>
#include "program_options.h"
#include "stringlib.h"
@@ -187,8 +188,8 @@ struct DecoderImpl {
}
void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
- void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false) {
- cerr << viterbi_stats(forest,name,true,show_tree,show_deriv);
+ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false, bool extract_rules=false, boost::shared_ptr<WriteFile> extract_file = boost::make_shared<WriteFile>()) {
+ cerr << viterbi_stats(forest,name,true,show_tree,show_deriv,extract_rules, extract_file);
cerr << endl;
}
@@ -424,7 +425,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
- ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")
("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
@@ -570,6 +571,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
// cube pruning pop-limit: we may want to configure this on a per-pass basis
pop_limit = conf["cubepruning_pop_limit"].as<int>();
+ if (conf.count("extract_rules")) {
+ if (!DirectoryExists(conf["extract_rules"].as<string>()))
+ MkDirP(conf["extract_rules"].as<string>());
+ }
+
// determine the number of rescoring/pruning/weighting passes configured
const int MAX_PASSES = 3;
for (int pass = 0; pass < MAX_PASSES; ++pass) {
@@ -712,9 +718,11 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
cfg_options.Validate();
#endif
- if (conf.count("extract_rules"))
- extract_file.reset(new WriteFile(str("extract_rules",conf)));
-
+ if (conf.count("extract_rules")) {
+ stringstream ss;
+ ss << sent_id;
+ extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str()));
+ }
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
sent_id = -1;
@@ -851,7 +859,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
#endif
forest.swap(rescored_forest);
forest.Reweight(cur_weights);
- if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation);
+ if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file);
}
if (conf.count("show_partition")) {
diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc
index 9d19914b..1b9c6665 100644
--- a/decoder/viterbi.cc
+++ b/decoder/viterbi.cc
@@ -5,11 +5,12 @@
#include <vector>
#include "hg.h"
+
//#define DEBUG_VITERBI_SORT
using namespace std;
-std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree,bool show_derivation)
+std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree,bool show_derivation, bool extract_rules, boost::shared_ptr<WriteFile> extract_file)
{
ostringstream o;
o << hg.stats(name);
@@ -22,6 +23,9 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es
if (etree) {
o<<name<<" tree: "<<ViterbiETree(hg)<<endl;
}
+ if (extract_rules) {
+ ViterbiRules(hg, extract_file->stream());
+ }
if (show_derivation) {
o<<name<<" derivation: ";
o << hg.show_viterbi_tree(false); // last item should be goal (or at least depend on prev items). TODO: this doesn't actually reorder the nodes in hg.
@@ -36,6 +40,12 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es
return o.str();
}
+void ViterbiRules(const Hypergraph& hg, ostream* o) {
+ vector<Hypergraph::Edge const*> edges;
+ Viterbi<ViterbiPathTraversal>(hg, &edges);
+ for (unsigned i = 0; i < edges.size(); i++)
+ (*o) << edges[i]->rule_->AsString(true) << endl;
+}
string ViterbiETree(const Hypergraph& hg) {
vector<WordID> tmp;
diff --git a/decoder/viterbi.h b/decoder/viterbi.h
index 3092f6da..03e961a2 100644
--- a/decoder/viterbi.h
+++ b/decoder/viterbi.h
@@ -5,8 +5,10 @@
#include "prob.h"
#include "hg.h"
#include "tdict.h"
+#include "filelib.h"
+#include <boost/make_shared.hpp>
-std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false, bool derivation_tree=false);
+std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false, bool derivation_tree=false, bool extract_rules=false, boost::shared_ptr<WriteFile> extract_file = boost::make_shared<WriteFile>());
/// computes for each hg node the best (according to WeightType/WeightFunction) derivation, and some homomorphism (bottom up expression tree applied through Traversal) of it. T is the "return type" of Traversal, which is called only once for the best edge for a node's result (i.e. result will start default constructed)
//TODO: make T a typename inside Traversal and WeightType a typename inside WeightFunction?
@@ -201,6 +203,7 @@ struct FeatureVectorTraversal {
std::string JoshuaVisualizationString(const Hypergraph& hg);
prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result);
std::string ViterbiETree(const Hypergraph& hg);
+void ViterbiRules(const Hypergraph& hg, std::ostream* s);
prob_t ViterbiFSentence(const Hypergraph& hg, std::vector<WordID>* result);
std::string ViterbiFTree(const Hypergraph& hg);
int ViterbiELength(const Hypergraph& hg);