From 70114f58324ff60de3d77a42d299e3bcf2c37450 Mon Sep 17 00:00:00 2001 From: graehl Date: Thu, 1 Jul 2010 23:08:34 +0000 Subject: factor forest stats (show size post pruning, and portion kept) and cdec --scale_prune_srclen git-svn-id: https://ws10smt.googlecode.com/svn/trunk@96 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cdec.cc | 50 ++++++++++++++++++++++++++++++-------------------- decoder/hg.cc | 10 ++++++++++ decoder/hg.h | 3 +++ decoder/stringlib.h | 7 +++++++ decoder/viterbi.cc | 17 +++++++++++++++++ decoder/viterbi.h | 4 +++- 6 files changed, 70 insertions(+), 21 deletions(-) diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 5d0ac8b2..0a4593ef 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -82,6 +82,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("show_cfg_search_space", "Show the search space as a CFG") ("prelm_beam_prune", po::value(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") ("beam_prune", po::value(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") + ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") ("lexalign_use_null", "Support source-side null words in lexical translation") ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") @@ -234,6 +235,15 @@ static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { void register_feature_functions(); +bool beam_param(po::variables_map const& conf,char const* name,double *val,bool scale_srclen=false,double srclen=1) +{ + if (conf.count(name)) { + *val=conf[name].as()*(scale_srclen?srclen:1); + return true; + } + return false; +} + int main(int argc, char** argv) { global_ff_registry.reset(new FFRegistry); register_feature_functions(); @@ -257,6 +267,7 @@ int main(int argc, char** argv) { << "used with csplit AND --beam_prune!\n"; exit(1); } + const bool scale_prune_srclen=conf.count("scale_prune_srclen"); const bool csplit_output_plf = conf.count("csplit_output_plf"); if (csplit_output_plf && formalism != "csplit") { cerr << "--csplit_output_plf should only be used with csplit!\n"; @@ -373,6 +384,8 @@ int main(int argc, char** argv) { string to_translate; Lattice ref; ParseTranslatorInputLattice(buf, &to_translate, &ref); + const unsigned srclen=NTokens(to_translate,' '); +//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. const bool has_ref = ref.size() > 0; SentenceMetadata smeta(sent_id, ref); const bool hadoop_counters = (write_gradient); @@ -389,8 +402,8 @@ int main(int argc, char** argv) { cout << endl << flush; continue; } - cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl; + const bool show_tree_structure=conf.count("show_tree_structure"); + cerr << viterbi_stats(forest," -LM forest",true,show_tree_structure); if (conf.count("show_expected_length")) { const PRPair res = Inside, @@ -403,16 +416,13 @@ int main(int argc, char** argv) { } if (extract_file) ExtractRulesDedupe(forest, extract_file->stream()); - vector trans; - const prob_t vs = ViterbiESentence(forest, &trans); - cerr << " -LM Viterbi: " << TD::GetString(trans) << endl; - if (conf.count("show_tree_structure")) - cerr << " -LM tree: " << ViterbiETree(forest) << endl;; - cerr << " -LM Viterbi: " << log(vs) << endl; - - if (conf.count("prelm_beam_prune")) { - forest.BeamPruneInsideOutside(1.0, false, conf["prelm_beam_prune"].as(), NULL); - cerr << " Pruned -LM forest (paths): " << forest.NumberOfPaths() << endl; + + double prelm_beam_prune; + if (beam_param(conf,"prelm_beam_prune",&prelm_beam_prune,scale_prune_srclen,srclen)) { + double presize=forest.edges_.size(); + forest.BeamPruneInsideOutside(1.0, false, prelm_beam_prune, NULL); + cerr << viterbi_stats(forest," Pruned -LM forest",false,false); + cerr << " Pruned -LM forest (beam="< preserve_mask(forest.edges_.size(), false); if (csplit_preserve_full_word) preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; - forest.BeamPruneInsideOutside(1.0, false, conf["beam_prune"].as(), &preserve_mask); - cerr << " Pruned forest (paths): " << forest.NumberOfPaths() << endl; + forest.BeamPruneInsideOutside(1.0, false, beam_prune, &preserve_mask); + cerr << viterbi_stats(forest," Pruned forest",false,false); } if (conf.count("forest_output") && !has_ref) { @@ -464,6 +471,9 @@ int main(int argc, char** argv) { if (sample_max_trans) { MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as() : 0); } else { + vector trans; + ViterbiESentence(forest, &trans); + if (kbest) { DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest); } else if (csplit_output_plf) { diff --git a/decoder/hg.cc b/decoder/hg.cc index 025feb7c..4da0beb3 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "viterbi.h" #include "inside_outside.h" @@ -13,6 +14,15 @@ using namespace std; +std::string Hypergraph::stats(std::string const& name) const +{ + ostringstream o; + o<(*this); } diff --git a/decoder/hg.h b/decoder/hg.h index 50c9048a..a632dc1c 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -178,6 +178,9 @@ class Hypergraph { void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha, const std::vector* preserve_mask = NULL); + // report nodes, edges, paths + std::string stats(std::string const& name="forest") const; + void clear() { nodes_.clear(); edges_.clear(); diff --git a/decoder/stringlib.h b/decoder/stringlib.h index 76efee8f..22863945 100644 --- a/decoder/stringlib.h +++ b/decoder/stringlib.h @@ -36,6 +36,13 @@ inline void Tokenize(const std::string& str, char delimiter, std::vectorpush_back(&s[last]); } +inline unsigned NTokens(const std::string& str, char delimiter) +{ + std::vector r; + Tokenize(str,delimiter,&r); + return r.size(); +} + inline std::string LowercaseString(const std::string& in) { std::string res(in.size(),' '); for (int i = 0; i < in.size(); ++i) diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index 582dc5b2..7f52d08c 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -6,6 +6,23 @@ using namespace std; +std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool estring, bool etree) +{ + ostringstream o; + o << hg.stats(name); + if (estring) { + vector trans; + const prob_t vs = ViterbiESentence(hg, &trans); + o< tmp; const prob_t p = Viterbi, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp); diff --git a/decoder/viterbi.h b/decoder/viterbi.h index dd54752a..d4a97516 100644 --- a/decoder/viterbi.h +++ b/decoder/viterbi.h @@ -6,6 +6,8 @@ #include "hg.h" #include "tdict.h" +std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false); + // V must implement: // void operator()(const vector& ants, T* result); template @@ -21,7 +23,7 @@ WeightType Viterbi(const Hypergraph& hg, const Hypergraph::Node& cur_node = hg.nodes_[i]; WeightType* const cur_node_best_weight = &vit_weight[i]; T* const cur_node_best_result = &vit_result[i]; - + const int num_in_edges = cur_node.in_edges_.size(); if (num_in_edges == 0) { *cur_node_best_weight = WeightType(1); -- cgit v1.2.3