summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/ff_context.cc99
-rw-r--r--decoder/ff_context.h23
4 files changed, 125 insertions, 0 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 30eaf04d..a00b18af 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -63,6 +63,7 @@ libcdec_a_SOURCES = \
ff.cc \
ff_rules.cc \
ff_wordset.cc \
+ ff_context.cc \
ff_charset.cc \
ff_lm.cc \
ff_klm.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 4ce5749e..b516c386 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -1,6 +1,7 @@
#include <boost/shared_ptr.hpp>
#include "ff.h"
+#include "ff_context.h"
#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
@@ -42,6 +43,7 @@ void register_feature_functions() {
#endif
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
+ ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>());
ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>);
ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>);
diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc
new file mode 100644
index 00000000..19f9a413
--- /dev/null
+++ b/decoder/ff_context.cc
@@ -0,0 +1,99 @@
+#include "ff_context.h"
+
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+RuleContextFeatures::RuleContextFeatures(const std::string& param) {
+ kSOS = TD::Convert("<s>");
+ kEOS = TD::Convert("</s>");
+
+ // TODO param lets you pass in a string from the cdec.ini file
+}
+
+void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ const Lattice& sl = smeta.GetSourceLattice();
+ current_input.resize(sl.size());
+ for (unsigned i = 0; i < sl.size(); ++i) {
+ if (sl[i].size() != 1) {
+ cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl;
+ abort();
+ }
+ current_input[i] = sl[i][0].label;
+ }
+}
+
+void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ const TRule& rule = *edge.rule_;
+
+ if (rule.Arity() != 0 || // arity = 0, no nonterminals
+ rule.e_.size() != 1) return; // size = 1, predicted label is a single token
+
+
+ // you can see the current label "for free"
+ const WordID cur_label = rule.e_[0];
+ // (if you want to see more labels, you have to be very careful, and muck
+ // about with contexts and ant_contexts)
+
+ // but... you can look at as much of the source as you want!
+ const int from_src_index = edge.i_; // start of the span in the input being labeled
+ const int to_src_index = edge.j_; // end of the span in the input
+ // (note: in the case of tagging the size of the spans being labeled will
+ // always be 1, but in other formalisms, you can have bigger spans.)
+
+ // this is the current token being labeled:
+ const WordID cur_input = current_input[from_src_index];
+
+ // let's get the previous token in the input (may be to the left of the start
+ // of the sentence!)
+ WordID prev_input = kSOS;
+ if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; }
+ // let's get the next token (may be to the left of the start of the sentence!)
+ WordID next_input = kEOS;
+ if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; }
+
+ // now, build a feature string
+ ostringstream os;
+ // TD::Convert converts from the internal integer representation of a token
+ // to the actual token
+ os << "C1:" << TD::Convert(prev_input) << '_'
+ << TD::Convert(cur_input) << '|' << TD::Convert(cur_label);
+ // C1 is just to prevent a name clash
+
+ // pick a value
+ double fval = 1.0; // can be any real value
+
+ // add it to the feature vector FD::Convert converts the feature string to a
+ // feature int, Escape makes sure the feature string doesn't have any bad
+ // symbols that could confuse a parser somewhere
+ features->add_value(FD::Convert(Escape(os.str())), fval);
+ // that's it!
+
+ // create more features if you like...
+}
+
diff --git a/decoder/ff_context.h b/decoder/ff_context.h
new file mode 100644
index 00000000..0d22b027
--- /dev/null
+++ b/decoder/ff_context.h
@@ -0,0 +1,23 @@
+#ifndef _FF_CONTEXT_H_
+#define _FF_CONTEXT_H_
+
+#include <vector>
+#include "ff.h"
+
+class RuleContextFeatures : public FeatureFunction {
+ public:
+ RuleContextFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ std::vector<WordID> current_input;
+ WordID kSOS, kEOS;
+};
+
+#endif