diff options
-rw-r--r-- | decoder/cdec.cc | 50 | ||||
-rw-r--r-- | decoder/hg.cc | 10 | ||||
-rw-r--r-- | decoder/hg.h | 3 | ||||
-rw-r--r-- | decoder/stringlib.h | 7 | ||||
-rw-r--r-- | decoder/viterbi.cc | 17 | ||||
-rw-r--r-- | decoder/viterbi.h | 4 |
6 files changed, 70 insertions, 21 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 5d0ac8b2..0a4593ef 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -82,6 +82,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("show_cfg_search_space", "Show the search space as a CFG") ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") + ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") ("lexalign_use_null", "Support source-side null words in lexical translation") ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") @@ -234,6 +235,15 @@ static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { void register_feature_functions(); +bool beam_param(po::variables_map const& conf,char const* name,double *val,bool scale_srclen=false,double srclen=1) +{ + if (conf.count(name)) { + *val=conf[name].as<double>()*(scale_srclen?srclen:1); + return true; + } + return false; +} + int main(int argc, char** argv) { global_ff_registry.reset(new FFRegistry); register_feature_functions(); @@ -257,6 +267,7 @@ int main(int argc, char** argv) { << "used with csplit AND --beam_prune!\n"; exit(1); } + const bool scale_prune_srclen=conf.count("scale_prune_srclen"); const bool csplit_output_plf = conf.count("csplit_output_plf"); if (csplit_output_plf && formalism != "csplit") { cerr << "--csplit_output_plf should only be used with csplit!\n"; @@ -373,6 +384,8 @@ int main(int argc, char** argv) { string to_translate; Lattice ref; ParseTranslatorInputLattice(buf, &to_translate, &ref); + const unsigned srclen=NTokens(to_translate,' '); +//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. const bool has_ref = ref.size() > 0; SentenceMetadata smeta(sent_id, ref); const bool hadoop_counters = (write_gradient); @@ -389,8 +402,8 @@ int main(int argc, char** argv) { cout << endl << flush; continue; } - cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl; + const bool show_tree_structure=conf.count("show_tree_structure"); + cerr << viterbi_stats(forest," -LM forest",true,show_tree_structure); if (conf.count("show_expected_length")) { const PRPair<double, double> res = Inside<PRPair<double, double>, @@ -403,16 +416,13 @@ int main(int argc, char** argv) { } if (extract_file) ExtractRulesDedupe(forest, extract_file->stream()); - vector<WordID> trans; - const prob_t vs = ViterbiESentence(forest, &trans); - cerr << " -LM Viterbi: " << TD::GetString(trans) << endl; - if (conf.count("show_tree_structure")) - cerr << " -LM tree: " << ViterbiETree(forest) << endl;; - cerr << " -LM Viterbi: " << log(vs) << endl; - - if (conf.count("prelm_beam_prune")) { - forest.BeamPruneInsideOutside(1.0, false, conf["prelm_beam_prune"].as<double>(), NULL); - cerr << " Pruned -LM forest (paths): " << forest.NumberOfPaths() << endl; + + double prelm_beam_prune; + if (beam_param(conf,"prelm_beam_prune",&prelm_beam_prune,scale_prune_srclen,srclen)) { + double presize=forest.edges_.size(); + forest.BeamPruneInsideOutside(1.0, false, prelm_beam_prune, NULL); + cerr << viterbi_stats(forest," Pruned -LM forest",false,false); + cerr << " Pruned -LM forest (beam="<<prelm_beam_prune<<") portion of edges kept: "<<forest.edges_.size()/presize; } bool has_late_models = !late_models.empty(); @@ -428,18 +438,15 @@ int main(int argc, char** argv) { &lm_forest); forest.swap(lm_forest); forest.Reweight(feature_weights); - trans.clear(); - ViterbiESentence(forest, &trans); - cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl; - cerr << " +LM Viterbi: " << TD::GetString(trans) << endl; + cerr << viterbi_stats(forest," +LM forest",true,show_tree_structure); } - if (conf.count("beam_prune")) { + double beam_prune; + if (beam_param(conf,"beam_prune",&beam_prune,scale_prune_srclen,srclen)) { vector<bool> preserve_mask(forest.edges_.size(), false); if (csplit_preserve_full_word) preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; - forest.BeamPruneInsideOutside(1.0, false, conf["beam_prune"].as<double>(), &preserve_mask); - cerr << " Pruned forest (paths): " << forest.NumberOfPaths() << endl; + forest.BeamPruneInsideOutside(1.0, false, beam_prune, &preserve_mask); + cerr << viterbi_stats(forest," Pruned forest",false,false); } if (conf.count("forest_output") && !has_ref) { @@ -464,6 +471,9 @@ int main(int argc, char** argv) { if (sample_max_trans) { MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); } else { + vector<WordID> trans; + ViterbiESentence(forest, &trans); + if (kbest) { DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest); } else if (csplit_output_plf) { diff --git a/decoder/hg.cc b/decoder/hg.cc index 025feb7c..4da0beb3 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -6,6 +6,7 @@ #include <set> #include <map> #include <iostream> +#include <sstream> #include "viterbi.h" #include "inside_outside.h" @@ -13,6 +14,15 @@ using namespace std; +std::string Hypergraph::stats(std::string const& name) const +{ + ostringstream o; + o<<name<<" (nodes/edges): "<<nodes_.size()<<'/'<<edges_.size()<<endl; + o<<name<<" (paths): "<<NumberOfPaths()<<endl; + return o.str(); +} + + double Hypergraph::NumberOfPaths() const { return Inside<double, TransitionCountWeightFunction>(*this); } diff --git a/decoder/hg.h b/decoder/hg.h index 50c9048a..a632dc1c 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -178,6 +178,9 @@ class Hypergraph { void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha, const std::vector<bool>* preserve_mask = NULL); + // report nodes, edges, paths + std::string stats(std::string const& name="forest") const; + void clear() { nodes_.clear(); edges_.clear(); diff --git a/decoder/stringlib.h b/decoder/stringlib.h index 76efee8f..22863945 100644 --- a/decoder/stringlib.h +++ b/decoder/stringlib.h @@ -36,6 +36,13 @@ inline void Tokenize(const std::string& str, char delimiter, std::vector<std::st res->push_back(&s[last]); } +inline unsigned NTokens(const std::string& str, char delimiter) +{ + std::vector<std::string> r; + Tokenize(str,delimiter,&r); + return r.size(); +} + inline std::string LowercaseString(const std::string& in) { std::string res(in.size(),' '); for (int i = 0; i < in.size(); ++i) diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index 582dc5b2..7f52d08c 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -6,6 +6,23 @@ using namespace std; +std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree) +{ + ostringstream o; + o << hg.stats(name); + if (estring) { + vector<WordID> trans; + const prob_t vs = ViterbiESentence(hg, &trans); + o<<name<<" Viterbi: "<<log(vs)<<endl; + o<<name<<" Viterbi: "<<TD::GetString(trans)<<endl; + } + if (etree) { + o<<name<<" tree: "<<ViterbiETree(hg)<<endl; + } + return o.str(); +} + + string ViterbiETree(const Hypergraph& hg) { vector<WordID> tmp; const prob_t p = Viterbi<vector<WordID>, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp); diff --git a/decoder/viterbi.h b/decoder/viterbi.h index dd54752a..d4a97516 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -6,6 +6,8 @@ #include "hg.h" #include "tdict.h" +std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false); + // V must implement: // void operator()(const vector<const T*>& ants, T* result); template<typename T, typename Traversal, typename WeightType, typename WeightFunction> @@ -21,7 +23,7 @@ WeightType Viterbi(const Hypergraph& hg, const Hypergraph::Node& cur_node = hg.nodes_[i]; WeightType* const cur_node_best_weight = &vit_weight[i]; T* const cur_node_best_result = &vit_result[i]; - + const int num_in_edges = cur_node.in_edges_.size(); if (num_in_edges == 0) { *cur_node_best_weight = WeightType(1); |