summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-07 13:01:21 -0500
committerChris Dyer <redpony@gmail.com>2009-12-07 13:01:21 -0500
commit476d09e1df52cba0be8e5f50d52bf5f32795288f (patch)
tree849b10b690bcf762aeeabf114595862742a811ca
parentec7edcc7e398bdb040d810094b8416ad9f279d98 (diff)
add support for generating pruned lattices when in compound splitting mode
-rw-r--r--src/apply_models.cc2
-rw-r--r--src/cdec.cc46
-rw-r--r--src/csplit.cc23
-rw-r--r--src/csplit.h5
-rw-r--r--src/ff_csplit.cc6
-rw-r--r--src/hg.cc5
-rw-r--r--src/hg_io.cc25
7 files changed, 90 insertions, 22 deletions
diff --git a/src/apply_models.cc b/src/apply_models.cc
index 8efb331b..b1d002f4 100644
--- a/src/apply_models.cc
+++ b/src/apply_models.cc
@@ -159,7 +159,7 @@ public:
out(*o),
D(in.nodes_.size()),
pop_limit_(pop_limit) {
- cerr << " Rescoring forest (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
+ cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
}
void Apply() {
diff --git a/src/cdec.cc b/src/cdec.cc
index 7bdf7bcc..f9634a7d 100644
--- a/src/cdec.cc
+++ b/src/cdec.cc
@@ -13,6 +13,7 @@
#include "aligner.h"
#include "stringlib.h"
#include "forest_writer.h"
+#include "hg_io.h"
#include "filelib.h"
#include "sampler.h"
#include "sparse_vector.h"
@@ -47,7 +48,7 @@ void ShowBanner() {
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSPLIT (compound splitting)")
+ ("formalism,f",po::value<string>(),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting)")
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
@@ -66,6 +67,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("show_tree_structure,T", "Show the Viterbi derivation structure")
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
+ ("beam_prune", po::value<double>(), "Prune paths from +LM forest")
+ ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
+ ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
@@ -226,6 +230,19 @@ int main(int argc, char** argv) {
boost::shared_ptr<Translator> translator;
const string formalism = LowercaseString(conf["formalism"].as<string>());
+ const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
+ if (csplit_preserve_full_word &&
+ (formalism != "csplit" || !conf.count("beam_prune"))) {
+ cerr << "--csplit_preserve_full_word should only be "
+ << "used with csplit AND --beam_prune!\n";
+ exit(1);
+ }
+ const bool csplit_output_plf = conf.count("csplit_output_plf");
+ if (csplit_output_plf && formalism != "csplit") {
+ cerr << "--csplit_output_plf should only be used with csplit!\n";
+ exit(1);
+ }
+
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
else if (formalism == "fst")
@@ -239,12 +256,12 @@ int main(int argc, char** argv) {
else
assert(!"error");
- vector<double> wv;
+ vector<double> feature_weights;
Weights w;
if (conf.count("weights")) {
w.InitFromFile(conf["weights"].as<string>());
- wv.resize(FD::NumFeats());
- w.InitVector(&wv);
+ feature_weights.resize(FD::NumFeats());
+ w.InitVector(&feature_weights);
}
// set up additional scoring features
@@ -255,6 +272,7 @@ int main(int argc, char** argv) {
for (int i = 0; i < add_ffs.size(); ++i) {
string ff, param;
SplitCommandAndParam(add_ffs[i], &ff, &param);
+ cerr << "Feature: " << ff;
if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
else cerr << " (no config parameters)\n";
shared_ptr<FeatureFunction> pff = global_ff_registry->Create(ff, param);
@@ -264,7 +282,7 @@ int main(int argc, char** argv) {
late_ffs.push_back(pff.get());
}
}
- ModelSet late_models(wv, late_ffs);
+ ModelSet late_models(feature_weights, late_ffs);
const int sample_max_trans = conf.count("max_translation_sample") ?
conf["max_translation_sample"].as<int>() : 0;
@@ -321,7 +339,7 @@ int main(int argc, char** argv) {
const bool hadoop_counters = (write_gradient);
Hypergraph forest; // -LM forest
Timer t("Translation");
- if (!translator->Translate(to_translate, &smeta, wv, &forest)) {
+ if (!translator->Translate(to_translate, &smeta, feature_weights, &forest)) {
cerr << " NO PARSE FOUND.\n";
if (hadoop_counters)
cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl;
@@ -351,7 +369,7 @@ int main(int argc, char** argv) {
bool has_late_models = !late_models.empty();
if (has_late_models) {
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
forest.SortInEdgesByEdgeWeights();
Hypergraph lm_forest;
int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>();
@@ -361,13 +379,21 @@ int main(int argc, char** argv) {
PruningConfiguration(cubepruning_pop_limit),
&lm_forest);
forest.swap(lm_forest);
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
trans.clear();
ViterbiESentence(forest, &trans);
cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl;
cerr << " +LM Viterbi: " << TD::GetString(trans) << endl;
}
+ if (conf.count("beam_prune")) {
+ vector<bool> preserve_mask(forest.edges_.size(), false);
+ if (csplit_preserve_full_word)
+ preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
+ forest.BeamPruneInsideOutside(1.0, false, conf["beam_prune"].as<double>(), &preserve_mask);
+ cerr << " Pruned forest (paths): " << forest.NumberOfPaths() << endl;
+ }
+
if (conf.count("forest_output") && !has_ref) {
ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
assert(writer.Write(forest, minimal_forests));
@@ -378,6 +404,8 @@ int main(int argc, char** argv) {
} else {
if (kbest) {
DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest);
+ } else if (csplit_output_plf) {
+ cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
if (!graphviz && !has_ref) {
cout << TD::GetString(trans) << endl << flush;
@@ -405,7 +433,7 @@ int main(int argc, char** argv) {
if (HG::Intersect(ref, &forest)) {
cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
- forest.Reweight(wv);
+ forest.Reweight(feature_weights);
cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
if (hadoop_counters)
cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl;
diff --git a/src/csplit.cc b/src/csplit.cc
index 21e1b711..47197782 100644
--- a/src/csplit.cc
+++ b/src/csplit.cc
@@ -45,7 +45,8 @@ struct CompoundSplitImpl {
const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_;
forest->ConnectEdgeToHeadNode(left_rule, nodes[0]);
- const int max_split_ = chars.size() - min_size_ + 1;
+ const int max_split_ = max(static_cast<int>(chars.size()) - min_size_ + 1, 1);
+ cerr << "max: " << max_split_ << " " << " min: " << min_size_ << endl;
for (int i = min_size_; i < max_split_; ++i)
nodes[i] = forest->AddNode(kXCAT)->id_;
assert(nodes.back() == -1);
@@ -53,7 +54,8 @@ struct CompoundSplitImpl {
for (int i = 0; i < max_split_; ++i) {
if (nodes[i] < 0) continue;
- for (int j = i + min_size_; j <= chars.size(); ++j) {
+ const int start = min(i + min_size_, static_cast<int>(chars.size()));
+ for (int j = start; j <= chars.size(); ++j) {
if (nodes[j] < 0) continue;
string yield;
PasteTogetherStrings(chars, i, j, &yield);
@@ -152,3 +154,20 @@ bool CompoundSplit::Translate(const string& input,
return true;
}
+int CompoundSplit::GetFullWordEdgeIndex(const Hypergraph& forest) {
+ assert(forest.nodes_.size() > 0);
+ const vector<int> out_edges = forest.nodes_[0].out_edges_;
+ int max_edge = -1;
+ int max_j = -1;
+ for (int i = 0; i < out_edges.size(); ++i) {
+ const int j = forest.edges_[out_edges[i]].j_;
+ if (j > max_j) {
+ max_j = j;
+ max_edge = out_edges[i];
+ }
+ }
+ assert(max_edge >= 0);
+ assert(max_edge < forest.edges_.size());
+ return max_edge;
+}
+
diff --git a/src/csplit.h b/src/csplit.h
index 54e5329d..ce6295c1 100644
--- a/src/csplit.h
+++ b/src/csplit.h
@@ -18,6 +18,11 @@ struct CompoundSplit : public Translator {
SentenceMetadata* smeta,
const std::vector<double>& weights,
Hypergraph* forest);
+
+ // given a forest generated by CompoundSplit::Translate,
+ // find the edge representing the unsegmented form
+ static int GetFullWordEdgeIndex(const Hypergraph& forest);
+
private:
boost::shared_ptr<CompoundSplitImpl> pimpl_;
};
diff --git a/src/ff_csplit.cc b/src/ff_csplit.cc
index eb106047..cac4bb8e 100644
--- a/src/ff_csplit.cc
+++ b/src/ff_csplit.cc
@@ -19,6 +19,8 @@ using namespace std;
struct BasicCSplitFeaturesImpl {
BasicCSplitFeaturesImpl(const string& param) :
word_count_(FD::Convert("WordCount")),
+ letters_sq_(FD::Convert("LettersSq")),
+ letters_sqrt_(FD::Convert("LettersSqrt")),
in_dict_(FD::Convert("InDict")),
short_(FD::Convert("Short")),
long_(FD::Convert("Long")),
@@ -53,6 +55,8 @@ struct BasicCSplitFeaturesImpl {
SparseVector<double>* features) const;
const int word_count_;
+ const int letters_sq_;
+ const int letters_sqrt_;
const int in_dict_;
const int short_;
const int long_;
@@ -75,6 +79,8 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(
const Hypergraph::Edge& edge,
SparseVector<double>* features) const {
features->set_value(word_count_, 1.0);
+ features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_));
+ features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_));
const WordID word = edge.rule_->e_[1];
const char* sword = TD::Convert(word);
const int len = strlen(sword);
diff --git a/src/hg.cc b/src/hg.cc
index dd8f8eba..7bd79394 100644
--- a/src/hg.cc
+++ b/src/hg.cc
@@ -77,6 +77,8 @@ prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const {
for (int i = 0; i < in.size(); ++i)
(*post)[i] = in[i] * out[i];
+ // for (int i = 0; i < in.size(); ++i)
+ // cerr << "edge " << i << ": " << log((*post)[i]) << endl;
return ins_sco;
}
@@ -161,6 +163,7 @@ void Hypergraph::BeamPruneInsideOutside(
if (io[i] > best) best = io[i];
const prob_t aprob(exp(-alpha));
const prob_t cutoff = best * aprob;
+ // cerr << "aprob = " << aprob << "\t CUTOFF=" << cutoff << endl;
vector<bool> prune(edges_.size());
//cerr << preserve_mask.size() << " " << edges_.size() << endl;
int pc = 0;
@@ -170,7 +173,7 @@ void Hypergraph::BeamPruneInsideOutside(
prune[i] = (io[i] < cutoff);
if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
}
- cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n";
+ // cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n";
PruneEdges(prune);
}
diff --git a/src/hg_io.cc b/src/hg_io.cc
index beb96aa6..e21b1714 100644
--- a/src/hg_io.cc
+++ b/src/hg_io.cc
@@ -350,17 +350,24 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses
if (include_global_parentheses) os << '(';
static const string EPS="*EPS*";
for (int i = 0; i < hg.nodes_.size()-1; ++i) {
- os << '(';
if (hg.nodes_[i].out_edges_.empty()) abort();
- for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) {
- const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
- const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
- double prob = log(e.edge_prob_);
- if (isinf(prob)) { prob = -9e20; }
- if (isnan(prob)) { prob = 0; }
- os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
+ const bool last_node = (i == hg.nodes_.size() - 2);
+ const int out_edges_size = hg.nodes_[i].out_edges_.size();
+ // compound splitter adds an extra goal transition which we suppress with
+ // the following conditional
+ if (!last_node || out_edges_size != 1 ||
+ hg.edges_[hg.nodes_[i].out_edges_[0]].rule_->EWords() == 1) {
+ os << '(';
+ for (int j = 0; j < out_edges_size; ++j) {
+ const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
+ const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
+ double prob = log(e.edge_prob_);
+ if (isinf(prob)) { prob = -9e20; }
+ if (isnan(prob)) { prob = 0; }
+ os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
+ }
+ os << "),";
}
- os << "),";
}
if (include_global_parentheses) os << ')';
return os.str();