summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/cdec_ff.cc9
-rw-r--r--decoder/ff_klm.cc38
-rw-r--r--decoder/ff_klm.h7
-rw-r--r--decoder/ff_ngrams.cc341
-rw-r--r--decoder/ff_ngrams.h29
-rw-r--r--decoder/ff_rules.cc107
-rw-r--r--decoder/ff_rules.h40
-rw-r--r--decoder/ff_spans.cc74
-rw-r--r--decoder/ff_spans.h15
10 files changed, 587 insertions, 75 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 244da2de..e5f7505f 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -61,10 +61,12 @@ libcdec_a_SOURCES = \
phrasetable_fst.cc \
trule.cc \
ff.cc \
+ ff_rules.cc \
ff_wordset.cc \
ff_charset.cc \
ff_lm.cc \
ff_klm.cc \
+ ff_ngrams.cc \
ff_spans.cc \
ff_ruleshape.cc \
ff_wordalign.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 31f88a4f..588842f1 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -4,10 +4,12 @@
#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
+#include "ff_ngrams.h"
#include "ff_csplit.h"
#include "ff_wordalign.h"
#include "ff_tagger.h"
#include "ff_factory.h"
+#include "ff_rules.h"
#include "ff_ruleshape.h"
#include "ff_bleu.h"
#include "ff_lm_fsa.h"
@@ -51,12 +53,11 @@ void register_feature_functions() {
ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);
#endif
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
+ ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
+ ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
ff_registry.Register("RuleNgramFeatures", new FFFactory<RuleNgramFeatures>());
ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>());
- ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
- ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >());
- ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory<KLanguageModel<lm::ngram::QuantTrieModel> >());
- ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
+ ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 9b7fe2d3..24dcb9c3 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -9,6 +9,7 @@
#include "stringlib.h"
#include "hg.h"
#include "tdict.h"
+#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
using namespace std;
@@ -434,8 +435,37 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
features->set_value(oov_fid_, oovs);
}
-// instantiate templates
-template class KLanguageModel<lm::ngram::ProbingModel>;
-template class KLanguageModel<lm::ngram::TrieModel>;
-template class KLanguageModel<lm::ngram::QuantTrieModel>;
+template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string &param) {
+ KLanguageModel<Model> *ret = new KLanguageModel<Model>(param);
+ ret->Init();
+ return boost::shared_ptr<FeatureFunction>(ret);
+}
+boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const {
+ using namespace lm::ngram;
+ std::string filename, ignored_map;
+ bool ignored_markers;
+ std::string ignored_featname;
+ ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname);
+ ModelType m;
+ if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING;
+
+ switch (m) {
+ case HASH_PROBING:
+ return CreateModel<ProbingModel>(param);
+ case TRIE_SORTED:
+ return CreateModel<TrieModel>(param);
+ case ARRAY_TRIE_SORTED:
+ return CreateModel<ArrayTrieModel>(param);
+ case QUANT_TRIE_SORTED:
+ return CreateModel<QuantTrieModel>(param);
+ case QUANT_ARRAY_TRIE_SORTED:
+ return CreateModel<QuantArrayTrieModel>(param);
+ default:
+ UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m);
+ }
+}
+
+std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
+ return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
+}
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index 5eafe8be..6efe50f6 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -4,8 +4,8 @@
#include <vector>
#include <string>
+#include "ff_factory.h"
#include "ff.h"
-#include "lm/model.hh"
template <class Model> struct KLanguageModelImpl;
@@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction {
KLanguageModelImpl<Model>* pimpl_;
};
+struct KLanguageModelFactory : public FactoryBase<FeatureFunction> {
+ FP Create(std::string param) const;
+ std::string usage(bool params,bool verbose) const;
+};
+
#endif
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
new file mode 100644
index 00000000..04dd1906
--- /dev/null
+++ b/decoder/ff_ngrams.cc
@@ -0,0 +1,341 @@
+#include "ff_ngrams.h"
+
+#include <cstring>
+#include <iostream>
+
+#include <boost/scoped_ptr.hpp>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "hg.h"
+#include "tdict.h"
+
+using namespace std;
+
+static const unsigned char HAS_FULL_CONTEXT = 1;
+static const unsigned char HAS_EOS_ON_RIGHT = 2;
+static const unsigned char MASK = 7;
+
+namespace {
+template <unsigned MAX_ORDER = 5>
+struct State {
+ explicit State() {
+ memset(state, 0, sizeof(state));
+ }
+ explicit State(int order) {
+ memset(state, 0, (order - 1) * sizeof(WordID));
+ }
+ State<MAX_ORDER>(char order, const WordID* mem) {
+ memcpy(state, mem, (order - 1) * sizeof(WordID));
+ }
+ State(const State<MAX_ORDER>& other) {
+ memcpy(state, other.state, sizeof(state));
+ }
+ const State& operator=(const State<MAX_ORDER>& other) {
+ memcpy(state, other.state, sizeof(state));
+ }
+ explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) {
+ char om1 = order - 1;
+ assert(om1 > 0);
+ for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i];
+ state[om1 - 1] = extend;
+ }
+ const WordID& operator[](size_t i) const { return state[i]; }
+ WordID& operator[](size_t i) { return state[i]; }
+ WordID state[MAX_ORDER];
+};
+}
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+class NgramDetectorImpl {
+
+ // returns the number of unscored words at the left edge of a span
+ inline int UnscoredSize(const void* state) const {
+ return *(static_cast<const char*>(state) + unscored_size_offset_);
+ }
+
+ inline void SetUnscoredSize(int size, void* state) const {
+ *(static_cast<char*>(state) + unscored_size_offset_) = size;
+ }
+
+ inline State<5> RemnantLMState(const void* cstate) const {
+ return State<5>(order_, static_cast<const WordID*>(cstate));
+ }
+
+ inline const State<5> BeginSentenceState() const {
+ State<5> state(order_);
+ state.state[0] = kSOS_;
+ return state;
+ }
+
+ inline void SetRemnantLMState(const State<5>& lmstate, void* state) const {
+ // if we were clever, we could use the memory pointed to by state to do all
+ // the work, avoiding this copy
+ memcpy(state, lmstate.state, (order_-1) * sizeof(WordID));
+ }
+
+ WordID IthUnscoredWord(int i, const void* state) const {
+ const WordID* const mem = reinterpret_cast<const WordID*>(static_cast<const char*>(state) + unscored_words_offset_);
+ return mem[i];
+ }
+
+ void SetIthUnscoredWord(int i, const WordID index, void *state) const {
+ WordID* mem = reinterpret_cast<WordID*>(static_cast<char*>(state) + unscored_words_offset_);
+ mem[i] = index;
+ }
+
+ inline bool GetFlag(const void *state, unsigned char flag) const {
+ return (*(static_cast<const char*>(state) + is_complete_offset_) & flag);
+ }
+
+ inline void SetFlag(bool on, unsigned char flag, void *state) const {
+ if (on) {
+ *(static_cast<char*>(state) + is_complete_offset_) |= flag;
+ } else {
+ *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag);
+ }
+ }
+
+ inline bool HasFullContext(const void *state) const {
+ return GetFlag(state, HAS_FULL_CONTEXT);
+ }
+
+ inline void SetHasFullContext(bool flag, void *state) const {
+ SetFlag(flag, HAS_FULL_CONTEXT, state);
+ }
+
+ void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) {
+ FidTree* ft = &fidroot_;
+ int n = 0;
+ WordID buf[10];
+ int ci = order_ - 1;
+ WordID curword = cur;
+ while(curword) {
+ buf[n] = curword;
+ int& fid = ft->fids[curword];
+ ++n;
+ if (!fid) {
+ const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.)
+ ostringstream os;
+ os << code[n] << ':';
+ for (int i = n-1; i >= 0; --i) {
+ os << (i != n-1 ? "_" : "");
+ const string& tok = TD::Convert(buf[i]);
+ if (tok.find('=') == string::npos)
+ os << tok;
+ else
+ os << Escape(tok);
+ }
+ fid = FD::Convert(os.str());
+ }
+ feats->set_value(fid, 1);
+ ft = &ft->levels[curword];
+ --ci;
+ if (ci < 0) break;
+ curword = state[ci];
+ }
+ }
+
+ public:
+ void LookupWords(const TRule& rule, const vector<const void*>& ant_states, SparseVector<double>* feats, SparseVector<double>* est_feats, void* remnant) {
+ double sum = 0.0;
+ double est_sum = 0.0;
+ int num_scored = 0;
+ int num_estimated = 0;
+ bool saw_eos = false;
+ bool has_some_history = false;
+ State<5> state;
+ const vector<WordID>& e = rule.e();
+ bool context_complete = false;
+ for (int j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) { // handle non-terminal substitution
+ const void* astate = (ant_states[-e[j]]);
+ int unscored_ant_len = UnscoredSize(astate);
+ for (int k = 0; k < unscored_ant_len; ++k) {
+ const WordID cur_word = IthUnscoredWord(k, astate);
+ const bool is_oov = (cur_word == 0);
+ SparseVector<double> p;
+ if (cur_word == kSOS_) {
+ state = BeginSentenceState();
+ if (has_some_history) { // this is immediately fully scored, and bad
+ p.set_value(FD::Convert("Malformed"), 1.0);
+ context_complete = true;
+ } else { // this might be a real <s>
+ num_scored = max(0, order_ - 2);
+ }
+ } else {
+ FireFeatures(state, cur_word, &p);
+ const State<5> scopy = State<5>(state, order_, cur_word);
+ state = scopy;
+ if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); }
+ saw_eos = (cur_word == kEOS_);
+ }
+ has_some_history = true;
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ (*feats) += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ (*est_feats) += p;
+ }
+ }
+ saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
+ if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
+ state = RemnantLMState(astate);
+ context_complete = true;
+ }
+ } else { // handle terminal
+ const WordID cur_word = e[j];
+ SparseVector<double> p;
+ if (cur_word == kSOS_) {
+ state = BeginSentenceState();
+ if (has_some_history) { // this is immediately fully scored, and bad
+ p.set_value(FD::Convert("Malformed"), -100);
+ context_complete = true;
+ } else { // this might be a real <s>
+ num_scored = max(0, order_ - 2);
+ }
+ } else {
+ FireFeatures(state, cur_word, &p);
+ const State<5> scopy = State<5>(state, order_, cur_word);
+ state = scopy;
+ if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); }
+ saw_eos = (cur_word == kEOS_);
+ }
+ has_some_history = true;
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ (*feats) += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ (*est_feats) += p;
+ }
+ }
+ }
+ if (remnant) {
+ SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant);
+ SetRemnantLMState(state, remnant);
+ SetUnscoredSize(num_estimated, remnant);
+ SetHasFullContext(context_complete || (num_scored >= order_), remnant);
+ }
+ }
+
+ // this assumes no target words on final unary -> goal rule. is that ok?
+ // for <s> (n-1 left words) and (n-1 right words) </s>
+ void FinalTraversal(const void* state, SparseVector<double>* feats) {
+ if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
+ SetRemnantLMState(BeginSentenceState(), dummy_state_);
+ SetHasFullContext(1, dummy_state_);
+ SetUnscoredSize(0, dummy_state_);
+ dummy_ants_[1] = state;
+ LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL);
+ } else { // rules DO produce <s> ... </s>
+#if 0
+ double p = 0;
+ if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
+ if (UnscoredSize(state) > 0) { // are there unscored words
+ if (kSOS_ != IthUnscoredWord(0, state)) {
+ p -= 100 * UnscoredSize(state);
+ }
+ }
+ return p;
+#endif
+ }
+ }
+
+ public:
+ explicit NgramDetectorImpl(bool explicit_markers) :
+ kCDEC_UNK(TD::Convert("<unk>")) ,
+ add_sos_eos_(!explicit_markers) {
+ order_ = 3;
+ state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID);
+ unscored_size_offset_ = (order_ - 1) * sizeof(WordID);
+ is_complete_offset_ = unscored_size_offset_ + 1;
+ unscored_words_offset_ = is_complete_offset_ + 1;
+
+ // special handling of beginning / ending sentence markers
+ dummy_state_ = new char[state_size_];
+ memset(dummy_state_, 0, state_size_);
+ dummy_ants_.push_back(dummy_state_);
+ dummy_ants_.push_back(NULL);
+ dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
+ kSOS_ = TD::Convert("<s>");
+ kEOS_ = TD::Convert("</s>");
+ }
+
+ ~NgramDetectorImpl() {
+ delete[] dummy_state_;
+ }
+
+ int ReserveStateSize() const { return state_size_; }
+
+ private:
+ const WordID kCDEC_UNK;
+ WordID kSOS_; // <s> - requires special handling.
+ WordID kEOS_; // </s>
+ const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
+ // if this is true, FinalTransitionFeatures will "add" <s> and </s>
+ // if false, FinalTransitionFeatures will score anything with the
+ // markers in the right place (i.e., the beginning and end of
+ // the sentence) with 0, and anything else with -100
+
+ int order_;
+ int state_size_;
+ int unscored_size_offset_;
+ int is_complete_offset_;
+ int unscored_words_offset_;
+ char* dummy_state_;
+ vector<const void*> dummy_ants_;
+ TRulePtr dummy_rule_;
+ struct FidTree {
+ map<WordID, int> fids;
+ map<WordID, FidTree> levels;
+ };
+ mutable FidTree fidroot_;
+};
+
+NgramDetector::NgramDetector(const string& param) {
+ string filename, mapfile, featname;
+ bool explicit_markers = (param == "-x");
+ pimpl_ = new NgramDetectorImpl(explicit_markers);
+ SetStateSize(pimpl_->ReserveStateSize());
+}
+
+NgramDetector::~NgramDetector() {
+ delete pimpl_;
+}
+
+void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state);
+}
+
+void NgramDetector::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ pimpl_->FinalTraversal(ant_state, features);
+}
+
diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h
new file mode 100644
index 00000000..82f61b33
--- /dev/null
+++ b/decoder/ff_ngrams.h
@@ -0,0 +1,29 @@
+#ifndef _NGRAMS_FF_H_
+#define _NGRAMS_FF_H_
+
+#include <vector>
+#include <map>
+#include <string>
+
+#include "ff.h"
+
+struct NgramDetectorImpl;
+class NgramDetector : public FeatureFunction {
+ public:
+ // param = "filename.lm [-o n]"
+ NgramDetector(const std::string& param);
+ ~NgramDetector();
+ virtual void FinalTraversalFeatures(const void* context,
+ SparseVector<double>* features) const;
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ NgramDetectorImpl* pimpl_;
+};
+
+#endif
diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc
new file mode 100644
index 00000000..bd4c4cc0
--- /dev/null
+++ b/decoder/ff_rules.cc
@@ -0,0 +1,107 @@
+#include "ff_rules.h"
+
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) {
+}
+
+void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+// std::map<const TRule*, SparseVector<double> >
+ rule2_fid_.clear();
+}
+
+void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ map<const TRule*, int>::iterator it = rule2_fid_.find(edge.rule_.get());
+ if (it == rule2_fid_.end()) {
+ const TRule& rule = *edge.rule_;
+ ostringstream os;
+ os << "R:";
+ if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':';
+ for (unsigned i = 0; i < rule.f_.size(); ++i) {
+ if (i > 0) os << '_';
+ WordID w = rule.f_[i];
+ if (w < 0) { os << 'N'; w = -w; }
+ assert(w > 0);
+ os << TD::Convert(w);
+ }
+ os << ':';
+ for (unsigned i = 0; i < rule.e_.size(); ++i) {
+ if (i > 0) os << '_';
+ WordID w = rule.e_[i];
+ if (w <= 0) {
+ os << 'N' << (1-w);
+ } else {
+ os << TD::Convert(w);
+ }
+ }
+ it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first;
+ }
+ features->add_value(it->second, 1);
+}
+
+RuleNgramFeatures::RuleNgramFeatures(const std::string& param) {
+}
+
+void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+// std::map<const TRule*, SparseVector<double> >
+ rule2_feats_.clear();
+}
+
+void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get());
+ if (it == rule2_feats_.end()) {
+ const TRule& rule = *edge.rule_;
+ it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
+ SparseVector<double>& f = it->second;
+ string prev = "<r>";
+ for (int i = 0; i < rule.f_.size(); ++i) {
+ WordID w = rule.f_[i];
+ if (w < 0) w = -w;
+ assert(w > 0);
+ const string& cur = TD::Convert(w);
+ ostringstream os;
+ os << "RB:" << prev << '_' << cur;
+ const int fid = FD::Convert(Escape(os.str()));
+ if (fid <= 0) return;
+ f.add_value(fid, 1.0);
+ prev = cur;
+ }
+ ostringstream os;
+ os << "RB:" << prev << '_' << "</r>";
+ f.set_value(FD::Convert(Escape(os.str())), 1.0);
+ }
+ (*features) += it->second;
+}
+
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
new file mode 100644
index 00000000..48d8bd05
--- /dev/null
+++ b/decoder/ff_rules.h
@@ -0,0 +1,40 @@
+#ifndef _FF_RULES_H_
+#define _FF_RULES_H_
+
+#include <vector>
+#include <map>
+#include "ff.h"
+#include "array2d.h"
+#include "wordid.h"
+
+class RuleIdentityFeatures : public FeatureFunction {
+ public:
+ RuleIdentityFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ mutable std::map<const TRule*, int> rule2_fid_;
+};
+
+class RuleNgramFeatures : public FeatureFunction {
+ public:
+ RuleNgramFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
+};
+
+#endif
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index e1da088d..0483517b 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -13,6 +13,17 @@
using namespace std;
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
// log transform to make long spans cluster together
// but preserve differences
int SpanSizeTransform(unsigned span_size) {
@@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
word = MapIfNecessary(word);
ostringstream sfid;
sfid << "ES:" << TD::Convert(word);
- end_span_ids_[i] = FD::Convert(sfid.str());
+ end_span_ids_[i] = FD::Convert(Escape(sfid.str()));
ostringstream esbiid;
esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word);
- end_bigram_ids_[i] = FD::Convert(esbiid.str());
+ end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str()));
ostringstream bsbiid;
bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word);
- beg_bigram_ids_[i] = FD::Convert(bsbiid.str());
+ beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str()));
ostringstream bfid;
bfid << "BS:" << TD::Convert(bword);
- beg_span_ids_[i] = FD::Convert(bfid.str());
+ beg_span_ids_[i] = FD::Convert(Escape(bfid.str()));
if (use_collapsed_features_) {
- end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()];
- beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()];
+ end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())];
+ beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())];
}
}
for (int i = 0; i <= lattice.size(); ++i) {
@@ -167,60 +178,21 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
word = MapIfNecessary(word);
ostringstream pf;
pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word);
- span_feats_(i,j).first = FD::Convert(pf.str());
- span_feats_(i,j).second = FD::Convert("S_" + pf.str());
+ span_feats_(i,j).first = FD::Convert(Escape(pf.str()));
+ span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str()));
ostringstream lf;
const unsigned span_size = (i < j ? j - i : i - j);
lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word);
- len_span_feats_(i,j).first = FD::Convert(lf.str());
- len_span_feats_(i,j).second = FD::Convert("S_" + lf.str());
+ len_span_feats_(i,j).first = FD::Convert(Escape(lf.str()));
+ len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str()));
if (use_collapsed_features_) {
- span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()];
- span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()];
+ span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())];
+ span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())];
}
}
}
}
-RuleNgramFeatures::RuleNgramFeatures(const std::string& param) {
-}
-
-void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
-// std::map<const TRule*, SparseVector<double> >
- rule2_feats_.clear();
-}
-
-void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get());
- if (it == rule2_feats_.end()) {
- const TRule& rule = *edge.rule_;
- it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
- SparseVector<double>& f = it->second;
- string prev = "<r>";
- for (int i = 0; i < rule.f_.size(); ++i) {
- WordID w = rule.f_[i];
- if (w < 0) w = -w;
- assert(w > 0);
- const string& cur = TD::Convert(w);
- ostringstream os;
- os << "RB:" << prev << '_' << cur;
- const int fid = FD::Convert(os.str());
- if (fid <= 0) return;
- f.add_value(fid, 1.0);
- prev = cur;
- }
- ostringstream os;
- os << "RB:" << prev << '_' << "</r>";
- f.set_value(FD::Convert(os.str()), 1.0);
- }
- (*features) += it->second;
-}
-
inline bool IsArity2RuleReordered(const TRule& rule) {
const vector<WordID>& e = rule.e_;
for (int i = 0; i < e.size(); ++i) {
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index b22c4d03..24e0dede 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction {
WordID oov_;
};
-class RuleNgramFeatures : public FeatureFunction {
- public:
- RuleNgramFeatures(const std::string& param);
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
- private:
- mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
-};
-
class CMR2008ReorderingFeatures : public FeatureFunction {
public:
CMR2008ReorderingFeatures(const std::string& param);