diff options
author | Waleed Ammar <wammar@cs.cmu.edu> | 2012-06-16 05:35:49 -0400 |
---|---|---|
committer | Waleed Ammar <wammar@cs.cmu.edu> | 2012-06-16 05:35:49 -0400 |
commit | 21bed034628ae5e0e51a5dae6db0427ec1fd2e47 (patch) | |
tree | 527fab1b2677e8db3210692eb8bd8e56b2ffa371 /decoder/ff_context.cc | |
parent | 3acdf1e4b37637d6df86a7b54fb0f1b0464c172b (diff) |
enable regex-based feature templates
Diffstat (limited to 'decoder/ff_context.cc')
-rw-r--r-- | decoder/ff_context.cc | 262 |
1 files changed, 196 insertions, 66 deletions
diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc index 19f9a413..9de4d737 100644 --- a/decoder/ff_context.cc +++ b/decoder/ff_context.cc @@ -1,5 +1,6 @@ #include "ff_context.h" +#include <stdlib.h> #include <sstream> #include <cassert> #include <cmath> @@ -11,24 +12,150 @@ #include "fdict.h" #include "verbose.h" -using namespace std; +RuleContextFeatures::RuleContextFeatures(const string& param) { + // cerr << "initializing RuleContextFeatures with parameters: " << param; + kSOS = TD::Convert("<s>"); + kEOS = TD::Convert("</s>"); + macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]"); + ParseArgs(param); +} -namespace { - string Escape(const string& x) { - string y = x; - for (int i = 0; i < y.size(); ++i) { - if (y[i] == '=') y[i]='_'; - if (y[i] == ';') y[i]='_'; - } - return y; +string RuleContextFeatures::Escape(const string& x) const { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; } + return y; } -RuleContextFeatures::RuleContextFeatures(const std::string& param) { - kSOS = TD::Convert("<s>"); - kEOS = TD::Convert("</s>"); +// replace %x[relative_location] or %y[relative_location] with actual_token +// within feature_instance +void RuleContextFeatures::ReplaceMacroWithString( + string& feature_instance, bool token_vs_label, int relative_location, + const string& actual_token) const { + + stringstream macro; + if (token_vs_label) { + macro << "%x["; + } else { + macro << "%y["; + } + macro << relative_location << "]"; + int macro_index = feature_instance.find(macro.str()); + if (macro_index == string::npos) { + cerr << "Can't find macro " << macro << " in feature template " + << feature_instance; + abort(); + } + feature_instance.replace(macro_index, macro.str().size(), actual_token); +} + +void RuleContextFeatures::ReplaceTokenMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, true, relative_location, + actual_token); +} - // TODO param lets you pass in a string from the cdec.ini file +void RuleContextFeatures::ReplaceLabelMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, false, relative_location, + actual_token); +} + +void RuleContextFeatures::Error(const string& error_message) const { + cerr << "Error: " << error_message << "\n\n" + + << "RuleContextFeatures Usage: \n" + << " feature_function=RuleContextFeatures -t <TEMPLATE>\n\n" + + << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n" + + << "%x[k] and %y[k] are macros to be instantiated with an input\n" + << "token (for x) or a label (for y). k specifies the relative\n" + << "location of the input token or label with respect to the current\n" + << "position. For x, k is an integer value. For y, k must be 0 (to\n" + << "be extended).\n\n"; + + abort(); +} + +void RuleContextFeatures::ParseArgs(const string& in) { + vector<string> const& argv = SplitOnWhitespace(in); + for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) { + string const& s = *i; + if (s[0] == '-') { + if (s.size() > 2) { + stringstream msg; + msg << s << " is an invalid option for RuleContextFeatures."; + Error(msg.str()); + } + + switch (s[1]) { + + // feature template + case 't': { + if (++i == argv.end()) { + Error("Can't find template."); + } + feature_template = *i; + string::const_iterator start = feature_template.begin(); + string::const_iterator end = feature_template.end(); + smatch macro_match; + + // parse the template + while (regex_search(start, end, macro_match, macro_regex)) { + // get the relative location + string relative_location_str(macro_match[2].first, + macro_match[2].second); + int relative_location = atoi(relative_location_str.c_str()); + // add it to the list of relative locations for token or label + // (i.e. x or y) + bool valid_location = true; + if (*macro_match[1].first == 'x') { + // add it to token locations + token_relative_locations.push_back(relative_location); + } else { + if (relative_location != 0) { valid_location = false; } + // add it to label locations + label_relative_locations.push_back(relative_location); + } + if (!valid_location) { + stringstream msg; + msg << "Relative location " << relative_location + << " in feature template " << feature_template + << " is invalid."; + Error(msg.str()); + } + start = macro_match[0].second; + } + break; + } + + // TODO: arguments to specify kSOS and kEOS + + default: { + stringstream msg; + msg << "Invalid option on RuleContextFeatures: " << s; + Error(msg.str()); + break; + } + } // end of switch + } // end of if (token starts with hyphen) + } // end of for loop (over arguments) + + // the -t (i.e. template) option is mandatory in this feature function + if (label_relative_locations.size() == 0 || + token_relative_locations.size() == 0) { + stringstream msg; + msg << "Feature template must specify at least one" + << "token macro (e.g. x[-1]) and one label macro (e.g. y[0])."; + Error(msg.str()); + } } void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) { @@ -36,64 +163,67 @@ void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) { current_input.resize(sl.size()); for (unsigned i = 0; i < sl.size(); ++i) { if (sl[i].size() != 1) { - cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl; - abort(); + stringstream msg; + msg << "RuleContextFeatures don't support lattice inputs!\nid=" + << smeta.GetSentenceId() << endl; + Error(msg.str()); } current_input[i] = sl[i][0].label; } } -void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const { +void RuleContextFeatures::TraversalFeaturesImpl( + const SentenceMetadata& smeta, const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, SparseVector<double>* features, + SparseVector<double>* estimated_features, void* context) const { + const TRule& rule = *edge.rule_; + // arity = 0, no nonterminals + // size = 1, predicted label is a single token + if (rule.Arity() != 0 || + rule.e_.size() != 1) { + return; + } - if (rule.Arity() != 0 || // arity = 0, no nonterminals - rule.e_.size() != 1) return; // size = 1, predicted label is a single token - - - // you can see the current label "for free" - const WordID cur_label = rule.e_[0]; - // (if you want to see more labels, you have to be very careful, and muck - // about with contexts and ant_contexts) - - // but... you can look at as much of the source as you want! - const int from_src_index = edge.i_; // start of the span in the input being labeled - const int to_src_index = edge.j_; // end of the span in the input - // (note: in the case of tagging the size of the spans being labeled will - // always be 1, but in other formalisms, you can have bigger spans.) - - // this is the current token being labeled: - const WordID cur_input = current_input[from_src_index]; - - // let's get the previous token in the input (may be to the left of the start - // of the sentence!) - WordID prev_input = kSOS; - if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; } - // let's get the next token (may be to the left of the start of the sentence!) - WordID next_input = kEOS; - if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; } - - // now, build a feature string - ostringstream os; - // TD::Convert converts from the internal integer representation of a token - // to the actual token - os << "C1:" << TD::Convert(prev_input) << '_' - << TD::Convert(cur_input) << '|' << TD::Convert(cur_label); - // C1 is just to prevent a name clash - - // pick a value - double fval = 1.0; // can be any real value - - // add it to the feature vector FD::Convert converts the feature string to a - // feature int, Escape makes sure the feature string doesn't have any bad - // symbols that could confuse a parser somewhere - features->add_value(FD::Convert(Escape(os.str())), fval); - // that's it! - - // create more features if you like... -} + // replace label macros with actual label strings + // NOTE: currently, this feature function doesn't allow any label + // macros except %y[0]. but you can look at as much of the source as you want + const WordID y0 = rule.e_[0]; + string y0_str = TD::Convert(y0); + + // start of the span in the input being labeled + const int from_src_index = edge.i_; + // end of the span in the input + const int to_src_index = edge.j_; + + // in the case of tagging the size of the spans being labeled will + // always be 1, but in other formalisms, you can have bigger spans + if (to_src_index - from_src_index != 1) { + cerr << "RuleContextFeatures doesn't support input spans of length != 1"; + abort(); + } + string feature_instance = feature_template; + // replace token macros with actual token strings + for (unsigned i = 0; i < token_relative_locations.size(); ++i) { + int loc = token_relative_locations[i]; + WordID x = loc < 0? kSOS: kEOS; + if(from_src_index + loc >= 0 && + from_src_index + loc < current_input.size()) { + x = current_input[from_src_index + loc]; + } + string x_str = TD::Convert(x); + ReplaceTokenMacroWithString(feature_instance, loc, x_str); + } + + ReplaceLabelMacroWithString(feature_instance, 0, y0_str); + + // pick a real value for this feature + double fval = 1.0; + + // add it to the feature vector + // FD::Convert converts the feature string to a feature int + // Escape makes sure the feature string doesn't have any bad + // symbols that could confuse a parser somewhere + features->add_value(FD::Convert(Escape(feature_instance)), fval); +} |