#include "ff_context.h"

#include <stdlib.h>
#include <sstream>
#include <cassert>
#include <cmath>

#include "hg.h"
#include "filelib.h"
#include "stringlib.h"
#include "sentence_metadata.h"
#include "lattice.h"
#include "fdict.h"
#include "verbose.h"
#include "tdict.h"

RuleContextFeatures::RuleContextFeatures(const string& param) {
  //  cerr << "initializing RuleContextFeatures with parameters: " << param;
  kSOS = TD::Convert("<s>");
  kEOS = TD::Convert("</s>");
  macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]");
  ParseArgs(param);
}

string RuleContextFeatures::Escape(const string& x) const {
  string y = x;
  for (int i = 0; i < y.size(); ++i) {
    if (y[i] == '=') y[i]='_';
    if (y[i] == ';') y[i]='_';
  }
  return y;
}

// replace %x[relative_location] or %y[relative_location] with actual_token
// within feature_instance
void RuleContextFeatures::ReplaceMacroWithString(
  string& feature_instance, bool token_vs_label, int relative_location, 
  const string& actual_token) const {

  stringstream macro;
  if (token_vs_label) {
    macro << "%x[";
  } else {
    macro << "%y[";
  }
  macro << relative_location << "]";
  int macro_index = feature_instance.find(macro.str());
  if (macro_index == string::npos) {
    cerr << "Can't find macro " << macro << " in feature template " 
	 << feature_instance;
    abort();
  }
  feature_instance.replace(macro_index, macro.str().size(), actual_token);
}

void RuleContextFeatures::ReplaceTokenMacroWithString(
  string& feature_instance, int relative_location, 
  const string& actual_token) const {

  ReplaceMacroWithString(feature_instance, true, relative_location, 
				actual_token);
}

void RuleContextFeatures::ReplaceLabelMacroWithString(
  string& feature_instance, int relative_location, 
  const string& actual_token) const {

  ReplaceMacroWithString(feature_instance, false, relative_location, 
				actual_token);
}

void RuleContextFeatures::Error(const string& error_message) const {
  cerr << "Error: " << error_message << "\n\n"

       << "RuleContextFeatures Usage: \n"			     
       << "  feature_function=RuleContextFeatures -t <TEMPLATE>\n\n" 

       << "Example <TEMPLATE>: U1:%x[-1]_%x[0]|%y[0]\n\n" 

       << "%x[k] and %y[k] are macros to be instantiated with an input\n" 
       << "token (for x) or a label (for y). k specifies the relative\n" 
       << "location of the input token or label with respect to the current\n"
       << "position. For x, k is an integer value. For y, k must be 0 (to\n" 
       << "be extended).\n\n";

  abort();
}

void RuleContextFeatures::ParseArgs(const string& in) {
  vector<string> const& argv = SplitOnWhitespace(in);
  for (vector<string>::const_iterator i = argv.begin(); i != argv.end(); ++i) {
    string const& s = *i;
    if (s[0] == '-') {
      if (s.size() > 2) { 
	stringstream msg;
	msg << s << " is an invalid option for RuleContextFeatures.";
	Error(msg.str());
      }

      switch (s[1]) {

      // feature template
      case 't': {
	if (++i == argv.end()) { 
	  Error("Can't find template."); 
	}
	feature_template = *i;
	string::const_iterator start = feature_template.begin();
	string::const_iterator end = feature_template.end();
	smatch macro_match;

        // parse the template
	while (regex_search(start, end, macro_match, macro_regex)) {
	  // get the relative location
	  string relative_location_str(macro_match[2].first, 
				       macro_match[2].second);
	  int relative_location = atoi(relative_location_str.c_str());
	  // add it to the list of relative locations for token or label
	  // (i.e. x or y)
	  bool valid_location = true;
          if (*macro_match[1].first == 'x') {
	    // add it to token locations
	    token_relative_locations.push_back(relative_location);
	  } else {
	    if (relative_location != 0) { valid_location = false; }
	    // add it to label locations
	    label_relative_locations.push_back(relative_location);
	  }
	  if (!valid_location) {
	    stringstream msg;
	    msg << "Relative location " << relative_location
		<< " in feature template " << feature_template
		<< " is invalid.";
	    Error(msg.str());
	  }
	  start = macro_match[0].second;
	}
	break;
      }

      // TODO: arguments to specify kSOS and kEOS

      default: {
	stringstream msg;
	msg << "Invalid option on RuleContextFeatures: " << s;
	Error(msg.str());
	break;
      }
      } // end of switch
    } // end of if (token starts with hyphen)
  } // end of for loop (over arguments)

  // the -t (i.e. template) option is mandatory in this feature function
  if (label_relative_locations.size() == 0 || 
      token_relative_locations.size() == 0) {
    stringstream msg;
    msg << "Feature template must specify at least one" 
	 << "token macro (e.g. x[-1]) and one label macro (e.g. y[0]).";
    Error(msg.str());
  }
}

void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
  const Lattice& sl = smeta.GetSourceLattice();
  current_input.resize(sl.size());
  for (unsigned i = 0; i < sl.size(); ++i) {
    if (sl[i].size() != 1) {
      stringstream msg;
      msg << "RuleContextFeatures don't support lattice inputs!\nid=" 
	   << smeta.GetSentenceId() << endl;
      Error(msg.str());
    }
    current_input[i] = sl[i][0].label;
  }
}

void RuleContextFeatures::TraversalFeaturesImpl(
  const SentenceMetadata& smeta, const Hypergraph::Edge& edge,
  const vector<const void*>& ant_contexts, SparseVector<double>* features,
  SparseVector<double>* estimated_features, void* context) const {

  const TRule& rule = *edge.rule_;
  // arity = 0, no nonterminals
  // size = 1, predicted label is a single token
  if (rule.Arity() != 0 ||
      rule.e_.size() != 1) { 
    return; 
  }

  // replace label macros with actual label strings
  // NOTE: currently, this feature function doesn't allow any label
  // macros except %y[0]. but you can look at as much of the source as you want
  const WordID y0 = rule.e_[0];
  string y0_str = TD::Convert(y0);

  // start of the span in the input being labeled
  const int from_src_index = edge.i_;   
  // end of the span in the input
  const int to_src_index = edge.j_;

  // in the case of tagging the size of the spans being labeled will
  //  always be 1, but in other formalisms, you can have bigger spans
  if (to_src_index - from_src_index != 1) {
    cerr << "RuleContextFeatures doesn't support input spans of length != 1";
    abort();
  }

  string feature_instance = feature_template;
  // replace token macros with actual token strings
  for (unsigned i = 0; i < token_relative_locations.size(); ++i) {
    int loc = token_relative_locations[i];
    WordID x = loc < 0? kSOS: kEOS;
    if(from_src_index + loc >= 0 && 
       from_src_index + loc < current_input.size()) {
      x = current_input[from_src_index + loc];
    }
    string x_str = TD::Convert(x);
    ReplaceTokenMacroWithString(feature_instance, loc, x_str);
  }
    
  ReplaceLabelMacroWithString(feature_instance, 0, y0_str);

  // pick a real value for this feature
  double fval = 1.0; 

  // add it to the feature vector
  //   FD::Convert converts the feature string to a feature int
  //   Escape makes sure the feature string doesn't have any bad
  //     symbols that could confuse a parser somewhere
  features->add_value(FD::Convert(Escape(feature_instance)), fval);
}