summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-06-18 18:34:21 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-06-18 18:34:21 -0400
commit953ec50e659084c13433ea311f6a07e7e1b292f8 (patch)
treee04361bc661740a08a8bf045aa8c4e5bcd691d85 /decoder
parentd9d602cf26e3696a5e575f314547b823254dba32 (diff)
parent681ba488e82b8e2079f2d6c2af63014fa86834ac (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_context.cc262
-rw-r--r--decoder/ff_context.h29
2 files changed, 222 insertions, 69 deletions
diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc
index 19f9a413..9de4d737 100644
--- a/decoder/ff_context.cc
+++ b/decoder/ff_context.cc
@@ -1,5 +1,6 @@
#include "ff_context.h"
+#include <stdlib.h>
#include <sstream>
#include <cassert>
#include <cmath>
@@ -11,24 +12,150 @@
#include "fdict.h"
#include "verbose.h"
-using namespace std;
+RuleContextFeatures::RuleContextFeatures(const string& param) {
+ // cerr << "initializing RuleContextFeatures with parameters: " << param;
+ kSOS = TD::Convert("<s>");
+ kEOS = TD::Convert("</s>");
+ macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]");
+ ParseArgs(param);
+}
-namespace {
- string Escape(const string& x) {
- string y = x;
- for (int i = 0; i < y.size(); ++i) {
- if (y[i] == '=') y[i]='_';
- if (y[i] == ';') y[i]='_';
- }
- return y;
+string RuleContextFeatures::Escape(const string& x) const {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
}
+ return y;
}
-RuleContextFeatures::RuleContextFeatures(const std::string& param) {
- kSOS = TD::Convert("<s>");
- kEOS = TD::Convert("</s>");
+// replace %x[relative_location] or %y[relative_location] with actual_token
+// within feature_instance
+void RuleContextFeatures::ReplaceMacroWithString(
+ string& feature_instance, bool token_vs_label, int relative_location,
+ const string& actual_token) const {
+
+ stringstream macro;
+ if (token_vs_label) {
+ macro << "%x[";
+ } else {
+ macro << "%y[";
+ }
+ macro << relative_location << "]";
+ int macro_index = feature_instance.find(macro.str());
+ if (macro_index == string::npos) {
+ cerr << "Can't find macro " << macro << " in feature template "
+ << feature_instance;
+ abort();
+ }
+ feature_instance.replace(macro_index, macro.str().size(), actual_token);
+}
+
+void RuleContextFeatures::ReplaceTokenMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, true, relative_location,
+ actual_token);
+}
- // TODO param lets you pass in a string from the cdec.ini file
+void RuleContextFeatures::ReplaceLabelMacroWithString(
+ string& feature_instance, int relative_location,
+ const string& actual_token) const {
+
+ ReplaceMacroWithString(feature_instance, false, relative_location,
+ actual_token);
+}
+
+void RuleContextFeatures::Error(const string& error_message) const {
+ cerr << "Error: " << error_message << "\n\n"
+
+ << "RuleContextFeatures Usage: \n"
+ << " feature_function=RuleContextFeatures -t <TEMPLATE>\n\n"
+
+ << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n"
+
+ << "%x[k] and %y[k] are macros to be instantiated with an input\n"
+ << "token (for x) or a label (for y). k specifies the relative\n"
+ << "location of the input token or label with respect to the current\n"
+ << "position. For x, k is an integer value. For y, k must be 0 (to\n"
+ << "be extended).\n\n";
+
+ abort();
+}
+
+void RuleContextFeatures::ParseArgs(const string& in) {
+ vector<string> const& argv = SplitOnWhitespace(in);
+ for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) {
+ string const& s = *i;
+ if (s[0] == '-') {
+ if (s.size() > 2) {
+ stringstream msg;
+ msg << s << " is an invalid option for RuleContextFeatures.";
+ Error(msg.str());
+ }
+
+ switch (s[1]) {
+
+ // feature template
+ case 't': {
+ if (++i == argv.end()) {
+ Error("Can't find template.");
+ }
+ feature_template = *i;
+ string::const_iterator start = feature_template.begin();
+ string::const_iterator end = feature_template.end();
+ smatch macro_match;
+
+ // parse the template
+ while (regex_search(start, end, macro_match, macro_regex)) {
+ // get the relative location
+ string relative_location_str(macro_match[2].first,
+ macro_match[2].second);
+ int relative_location = atoi(relative_location_str.c_str());
+ // add it to the list of relative locations for token or label
+ // (i.e. x or y)
+ bool valid_location = true;
+ if (*macro_match[1].first == 'x') {
+ // add it to token locations
+ token_relative_locations.push_back(relative_location);
+ } else {
+ if (relative_location != 0) { valid_location = false; }
+ // add it to label locations
+ label_relative_locations.push_back(relative_location);
+ }
+ if (!valid_location) {
+ stringstream msg;
+ msg << "Relative location " << relative_location
+ << " in feature template " << feature_template
+ << " is invalid.";
+ Error(msg.str());
+ }
+ start = macro_match[0].second;
+ }
+ break;
+ }
+
+ // TODO: arguments to specify kSOS and kEOS
+
+ default: {
+ stringstream msg;
+ msg << "Invalid option on RuleContextFeatures: " << s;
+ Error(msg.str());
+ break;
+ }
+ } // end of switch
+ } // end of if (token starts with hyphen)
+ } // end of for loop (over arguments)
+
+ // the -t (i.e. template) option is mandatory in this feature function
+ if (label_relative_locations.size() == 0 ||
+ token_relative_locations.size() == 0) {
+ stringstream msg;
+ msg << "Feature template must specify at least one"
+ << "token macro (e.g. x[-1]) and one label macro (e.g. y[0]).";
+ Error(msg.str());
+ }
}
void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
@@ -36,64 +163,67 @@ void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
current_input.resize(sl.size());
for (unsigned i = 0; i < sl.size(); ++i) {
if (sl[i].size() != 1) {
- cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl;
- abort();
+ stringstream msg;
+ msg << "RuleContextFeatures don't support lattice inputs!\nid="
+ << smeta.GetSentenceId() << endl;
+ Error(msg.str());
}
current_input[i] = sl[i][0].label;
}
}
-void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
+void RuleContextFeatures::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta, const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts, SparseVector<double>* features,
+ SparseVector<double>* estimated_features, void* context) const {
+
const TRule& rule = *edge.rule_;
+ // arity = 0, no nonterminals
+ // size = 1, predicted label is a single token
+ if (rule.Arity() != 0 ||
+ rule.e_.size() != 1) {
+ return;
+ }
- if (rule.Arity() != 0 || // arity = 0, no nonterminals
- rule.e_.size() != 1) return; // size = 1, predicted label is a single token
-
-
- // you can see the current label "for free"
- const WordID cur_label = rule.e_[0];
- // (if you want to see more labels, you have to be very careful, and muck
- // about with contexts and ant_contexts)
-
- // but... you can look at as much of the source as you want!
- const int from_src_index = edge.i_; // start of the span in the input being labeled
- const int to_src_index = edge.j_; // end of the span in the input
- // (note: in the case of tagging the size of the spans being labeled will
- // always be 1, but in other formalisms, you can have bigger spans.)
-
- // this is the current token being labeled:
- const WordID cur_input = current_input[from_src_index];
-
- // let's get the previous token in the input (may be to the left of the start
- // of the sentence!)
- WordID prev_input = kSOS;
- if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; }
- // let's get the next token (may be to the left of the start of the sentence!)
- WordID next_input = kEOS;
- if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; }
-
- // now, build a feature string
- ostringstream os;
- // TD::Convert converts from the internal integer representation of a token
- // to the actual token
- os << "C1:" << TD::Convert(prev_input) << '_'
- << TD::Convert(cur_input) << '|' << TD::Convert(cur_label);
- // C1 is just to prevent a name clash
-
- // pick a value
- double fval = 1.0; // can be any real value
-
- // add it to the feature vector FD::Convert converts the feature string to a
- // feature int, Escape makes sure the feature string doesn't have any bad
- // symbols that could confuse a parser somewhere
- features->add_value(FD::Convert(Escape(os.str())), fval);
- // that's it!
-
- // create more features if you like...
-}
+ // replace label macros with actual label strings
+ // NOTE: currently, this feature function doesn't allow any label
+ // macros except %y[0]. but you can look at as much of the source as you want
+ const WordID y0 = rule.e_[0];
+ string y0_str = TD::Convert(y0);
+
+ // start of the span in the input being labeled
+ const int from_src_index = edge.i_;
+ // end of the span in the input
+ const int to_src_index = edge.j_;
+
+ // in the case of tagging the size of the spans being labeled will
+ // always be 1, but in other formalisms, you can have bigger spans
+ if (to_src_index - from_src_index != 1) {
+ cerr << "RuleContextFeatures doesn't support input spans of length != 1";
+ abort();
+ }
+ string feature_instance = feature_template;
+ // replace token macros with actual token strings
+ for (unsigned i = 0; i < token_relative_locations.size(); ++i) {
+ int loc = token_relative_locations[i];
+ WordID x = loc < 0? kSOS: kEOS;
+ if(from_src_index + loc >= 0 &&
+ from_src_index + loc < current_input.size()) {
+ x = current_input[from_src_index + loc];
+ }
+ string x_str = TD::Convert(x);
+ ReplaceTokenMacroWithString(feature_instance, loc, x_str);
+ }
+
+ ReplaceLabelMacroWithString(feature_instance, 0, y0_str);
+
+ // pick a real value for this feature
+ double fval = 1.0;
+
+ // add it to the feature vector
+ // FD::Convert converts the feature string to a feature int
+ // Escape makes sure the feature string doesn't have any bad
+ // symbols that could confuse a parser somewhere
+ features->add_value(FD::Convert(Escape(feature_instance)), fval);
+}
diff --git a/decoder/ff_context.h b/decoder/ff_context.h
index 0d22b027..89bcb557 100644
--- a/decoder/ff_context.h
+++ b/decoder/ff_context.h
@@ -1,23 +1,46 @@
+
#ifndef _FF_CONTEXT_H_
#define _FF_CONTEXT_H_
#include <vector>
+#include <boost/xpressive/xpressive.hpp>
#include "ff.h"
+using namespace boost::xpressive;
+using namespace std;
+
class RuleContextFeatures : public FeatureFunction {
public:
- RuleContextFeatures(const std::string& param);
+ RuleContextFeatures(const string& param);
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
+ const vector<const void*>& ant_contexts,
SparseVector<double>* features,
SparseVector<double>* estimated_features,
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
+ virtual void ParseArgs(const string& in);
+ virtual string Escape(const string& x) const;
+ virtual void ReplaceMacroWithString(string& feature_instance,
+ bool token_vs_label,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceTokenMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void ReplaceLabelMacroWithString(string& feature_instance,
+ int relative_location,
+ const string& actual_token) const;
+ virtual void Error(const string&) const;
+
private:
- std::vector<WordID> current_input;
+ vector<int> token_relative_locations, label_relative_locations;
+ string feature_template;
+ vector<WordID> current_input;
WordID kSOS, kEOS;
+ sregex macro_regex;
+
};
#endif