summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-22 13:15:42 -0600
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-22 13:15:42 -0600
commit129832e6d12b4c6e54189bdc030a6a31cccbba5c (patch)
treeb0c87af3f29455cd3aa7cd97afd2142346632d4e /decoder
parentb5ca2bd7001a385594af8dc4b9206399c679f8c5 (diff)
fix compound splitter, new features, more training data
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc1
-rw-r--r--decoder/csplit.cc140
-rw-r--r--decoder/decoder.cc25
-rw-r--r--decoder/ff_csplit.cc21
4 files changed, 100 insertions, 87 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 18460950..9390c809 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -177,6 +177,7 @@ public:
void Apply() {
int num_nodes = in.nodes_.size();
+ assert(num_nodes >= 2);
int goal_id = num_nodes - 1;
int pregoal = goal_id - 1;
int every = 1;
diff --git a/decoder/csplit.cc b/decoder/csplit.cc
index 7d50e3af..4a723822 100644
--- a/decoder/csplit.cc
+++ b/decoder/csplit.cc
@@ -13,14 +13,16 @@ using namespace std;
struct CompoundSplitImpl {
CompoundSplitImpl(const boost::program_options::variables_map& conf) :
- fugen_elements_(true), // TODO configure
+ fugen_elements_(true),
min_size_(3),
kXCAT(TD::Convert("X")*-1),
kWORDBREAK_RULE(new TRule("[X] ||| # ||| #")),
kTEMPLATE_RULE(new TRule("[X] ||| [X,1] ? ||| [1] ?")),
kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")),
kFUGEN_S(FD::Convert("FugS")),
- kFUGEN_N(FD::Convert("FugN")) {}
+ kFUGEN_N(FD::Convert("FugN")) {
+ // TODO: use conf to turn fugenelements on and off
+ }
void PasteTogetherStrings(const vector<string>& chars,
const int i,
@@ -40,73 +42,73 @@ struct CompoundSplitImpl {
void BuildTrellis(const vector<string>& chars,
Hypergraph* forest) {
-// vector<int> nodes(chars.size()+1, -1);
-// nodes[0] = forest->AddNode(kXCAT)->id_; // source
-// const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_;
-// forest->ConnectEdgeToHeadNode(left_rule, nodes[0]);
-//
-// const int max_split_ = max(static_cast<int>(chars.size()) - min_size_ + 1, 1);
-// cerr << "max: " << max_split_ << " " << " min: " << min_size_ << endl;
-// for (int i = min_size_; i < max_split_; ++i)
-// nodes[i] = forest->AddNode(kXCAT)->id_;
-// assert(nodes.back() == -1);
-// nodes.back() = forest->AddNode(kXCAT)->id_; // sink
-//
-// for (int i = 0; i < max_split_; ++i) {
-// if (nodes[i] < 0) continue;
-// const int start = min(i + min_size_, static_cast<int>(chars.size()));
-// for (int j = start; j <= chars.size(); ++j) {
-// if (nodes[j] < 0) continue;
-// string yield;
-// PasteTogetherStrings(chars, i, j, &yield);
-// // cerr << "[" << i << "," << j << "] " << yield << endl;
-// TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE));
-// rule->e_[1] = rule->f_[1] = TD::Convert(yield);
-// // cerr << rule->AsString() << endl;
-// int edge = forest->AddEdge(
-// rule,
-// Hypergraph::TailNodeVector(1, nodes[i]))->id_;
-// forest->ConnectEdgeToHeadNode(edge, nodes[j]);
-// forest->edges_[edge].i_ = i;
-// forest->edges_[edge].j_ = j;
-//
-// // handle "fugenelemente" here
-// // don't delete "fugenelemente" at the end of words
-// if (fugen_elements_ && j != chars.size()) {
-// const int len = yield.size();
-// string alt;
-// int fid = 0;
-// if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') {
-// alt = yield.substr(0, len - 2);
-// fid = kFUGEN_S;
-// } else if (len > (min_size_ + 1) && yield[len-1] == 's') {
-// alt = yield.substr(0, len - 1);
-// fid = kFUGEN_S;
-// } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') {
-// alt = yield.substr(0, len - 1);
-// fid = kFUGEN_N;
-// }
-// if (alt.size()) {
-// TRulePtr altrule = TRulePtr(new TRule(*rule));
-// altrule->e_[1] = TD::Convert(alt);
-// // cerr << altrule->AsString() << endl;
-// int edge = forest->AddEdge(
-// altrule,
-// Hypergraph::TailNodeVector(1, nodes[i]))->id_;
-// forest->ConnectEdgeToHeadNode(edge, nodes[j]);
-// forest->edges_[edge].feature_values_.set_value(fid, 1.0);
-// forest->edges_[edge].i_ = i;
-// forest->edges_[edge].j_ = j;
-// }
-// }
-// }
-// }
-//
-// // add goal rule
-// Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
-// Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
-// Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
-// forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ vector<int> nodes(chars.size()+1, -1);
+ nodes[0] = forest->AddNode(kXCAT)->id_; // source
+ const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_;
+ forest->ConnectEdgeToHeadNode(left_rule, nodes[0]);
+
+ const int max_split_ = max(static_cast<int>(chars.size()) - min_size_ + 1, 1);
+ // cerr << "max: " << max_split_ << " " << " min: " << min_size_ << endl;
+ for (int i = min_size_; i < max_split_; ++i)
+ nodes[i] = forest->AddNode(kXCAT)->id_;
+ assert(nodes.back() == -1);
+ nodes.back() = forest->AddNode(kXCAT)->id_; // sink
+
+ for (int i = 0; i < max_split_; ++i) {
+ if (nodes[i] < 0) continue;
+ const int start = min(i + min_size_, static_cast<int>(chars.size()));
+ for (int j = start; j <= chars.size(); ++j) {
+ if (nodes[j] < 0) continue;
+ string yield;
+ PasteTogetherStrings(chars, i, j, &yield);
+ // cerr << "[" << i << "," << j << "] " << yield << endl;
+ TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE));
+ rule->e_[1] = rule->f_[1] = TD::Convert(yield);
+ // cerr << rule->AsString() << endl;
+ int edge = forest->AddEdge(
+ rule,
+ Hypergraph::TailNodeVector(1, nodes[i]))->id_;
+ forest->ConnectEdgeToHeadNode(edge, nodes[j]);
+ forest->edges_[edge].i_ = i;
+ forest->edges_[edge].j_ = j;
+
+ // handle "fugenelemente" here
+ // don't delete "fugenelemente" at the end of words
+ if (fugen_elements_ && j != chars.size()) {
+ const int len = yield.size();
+ string alt;
+ int fid = 0;
+ if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') {
+ alt = yield.substr(0, len - 2);
+ fid = kFUGEN_S;
+ } else if (len > (min_size_ + 1) && yield[len-1] == 's') {
+ alt = yield.substr(0, len - 1);
+ fid = kFUGEN_S;
+ } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') {
+ alt = yield.substr(0, len - 1);
+ fid = kFUGEN_N;
+ }
+ if (alt.size()) {
+ TRulePtr altrule = TRulePtr(new TRule(*rule));
+ altrule->e_[1] = TD::Convert(alt);
+ // cerr << altrule->AsString() << endl;
+ int edge = forest->AddEdge(
+ altrule,
+ Hypergraph::TailNodeVector(1, nodes[i]))->id_;
+ forest->ConnectEdgeToHeadNode(edge, nodes[j]);
+ forest->edges_[edge].feature_values_.set_value(fid, 1.0);
+ forest->edges_[edge].i_ = i;
+ forest->edges_[edge].j_ = j;
+ }
+ }
+ }
+ }
+
+ // add goal rule
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
}
private:
const bool fugen_elements_;
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 3551b584..e28080aa 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -279,7 +279,6 @@ struct DecoderImpl {
bool encode_b64;
bool kbest;
bool unique_kbest;
- bool crf_uniform_empirical;
bool get_oracle_forest;
shared_ptr<WriteFile> extract_file;
int combine_size;
@@ -379,7 +378,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)")
- ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM")
("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules")
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
@@ -611,7 +609,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
encode_b64 = str("vector_format",conf) == "b64";
kbest = conf.count("k_best");
unique_kbest = conf.count("unique_k_best");
- crf_uniform_empirical = conf.count("crf_uniform_empirical");
get_oracle_forest = conf.count("get_oracle_forest");
cfg_options.Validate();
@@ -842,14 +839,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (has_ref) {
if (HG::Intersect(ref, &forest)) {
if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- if (crf_uniform_empirical) {
- if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
- for (int i = 0; i < forest.edges_.size(); ++i)
- forest.edges_[i].edge_prob_=prob_t::One();
- } else {
- forest.Reweight(feature_weights);
- if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
- }
+// if (crf_uniform_empirical) {
+// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
+// for (int i = 0; i < forest.edges_.size(); ++i)
+// forest.edges_[i].edge_prob_=prob_t::One(); }
+ forest.Reweight(feature_weights);
+ if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
if (conf.count("show_partition")) {
const prob_t z = Inside<prob_t, EdgeProb>(forest);
cerr << " Contst. partition log(Z): " << log(z) << endl;
@@ -878,11 +873,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (write_gradient) {
const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
ref_exp /= ref_z;
- if (crf_uniform_empirical) {
- log_ref_z = ref_exp.dot(feature_weights);
- } else {
- log_ref_z = log(ref_z);
- }
+// if (crf_uniform_empirical)
+// log_ref_z = ref_exp.dot(feature_weights);
+ log_ref_z = log(ref_z);
//cerr << " MODEL LOG Z: " << log_z << endl;
//cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl;
if ((log_z - log_ref_z) < kMINUS_EPSILON) {
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc
index 1485009b..204b7ce6 100644
--- a/decoder/ff_csplit.cc
+++ b/decoder/ff_csplit.cc
@@ -22,9 +22,11 @@ struct BasicCSplitFeaturesImpl {
letters_sq_(FD::Convert("LettersSq")),
letters_sqrt_(FD::Convert("LettersSqrt")),
in_dict_(FD::Convert("InDict")),
+ in_dict_sub_word_(FD::Convert("InDictSubWord")),
short_(FD::Convert("Short")),
long_(FD::Convert("Long")),
oov_(FD::Convert("OOV")),
+ oov_sub_word_(FD::Convert("OOVSubWord")),
short_range_(FD::Convert("ShortRange")),
high_freq_(FD::Convert("HighFreq")),
med_freq_(FD::Convert("MedFreq")),
@@ -52,15 +54,18 @@ struct BasicCSplitFeaturesImpl {
}
void TraversalFeaturesImpl(const Hypergraph::Edge& edge,
+ const int src_word_size,
SparseVector<double>* features) const;
const int word_count_;
const int letters_sq_;
const int letters_sqrt_;
const int in_dict_;
+ const int in_dict_sub_word_;
const int short_;
const int long_;
const int oov_;
+ const int oov_sub_word_;
const int short_range_;
const int high_freq_;
const int med_freq_;
@@ -77,7 +82,9 @@ BasicCSplitFeatures::BasicCSplitFeatures(const string& param) :
void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(
const Hypergraph::Edge& edge,
+ const int src_word_length,
SparseVector<double>* features) const {
+ const bool subword = (edge.i_ > 0) || (edge.j_ < src_word_length);
features->set_value(word_count_, 1.0);
features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_));
features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_));
@@ -108,8 +115,10 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(
if (freq) {
features->set_value(freq_, freq);
features->set_value(in_dict_, 1.0);
+ if (subword) features->set_value(in_dict_sub_word_, 1.0);
} else {
features->set_value(oov_, 1.0);
+ if (subword) features->set_value(oov_sub_word_, 1.0);
freq = 99.0f;
}
if (bad_words_.count(word) != 0)
@@ -143,7 +152,7 @@ void BasicCSplitFeatures::TraversalFeaturesImpl(
(void) estimated_features;
if (edge.Arity() == 0) return;
if (edge.rule_->EWords() != 1) return;
- pimpl_->TraversalFeaturesImpl(edge, features);
+ pimpl_->TraversalFeaturesImpl(edge, smeta.GetSourceLattice().size(), features);
}
struct ReverseCharLMCSplitFeatureImpl {
@@ -208,9 +217,17 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
if (edge.rule_->EWords() != 1) return;
const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);
features->set_value(fid_, lpp);
+#if 0
WordID neighbor_word = 0;
const WordID word = edge.rule_->e_[1];
-#if 0
+ const char* sword = TD::Convert(word);
+ const int len = strlen(sword);
+ int cur = 0;
+ int chars = 0;
+ while(cur < len) {
+ cur += UTF8Len(sword[cur]);
+ ++chars;
+ }
if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) {
neighbor_word = TD::Convert(string(&sword[1]));
}