summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/cdec_ff.cc3
-rw-r--r--decoder/ff_wordalign.cc256
-rw-r--r--decoder/ff_wordalign.h55
-rw-r--r--decoder/lextrans.cc32
-rw-r--r--decoder/trule.cc20
5 files changed, 348 insertions, 18 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 3953118c..d6cf4572 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -51,6 +51,8 @@ void register_feature_functions() {
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
ff_registry.Register("Model2BinaryFeatures", new FFFactory<Model2BinaryFeatures>);
+ ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
+ ff_registry.Register("NewJump", new FFFactory<NewJump>);
ff_registry.Register("MarkovJump", new FFFactory<MarkovJump>);
ff_registry.Register("MarkovJumpFClass", new FFFactory<MarkovJumpFClass>);
ff_registry.Register("SourceBigram", new FFFactory<SourceBigram>);
@@ -64,6 +66,7 @@ void register_feature_functions() {
ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>);
ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>);
ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>);
+ ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>);
ff_registry.Register("WordSet", new FFFactory<WordSet>);
#ifdef HAVE_GLC
ff_registry.Register("ContextCRF", new FFFactory<Model1Features>);
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index 5f42b438..980c64ad 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -1,10 +1,13 @@
#include "ff_wordalign.h"
+#include <algorithm>
+#include <iterator>
#include <set>
#include <sstream>
#include <string>
#include <cmath>
+#include "verbose.h"
#include "alignment_pharaoh.h"
#include "stringlib.h"
#include "sentence_metadata.h"
@@ -20,6 +23,8 @@ static const int kNULL_i = 255; // -1 as an unsigned char
using namespace std;
+// TODO new feature: if a word is translated as itself and there is a transition back to the same word, fire a feature
+
Model2BinaryFeatures::Model2BinaryFeatures(const string& ) :
fids_(boost::extents[MAX_SENTENCE_SIZE][MAX_SENTENCE_SIZE][MAX_SENTENCE_SIZE]) {
for (int i = 1; i < MAX_SENTENCE_SIZE; ++i) {
@@ -195,6 +200,45 @@ void MarkovJumpFClass::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+LexNullJump::LexNullJump(const string& param) :
+ FeatureFunction(1),
+ fid_lex_null_(FD::Convert("JumpLexNull")),
+ fid_null_lex_(FD::Convert("JumpNullLex")),
+ fid_null_null_(FD::Convert("JumpNullNull")),
+ fid_lex_lex_(FD::Convert("JumpLexLex")) {}
+
+void LexNullJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* state) const {
+ char& dpstate = *((char*)state);
+ if (edge.Arity() == 0) {
+ // dpstate is 'N' = null or 'L' = lex
+ if (edge.i_ < 0) { dpstate = 'N'; } else { dpstate = 'L'; }
+ } else if (edge.Arity() == 1) {
+ dpstate = *((unsigned char*)ant_states[0]);
+ } else if (edge.Arity() == 2) {
+ char left = *((char*)ant_states[0]);
+ char right = *((char*)ant_states[1]);
+ dpstate = right;
+ if (left == 'N') {
+ if (right == 'N')
+ features->set_value(fid_null_null_, 1.0);
+ else
+ features->set_value(fid_null_lex_, 1.0);
+ } else { // left == 'L'
+ if (right == 'N')
+ features->set_value(fid_lex_null_, 1.0);
+ else
+ features->set_value(fid_lex_lex_, 1.0);
+ }
+ } else {
+ assert(!"something really unexpected is happening");
+ }
+}
+
MarkovJump::MarkovJump(const string& param) :
FeatureFunction(1),
fid_(FD::Convert("MarkovJump")),
@@ -287,6 +331,100 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+NewJump::NewJump(const string& param) :
+ FeatureFunction(1) {
+ cerr << " NewJump";
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ set<string> config;
+ for (int i = 0; i < argc; ++i) config.insert(argv[i]);
+ cerr << endl;
+ use_binned_log_lengths_ = config.count("use_binned_log_lengths") > 0;
+}
+
+// do a log transform on the length (of a sentence, a jump, etc)
+// this basically means that large distances that are close to each other
+// are put into the same bin
+int BinnedLogLength(int len) {
+ int res = static_cast<int>(log(len+1) / log(1.3));
+ if (res > 16) res = 16;
+ return res;
+}
+
+void NewJump::FireFeature(const SentenceMetadata& smeta,
+ const int prev_src_index,
+ const int cur_src_index,
+ SparseVector<double>* features) const {
+ const int src_len = smeta.GetSourceLength();
+ const int raw_jump = cur_src_index - prev_src_index;
+ char jtype = 0;
+ int jump_magnitude = raw_jump;
+ if (raw_jump > 0) { jtype = 'R'; } // Right
+ else if (raw_jump == 0) { jtype = 'S'; } // Stay
+ else { jtype = 'L'; jump_magnitude = raw_jump * -1; } // Left
+ int effective_length = src_len;
+ if (use_binned_log_lengths_) {
+ jump_magnitude = BinnedLogLength(jump_magnitude);
+ effective_length = BinnedLogLength(src_len);
+ }
+
+ if (true) {
+ static map<int, map<int, int> > len2jump2fid;
+ int& fid = len2jump2fid[src_len][raw_jump];
+ if (!fid) {
+ ostringstream os;
+ os << fid_str_ << ":FLen" << effective_length << ":" << jtype << jump_magnitude;
+ fid = FD::Convert(os.str());
+ }
+ features->set_value(fid, 1.0);
+ }
+}
+
+void NewJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* state) const {
+ unsigned char& dpstate = *((unsigned char*)state);
+ const int flen = smeta.GetSourceLength();
+ if (edge.Arity() == 0) {
+ dpstate = static_cast<unsigned int>(edge.i_);
+ if (edge.prev_i_ == 0) { // first target word in sentence
+ if (edge.i_ >= 0) { // generated from non-Null token?
+ FireFeature(smeta,
+ -1, // previous src = beginning of sentence index
+ edge.i_, // current src
+ features);
+ }
+ } else if (edge.prev_i_ == smeta.GetTargetLength() - 1) { // last word
+ if (edge.i_ >= 0) { // generated from non-Null token?
+ FireFeature(smeta,
+ edge.i_, // previous src = last word position
+ flen, // current src
+ features);
+ }
+ }
+ } else if (edge.Arity() == 1) {
+ dpstate = *((unsigned char*)ant_states[0]);
+ } else if (edge.Arity() == 2) {
+ int left_index = *((unsigned char*)ant_states[0]);
+ int right_index = *((unsigned char*)ant_states[1]);
+ if (right_index == -1)
+ dpstate = static_cast<unsigned int>(left_index);
+ else
+ dpstate = static_cast<unsigned int>(right_index);
+ if (left_index != kNULL_i && right_index != kNULL_i) {
+ FireFeature(smeta,
+ left_index, // previous src index
+ right_index, // current src index
+ features);
+ }
+ } else {
+ assert(!"something really unexpected is happening");
+ }
+}
+
SourceBigram::SourceBigram(const std::string& param) :
FeatureFunction(sizeof(WordID) + sizeof(int)) {
}
@@ -626,6 +764,122 @@ void InputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+WordPairFeatures::WordPairFeatures(const string& param) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc != 1) {
+ cerr << "WordPairFeature /path/to/feature_values.table\n";
+ abort();
+ }
+ set<WordID> all_srcs;
+ {
+ ReadFile rf(argv[0]);
+ istream& in = *rf.stream();
+ string buf;
+ while (in) {
+ getline(in, buf);
+ if (buf.empty()) continue;
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = start;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ const WordID src = TD::Convert(buf.substr(start, end - start));
+ all_srcs.insert(src);
+ }
+ }
+ if (all_srcs.empty()) {
+ cerr << "WordPairFeature " << param << " loaded empty file!\n";
+ return;
+ }
+ fkeys_.reserve(all_srcs.size());
+ copy(all_srcs.begin(), all_srcs.end(), back_inserter(fkeys_));
+ values_.resize(all_srcs.size());
+ if (!SILENT) { cerr << "WordPairFeature: " << all_srcs.size() << " sources\n"; }
+ ReadFile rf(argv[0]);
+ istream& in = *rf.stream();
+ string buf;
+ double val = 0;
+ WordID cur_src = 0;
+ map<WordID, SparseVector<float> > *pv = NULL;
+ const WordID kBARRIER = TD::Convert("|||");
+ while (in) {
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = start;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ const WordID src = TD::Convert(buf.substr(start, end - start));
+ if (cur_src != src) {
+ cur_src = src;
+ size_t ind = distance(fkeys_.begin(), lower_bound(fkeys_.begin(), fkeys_.end(), cur_src));
+ pv = &values_[ind];
+ }
+ end += 1;
+ start = end;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ WordID x = TD::Convert(buf.substr(start, end - start));
+ if (x != kBARRIER) {
+ cerr << "1 Format error: " << buf << endl;
+ abort();
+ }
+ start = end + 1;
+ end = start + 1;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ WordID trg = TD::Convert(buf.substr(start, end - start));
+ if (trg == kBARRIER) {
+ cerr << "2 Format error: " << buf << endl;
+ abort();
+ }
+ start = end + 1;
+ end = start + 1;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ WordID x2 = TD::Convert(buf.substr(start, end - start));
+ if (x2 != kBARRIER) {
+ cerr << "3 Format error: " << buf << endl;
+ abort();
+ }
+ start = end + 1;
+
+ SparseVector<float>& v = (*pv)[trg];
+ while(start < buf.size()) {
+ end = start + 1;
+ while(end < buf.size() && buf[end] != '=' && buf[end] != ' ') ++end;
+ if (end == buf.size() || buf[end] != '=') { cerr << "4 Format error: " << buf << endl; abort(); }
+ const int fid = FD::Convert(buf.substr(start, end - start));
+ start = end + 1;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ end = start + 1;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ assert(end > start);
+ if (end < buf.size()) buf[end] = 0;
+ val = strtod(&buf.c_str()[start], NULL);
+ v.set_value(fid, val);
+ start = end + 1;
+ }
+ }
+}
-
+void WordPairFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ if (edge.Arity() == 0) {
+ assert(edge.rule_->EWords() == 1);
+ assert(edge.rule_->FWords() == 1);
+ const WordID trg = edge.rule_->e()[0];
+ const WordID src = edge.rule_->f()[0];
+ size_t ind = distance(fkeys_.begin(), lower_bound(fkeys_.begin(), fkeys_.end(), src));
+ if (ind == fkeys_.size() || fkeys_[ind] != src) {
+ cerr << "WordPairFeatures no source entries for " << TD::Convert(src) << endl;
+ abort();
+ }
+ const map<WordID, SparseVector<float> >::const_iterator it = values_[ind].find(trg);
+ // TODO optional strict flag to make sure there are features for all pairs?
+ if (it != values_[ind].end())
+ (*features) += it->second;
+ }
+}
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index 0714229c..418c8768 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -103,6 +103,43 @@ class SourceBigram : public FeatureFunction {
mutable Class2Class2FID fmap_;
};
+class LexNullJump : public FeatureFunction {
+ public:
+ LexNullJump(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_lex_null_;
+ const int fid_null_lex_;
+ const int fid_null_null_;
+ const int fid_lex_lex_;
+};
+
+class NewJump : public FeatureFunction {
+ public:
+ NewJump(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ void FireFeature(const SentenceMetadata& smeta,
+ const int prev_src_index,
+ const int cur_src_index,
+ SparseVector<double>* features) const;
+
+ bool use_binned_log_lengths_;
+ std::string fid_str_; // identifies configuration uniquely
+};
+
class SourcePOSBigram : public FeatureFunction {
public:
SourcePOSBigram(const std::string& param);
@@ -238,6 +275,24 @@ class BlunsomSynchronousParseHack : public FeatureFunction {
mutable std::vector<std::vector<WordID> > refs_;
};
+// association feature type look up a pair (e,f) in a table and return a vector
+// of feature values
+class WordPairFeatures : public FeatureFunction {
+ public:
+ WordPairFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+
+ private:
+ std::vector<WordID> fkeys_; // parallel to values_
+ std::vector<std::map<WordID, SparseVector<float> > > values_; // fkeys_index -> e -> value
+};
+
class InputIdentity : public FeatureFunction {
public:
InputIdentity(const std::string& param);
diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc
index 4476fe63..35d2d15d 100644
--- a/decoder/lextrans.cc
+++ b/decoder/lextrans.cc
@@ -76,13 +76,13 @@ struct LexicalTransImpl {
// hack to tell the feature function system how big the sentence pair is
const int f_start = (use_null ? -1 : 0);
int prev_node_id = -1;
- set<WordID> target_vocab; // only set for alignment_only mode
- if (align_only_) {
- const Lattice& ref = smeta.GetReference();
- for (int i = 0; i < ref.size(); ++i) {
- target_vocab.insert(ref[i][0].label);
- }
+ set<WordID> target_vocab;
+ const Lattice& ref = smeta.GetReference();
+ for (int i = 0; i < ref.size(); ++i) {
+ target_vocab.insert(ref[i][0].label);
}
+ bool all_sources_to_all_targets_ = true;
+ set<WordID> trgs_used;
for (int i = 0; i < e_len; ++i) { // for each word in the *target*
Hypergraph::Node* node = forest->AddNode(kXCAT);
const int new_node_id = node->id_;
@@ -101,10 +101,13 @@ struct LexicalTransImpl {
assert(rb);
for (int k = 0; k < rb->GetNumRules(); ++k) {
TRulePtr rule = rb->GetIthRule(k);
+ const WordID trg_word = rule->e_[0];
if (align_only_) {
- if (target_vocab.count(rule->e_[0]) == 0)
+ if (target_vocab.count(trg_word) == 0)
continue;
}
+ if (all_sources_to_all_targets_ && (target_vocab.count(trg_word) > 0))
+ trgs_used.insert(trg_word);
Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
edge->i_ = j;
edge->j_ = j+1;
@@ -113,6 +116,21 @@ struct LexicalTransImpl {
edge->feature_values_ += edge->rule_->GetFeatureValues();
forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
}
+ if (all_sources_to_all_targets_) {
+ for (set<WordID>::iterator it = target_vocab.begin(); it != target_vocab.end(); ++it) {
+ if (trgs_used.count(*it)) continue;
+ const WordID ungenerated_trg_word = *it;
+ TRulePtr rule;
+ rule.reset(TRule::CreateLexicalRule(src_sym, ungenerated_trg_word));
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = j;
+ edge->j_ = j+1;
+ edge->prev_i_ = i;
+ edge->prev_j_ = i+1;
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ }
+ trgs_used.clear();
+ }
}
if (prev_node_id >= 0) {
const int comb_node_id = forest->AddNode(kXCAT)->id_;
diff --git a/decoder/trule.cc b/decoder/trule.cc
index a40c4e14..eedf8f30 100644
--- a/decoder/trule.cc
+++ b/decoder/trule.cc
@@ -246,18 +246,18 @@ string TRule::AsString(bool verbose) const {
int idx = 0;
if (lhs_ && verbose) {
os << '[' << TD::Convert(lhs_ * -1) << "] |||";
- for (int i = 0; i < f_.size(); ++i) {
- const WordID& w = f_[i];
- if (w < 0) {
- int wi = w * -1;
- ++idx;
- os << " [" << TD::Convert(wi) << ',' << idx << ']';
- } else {
- os << ' ' << TD::Convert(w);
- }
+ }
+ for (int i = 0; i < f_.size(); ++i) {
+ const WordID& w = f_[i];
+ if (w < 0) {
+ int wi = w * -1;
+ ++idx;
+ os << " [" << TD::Convert(wi) << ',' << idx << ']';
+ } else {
+ os << ' ' << TD::Convert(w);
}
- os << " ||| ";
}
+ os << " ||| ";
if (idx > 9) {
cerr << "Too many non-terminals!\n partial: " << os.str() << endl;
exit(1);