summaryrefslogtreecommitdiff
path: root/src/cdec.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-07 13:01:21 -0500
committerChris Dyer <redpony@gmail.com>2009-12-07 13:01:21 -0500
commit476d09e1df52cba0be8e5f50d52bf5f32795288f (patch)
tree849b10b690bcf762aeeabf114595862742a811ca /src/cdec.cc
parentec7edcc7e398bdb040d810094b8416ad9f279d98 (diff)
add support for generating pruned lattices when in compound splitting mode
Diffstat (limited to 'src/cdec.cc')
-rw-r--r--src/cdec.cc46
1 files changed, 37 insertions, 9 deletions
diff --git a/src/cdec.cc b/src/cdec.cc
index 7bdf7bcc..f9634a7d 100644
--- a/src/cdec.cc
+++ b/src/cdec.cc
@@ -13,6 +13,7 @@
#include "aligner.h"
#include "stringlib.h"
#include "forest_writer.h"
+#include "hg_io.h"
#include "filelib.h"
#include "sampler.h"
#include "sparse_vector.h"
@@ -47,7 +48,7 @@ void ShowBanner() {
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSPLIT (compound splitting)")
+ ("formalism,f",po::value<string>(),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting)")
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
@@ -66,6 +67,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("show_tree_structure,T", "Show the Viterbi derivation structure")
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
+ ("beam_prune", po::value<double>(), "Prune paths from +LM forest")
+ ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
+ ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
@@ -226,6 +230,19 @@ int main(int argc, char** argv) {
boost::shared_ptr<Translator> translator;
const string formalism = LowercaseString(conf["formalism"].as<string>());
+ const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
+ if (csplit_preserve_full_word &&
+ (formalism != "csplit" || !conf.count("beam_prune"))) {
+ cerr << "--csplit_preserve_full_word should only be "
+ << "used with csplit AND --beam_prune!\n";
+ exit(1);
+ }
+ const bool csplit_output_plf = conf.count("csplit_output_plf");
+ if (csplit_output_plf && formalism != "csplit") {
+ cerr << "--csplit_output_plf should only be used with csplit!\n";
+ exit(1);
+ }
+
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
else if (formalism == "fst")
@@ -239,12 +256,12 @@ int main(int argc, char** argv) {
else
assert(!"error");
- vector<double> wv;
+ vector<double> feature_weights;
Weights w;
if (conf.count("weights")) {
w.InitFromFile(conf["weights"].as<string>());
- wv.resize(FD::NumFeats());
- w.InitVector(&wv);
+ feature_weights.resize(FD::NumFeats());
+ w.InitVector(&feature_weights);
}
// set up additional scoring features
@@ -255,6 +272,7 @@ int main(int argc, char** argv) {
for (int i = 0; i < add_ffs.size(); ++i) {
string ff, param;
SplitCommandAndParam(add_ffs[i], &ff, &param);
+ cerr << "Feature: " << ff;
if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
else cerr << " (no config parameters)\n";
shared_ptr<FeatureFunction> pff = global_ff_registry->Create(ff, param);
@@ -264,7 +282,7 @@ int main(int argc, char** argv) {
late_ffs.push_back(pff.get());
}
}
- ModelSet late_models(wv, late_ffs);
+ ModelSet late_models(feature_weights, late_ffs);
const int sample_max_trans = conf.count("max_translation_sample") ?
conf["max_translation_sample"].as<int>() : 0;
@@ -321,7 +339,7 @@ int main(int argc, char** argv) {
const bool hadoop_counters = (write_gradient);
Hypergraph forest; // -LM forest
Timer t("Translation");
- if (!translator->Translate(to_translate, &smeta, wv, &forest)) {
+ if (!translator->Translate(to_translate, &smeta, feature_weights, &forest)) {
cerr << " NO PARSE FOUND.\n";
if (hadoop_counters)
cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl;
@@ -351,7 +369,7 @@ int main(int argc, char** argv) {
bool has_late_models = !late_models.empty();
if (has_late_models) {
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
forest.SortInEdgesByEdgeWeights();
Hypergraph lm_forest;
int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>();
@@ -361,13 +379,21 @@ int main(int argc, char** argv) {
PruningConfiguration(cubepruning_pop_limit),
&lm_forest);
forest.swap(lm_forest);
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
trans.clear();
ViterbiESentence(forest, &trans);
cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl;
cerr << " +LM Viterbi: " << TD::GetString(trans) << endl;
}
+ if (conf.count("beam_prune")) {
+ vector<bool> preserve_mask(forest.edges_.size(), false);
+ if (csplit_preserve_full_word)
+ preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
+ forest.BeamPruneInsideOutside(1.0, false, conf["beam_prune"].as<double>(), &preserve_mask);
+ cerr << " Pruned forest (paths): " << forest.NumberOfPaths() << endl;
+ }
+
if (conf.count("forest_output") && !has_ref) {
ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
assert(writer.Write(forest, minimal_forests));
@@ -378,6 +404,8 @@ int main(int argc, char** argv) {
} else {
if (kbest) {
DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest);
+ } else if (csplit_output_plf) {
+ cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
if (!graphviz && !has_ref) {
cout << TD::GetString(trans) << endl << flush;
@@ -405,7 +433,7 @@ int main(int argc, char** argv) {
if (HG::Intersect(ref, &forest)) {
cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
if (hadoop_counters)
cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl;