diff options
author | Chris Dyer <redpony@gmail.com> | 2009-12-07 13:01:21 -0500 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2009-12-07 13:01:21 -0500 |
commit | 476d09e1df52cba0be8e5f50d52bf5f32795288f (patch) | |
tree | 849b10b690bcf762aeeabf114595862742a811ca /src/cdec.cc | |
parent | ec7edcc7e398bdb040d810094b8416ad9f279d98 (diff) |
add support for generating pruned lattices when in compound splitting mode
Diffstat (limited to 'src/cdec.cc')
-rw-r--r-- | src/cdec.cc | 46 |
1 files changed, 37 insertions, 9 deletions
diff --git a/src/cdec.cc b/src/cdec.cc index 7bdf7bcc..f9634a7d 100644 --- a/src/cdec.cc +++ b/src/cdec.cc @@ -13,6 +13,7 @@ #include "aligner.h" #include "stringlib.h" #include "forest_writer.h" +#include "hg_io.h" #include "filelib.h" #include "sampler.h" #include "sparse_vector.h" @@ -47,7 +48,7 @@ void ShowBanner() { void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSPLIT (compound splitting)") + ("formalism,f",po::value<string>(),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting)") ("input,i",po::value<string>()->default_value("-"),"Source file") ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("weights,w",po::value<string>(),"Feature weights file") @@ -66,6 +67,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("show_tree_structure,T", "Show the Viterbi derivation structure") ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") + ("beam_prune", po::value<double>(), "Prune paths from +LM forest") + ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") + ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") @@ -226,6 +230,19 @@ int main(int argc, char** argv) { boost::shared_ptr<Translator> translator; const string formalism = LowercaseString(conf["formalism"].as<string>()); + const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); + if (csplit_preserve_full_word && + (formalism != "csplit" || !conf.count("beam_prune"))) { + cerr << "--csplit_preserve_full_word should only be " + << "used with csplit AND --beam_prune!\n"; + exit(1); + } + const bool csplit_output_plf = conf.count("csplit_output_plf"); + if (csplit_output_plf && formalism != "csplit") { + cerr << "--csplit_output_plf should only be used with csplit!\n"; + exit(1); + } + if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); else if (formalism == "fst") @@ -239,12 +256,12 @@ int main(int argc, char** argv) { else assert(!"error"); - vector<double> wv; + vector<double> feature_weights; Weights w; if (conf.count("weights")) { w.InitFromFile(conf["weights"].as<string>()); - wv.resize(FD::NumFeats()); - w.InitVector(&wv); + feature_weights.resize(FD::NumFeats()); + w.InitVector(&feature_weights); } // set up additional scoring features @@ -255,6 +272,7 @@ int main(int argc, char** argv) { for (int i = 0; i < add_ffs.size(); ++i) { string ff, param; SplitCommandAndParam(add_ffs[i], &ff, ¶m); + cerr << "Feature: " << ff; if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; else cerr << " (no config parameters)\n"; shared_ptr<FeatureFunction> pff = global_ff_registry->Create(ff, param); @@ -264,7 +282,7 @@ int main(int argc, char** argv) { late_ffs.push_back(pff.get()); } } - ModelSet late_models(wv, late_ffs); + ModelSet late_models(feature_weights, late_ffs); const int sample_max_trans = conf.count("max_translation_sample") ? conf["max_translation_sample"].as<int>() : 0; @@ -321,7 +339,7 @@ int main(int argc, char** argv) { const bool hadoop_counters = (write_gradient); Hypergraph forest; // -LM forest Timer t("Translation"); - if (!translator->Translate(to_translate, &smeta, wv, &forest)) { + if (!translator->Translate(to_translate, &smeta, feature_weights, &forest)) { cerr << " NO PARSE FOUND.\n"; if (hadoop_counters) cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl; @@ -351,7 +369,7 @@ int main(int argc, char** argv) { bool has_late_models = !late_models.empty(); if (has_late_models) { - forest.Reweight(wv); + forest.Reweight(feature_weights); forest.SortInEdgesByEdgeWeights(); Hypergraph lm_forest; int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>(); @@ -361,13 +379,21 @@ int main(int argc, char** argv) { PruningConfiguration(cubepruning_pop_limit), &lm_forest); forest.swap(lm_forest); - forest.Reweight(wv); + forest.Reweight(feature_weights); trans.clear(); ViterbiESentence(forest, &trans); cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl; cerr << " +LM Viterbi: " << TD::GetString(trans) << endl; } + if (conf.count("beam_prune")) { + vector<bool> preserve_mask(forest.edges_.size(), false); + if (csplit_preserve_full_word) + preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; + forest.BeamPruneInsideOutside(1.0, false, conf["beam_prune"].as<double>(), &preserve_mask); + cerr << " Pruned forest (paths): " << forest.NumberOfPaths() << endl; + } + if (conf.count("forest_output") && !has_ref) { ForestWriter writer(conf["forest_output"].as<string>(), sent_id); assert(writer.Write(forest, minimal_forests)); @@ -378,6 +404,8 @@ int main(int argc, char** argv) { } else { if (kbest) { DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest); + } else if (csplit_output_plf) { + cout << HypergraphIO::AsPLF(forest, false) << endl; } else { if (!graphviz && !has_ref) { cout << TD::GetString(trans) << endl << flush; @@ -405,7 +433,7 @@ int main(int argc, char** argv) { if (HG::Intersect(ref, &forest)) { cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; - forest.Reweight(wv); + forest.Reweight(feature_weights); cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (hadoop_counters) cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl; |