summaryrefslogtreecommitdiff
path: root/decoder/ff_context.cc
blob: e56f6f1f9024a29213a7cecf1d40329716b792c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#include "ff_context.h"

#include <stdlib.h>
#include <sstream>
#include <cassert>
#include <cmath>

#include "hg.h"
#include "filelib.h"
#include "stringlib.h"
#include "sentence_metadata.h"
#include "lattice.h"
#include "fdict.h"
#include "verbose.h"
#include "tdict.h"

RuleContextFeatures::RuleContextFeatures(const string& param) {
  //  cerr << "initializing RuleContextFeatures with parameters: " << param;
  kSOS = TD::Convert("<s>");
  kEOS = TD::Convert("</s>");
  macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]");
  ParseArgs(param);
}

string RuleContextFeatures::Escape(const string& x) const {
  string y = x;
  for (int i = 0; i < y.size(); ++i) {
    if (y[i] == '=') y[i]='_';
    if (y[i] == ';') y[i]='_';
  }
  return y;
}

// replace %x[relative_location] or %y[relative_location] with actual_token
// within feature_instance
void RuleContextFeatures::ReplaceMacroWithString(
  string& feature_instance, bool token_vs_label, int relative_location, 
  const string& actual_token) const {

  stringstream macro;
  if (token_vs_label) {
    macro << "%x[";
  } else {
    macro << "%y[";
  }
  macro << relative_location << "]";
  int macro_index = feature_instance.find(macro.str());
  if (macro_index == string::npos) {
    cerr << "Can't find macro " << macro.str() << " in feature template " 
	 << feature_instance;
    abort();
  }
  feature_instance.replace(macro_index, macro.str().size(), actual_token);
}

void RuleContextFeatures::ReplaceTokenMacroWithString(
  string& feature_instance, int relative_location, 
  const string& actual_token) const {

  ReplaceMacroWithString(feature_instance, true, relative_location, 
				actual_token);
}

void RuleContextFeatures::ReplaceLabelMacroWithString(
  string& feature_instance, int relative_location, 
  const string& actual_token) const {

  ReplaceMacroWithString(feature_instance, false, relative_location, 
				actual_token);
}

void RuleContextFeatures::Error(const string& error_message) const {
  cerr << "Error: " << error_message << "\n\n"

       << "RuleContextFeatures Usage: \n"			     
       << "  feature_function=RuleContextFeatures -t <TEMPLATE>\n\n" 

       << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n" 

       << "%x[k] and %y[k] are macros to be instantiated with an input\n" 
       << "token (for x) or a label (for y). k specifies the relative\n" 
       << "location of the input token or label with respect to the current\n"
       << "position. For x, k is an integer value. For y, k must be 0 (to\n" 
       << "be extended).\n\n";

  abort();
}

void RuleContextFeatures::ParseArgs(const string& in) {
  vector<string> const& argv = SplitOnWhitespace(in);
  for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) {
    string const& s = *i;
    if (s[0] == '-') {
      if (s.size() > 2) { 
	stringstream msg;
	msg << s << " is an invalid option for RuleContextFeatures.";
	Error(msg.str());
      }

      switch (s[1]) {

      // feature template
      case 't': {
	if (++i == argv.end()) { 
	  Error("Can't find template."); 
	}
	feature_template = *i;
	string::const_iterator start = feature_template.begin();
	string::const_iterator end = feature_template.end();
	smatch macro_match;

        // parse the template
	while (regex_search(start, end, macro_match, macro_regex)) {
	  // get the relative location
	  string relative_location_str(macro_match[2].first, 
				       macro_match[2].second);
	  int relative_location = atoi(relative_location_str.c_str());
	  // add it to the list of relative locations for token or label
	  // (i.e. x or y)
	  bool valid_location = true;
          if (*macro_match[1].first == 'x') {
	    // add it to token locations
	    token_relative_locations.push_back(relative_location);
	  } else {
	    if (relative_location != 0) { valid_location = false; }
	    // add it to label locations
	    label_relative_locations.push_back(relative_location);
	  }
	  if (!valid_location) {
	    stringstream msg;
	    msg << "Relative location " << relative_location
		<< " in feature template " << feature_template
		<< " is invalid.";
	    Error(msg.str());
	  }
	  start = macro_match[0].second;
	}
	break;
      }

      // TODO: arguments to specify kSOS and kEOS

      default: {
	stringstream msg;
	msg << "Invalid option on RuleContextFeatures: " << s;
	Error(msg.str());
	break;
      }
      } // end of switch
    } // end of if (token starts with hyphen)
  } // end of for loop (over arguments)

  // the -t (i.e. template) option is mandatory in this feature function
  if (label_relative_locations.size() == 0 || 
      token_relative_locations.size() == 0) {
    stringstream msg;
    msg << "Feature template must specify at least one" 
	 << "token macro (e.g. x[-1]) and one label macro (e.g. y[0]).";
    Error(msg.str());
  }
}

void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
  const Lattice& sl = smeta.GetSourceLattice();
  current_input.resize(sl.size());
  for (unsigned i = 0; i < sl.size(); ++i) {
    if (sl[i].size() != 1) {
      stringstream msg;
      msg << "RuleContextFeatures don't support lattice inputs!\nid=" 
	   << smeta.GetSentenceId() << endl;
      Error(msg.str());
    }
    current_input[i] = sl[i][0].label;
  }
}

void RuleContextFeatures::TraversalFeaturesImpl(
  const SentenceMetadata& smeta, const Hypergraph::Edge& edge,
  const vector<const void*>& ant_contexts, SparseVector<double>* features,
  SparseVector<double>* estimated_features, void* context) const {

  const TRule& rule = *edge.rule_;
  // arity = 0, no nonterminals
  // size = 1, predicted label is a single token
  if (rule.Arity() != 0 ||
      rule.e_.size() != 1) { 
    return; 
  }

  // replace label macros with actual label strings
  // NOTE: currently, this feature function doesn't allow any label
  // macros except %y[0]. but you can look at as much of the source as you want
  const WordID y0 = rule.e_[0];
  string y0_str = TD::Convert(y0);

  // start of the span in the input being labeled
  const int from_src_index = edge.i_;   
  // end of the span in the input
  const int to_src_index = edge.j_;

  // in the case of tagging the size of the spans being labeled will
  //  always be 1, but in other formalisms, you can have bigger spans
  if (to_src_index - from_src_index != 1) {
    cerr << "RuleContextFeatures doesn't support input spans of length != 1";
    abort();
  }

  string feature_instance = feature_template;
  // replace token macros with actual token strings
  for (unsigned i = 0; i < token_relative_locations.size(); ++i) {
    int loc = token_relative_locations[i];
    WordID x = loc < 0? kSOS: kEOS;
    if(from_src_index + loc >= 0 && 
       from_src_index + loc < current_input.size()) {
      x = current_input[from_src_index + loc];
    }
    string x_str = TD::Convert(x);
    ReplaceTokenMacroWithString(feature_instance, loc, x_str);
  }
    
  ReplaceLabelMacroWithString(feature_instance, 0, y0_str);

  // pick a real value for this feature
  double fval = 1.0; 

  // add it to the feature vector
  //   FD::Convert converts the feature string to a feature int
  //   Escape makes sure the feature string doesn't have any bad
  //     symbols that could confuse a parser somewhere
  features->add_value(FD::Convert(Escape(feature_instance)), fval);
}