#include "rule_factory.h"

#include <chrono>
#include <memory>
#include <queue>
#include <vector>

#include "grammar.h"
#include "fast_intersector.h"
#include "matchings_finder.h"
#include "phrase.h"
#include "phrase_builder.h"
#include "rule.h"
#include "rule_extractor.h"
#include "phrase_location_sampler.h"
#include "sampler.h"
#include "scorer.h"
#include "suffix_array.h"
#include "time_util.h"
#include "vocabulary.h"
#include "data_array.h"

using namespace std;
using namespace chrono;

namespace extractor {

typedef high_resolution_clock Clock;

struct State {
  State(int start, int end, const vector<int>& phrase,
      const vector<int>& subpatterns_start, shared_ptr<TrieNode> node,
      bool starts_with_x) :
      start(start), end(end), phrase(phrase),
      subpatterns_start(subpatterns_start), node(node),
      starts_with_x(starts_with_x) {}

  int start, end;
  vector<int> phrase, subpatterns_start;
  shared_ptr<TrieNode> node;
  bool starts_with_x;
};

HieroCachingRuleFactory::HieroCachingRuleFactory(
    shared_ptr<SuffixArray> source_suffix_array,
    shared_ptr<DataArray> target_data_array,
    shared_ptr<Alignment> alignment,
    const shared_ptr<Vocabulary>& vocabulary,
    shared_ptr<Precomputation> precomputation,
    shared_ptr<Scorer> scorer,
    int min_gap_size,
    int max_rule_span,
    int max_nonterminals,
    int max_rule_symbols,
    int max_samples,
    bool require_tight_phrases) :
    vocabulary(vocabulary),
    scorer(scorer),
    min_gap_size(min_gap_size),
    max_rule_span(max_rule_span),
    max_nonterminals(max_nonterminals),
    max_chunks(max_nonterminals + 1),
    max_rule_symbols(max_rule_symbols) {
  matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
  fast_intersector = make_shared<FastIntersector>(source_suffix_array,
      precomputation, vocabulary, max_rule_span, min_gap_size);
  phrase_builder = make_shared<PhraseBuilder>(vocabulary);
  rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
      target_data_array, alignment, phrase_builder, scorer, vocabulary,
      max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
      false, require_tight_phrases);
  sampler = make_shared<PhraseLocationSampler>(
      source_suffix_array, max_samples);
}

HieroCachingRuleFactory::HieroCachingRuleFactory(
    shared_ptr<MatchingsFinder> finder,
    shared_ptr<FastIntersector> fast_intersector,
    shared_ptr<PhraseBuilder> phrase_builder,
    shared_ptr<RuleExtractor> rule_extractor,
    shared_ptr<Vocabulary> vocabulary,
    shared_ptr<Sampler> sampler,
    shared_ptr<Scorer> scorer,
    int min_gap_size,
    int max_rule_span,
    int max_nonterminals,
    int max_chunks,
    int max_rule_symbols) :
    matchings_finder(finder),
    fast_intersector(fast_intersector),
    phrase_builder(phrase_builder),
    rule_extractor(rule_extractor),
    vocabulary(vocabulary),
    sampler(sampler),
    scorer(scorer),
    min_gap_size(min_gap_size),
    max_rule_span(max_rule_span),
    max_nonterminals(max_nonterminals),
    max_chunks(max_chunks),
    max_rule_symbols(max_rule_symbols) {}

HieroCachingRuleFactory::HieroCachingRuleFactory() {}

HieroCachingRuleFactory::~HieroCachingRuleFactory() {}

Grammar HieroCachingRuleFactory::GetGrammar(
    const vector<int>& word_ids,
    const unordered_set<int>& blacklisted_sentence_ids) {
  Clock::time_point start_time = Clock::now();
  double total_extract_time = 0;
  double total_intersect_time = 0;
  double total_lookup_time = 0;

  MatchingsTrie trie;
  shared_ptr<TrieNode> root = trie.GetRoot();

  int first_x = vocabulary->GetNonterminalIndex(1);
  shared_ptr<TrieNode> x_root(new TrieNode(root));
  root->AddChild(first_x, x_root);

  queue<State> states;
  for (size_t i = 0; i < word_ids.size(); ++i) {
    states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false));
  }
  for (size_t i = min_gap_size; i < word_ids.size(); ++i) {
    states.push(State(i - min_gap_size, i, vector<int>(1, first_x),
        vector<int>(1, i), x_root, true));
  }

  vector<Rule> rules;
  while (!states.empty()) {
    State state = states.front();
    states.pop();

    shared_ptr<TrieNode> node = state.node;
    vector<int> phrase = state.phrase;
    int word_id = word_ids[state.end];
    phrase.push_back(word_id);
    Phrase next_phrase = phrase_builder->Build(phrase);
    shared_ptr<TrieNode> next_node;

    if (CannotHaveMatchings(node, word_id)) {
      if (!node->HasChild(word_id)) {
        node->AddChild(word_id, shared_ptr<TrieNode>());
      }
      continue;
    }

    if (RequiresLookup(node, word_id)) {
      shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ?
          trie.GetRoot() : node->suffix_link->GetChild(word_id);
      if (state.starts_with_x) {
        // If the phrase starts with a non terminal, we simply use the matchings
        // from the suffix link.
        next_node = make_shared<TrieNode>(
            next_suffix_link, next_phrase, next_suffix_link->matchings);
      } else {
        PhraseLocation phrase_location;
        if (next_phrase.Arity() > 0) {
          // For phrases containing a nonterminal, we use either the occurrences
          // of the prefix or the suffix to determine the occurrences of the
          // phrase.
          Clock::time_point intersect_start = Clock::now();
          phrase_location = fast_intersector->Intersect(
              node->matchings, next_suffix_link->matchings, next_phrase);
          Clock::time_point intersect_stop = Clock::now();
          total_intersect_time += GetDuration(intersect_start, intersect_stop);
        } else {
          // For phrases not containing any nonterminals, we simply query the
          // suffix array using the suffix array range of the prefix as a
          // starting point.
          Clock::time_point lookup_start = Clock::now();
          phrase_location = matchings_finder->Find(
              node->matchings,
              vocabulary->GetTerminalValue(word_id),
              state.phrase.size());
          Clock::time_point lookup_stop = Clock::now();
          total_lookup_time += GetDuration(lookup_start, lookup_stop);
        }

        if (phrase_location.IsEmpty()) {
          continue;
        }

        // Create new trie node to store data about the current phrase.
        next_node = make_shared<TrieNode>(
            next_suffix_link, next_phrase, phrase_location);
      }
      // Add the new trie node to the trie cache.
      node->AddChild(word_id, next_node);

      // Automatically adds a trailing non terminal if allowed. Simply copy the
      // matchings from the prefix node.
      AddTrailingNonterminal(phrase, next_phrase, next_node,
                             state.starts_with_x);

      Clock::time_point extract_start = Clock::now();
      if (!state.starts_with_x) {
        // Extract rules for the sampled set of occurrences.
        PhraseLocation sample = sampler->Sample(
            next_node->matchings, blacklisted_sentence_ids);
        vector<Rule> new_rules =
            rule_extractor->ExtractRules(next_phrase, sample);
        rules.insert(rules.end(), new_rules.begin(), new_rules.end());
      }
      Clock::time_point extract_stop = Clock::now();
      total_extract_time += GetDuration(extract_start, extract_stop);
    } else {
      next_node = node->GetChild(word_id);
    }

    // Create more states (phrases) to be analyzed.
    vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase,
                                           next_node);
    for (State new_state: new_states) {
      states.push(new_state);
    }
  }

  Clock::time_point stop_time = Clock::now();
  #pragma omp critical (stderr_write)
  {
    cerr << "Total time for rule lookup, extraction, and scoring = "
         << GetDuration(start_time, stop_time) << " seconds" << endl;
    cerr << "Extract time = " << total_extract_time << " seconds" << endl;
    cerr << "Intersect time = " << total_intersect_time << " seconds" << endl;
    cerr << "Lookup time = " << total_lookup_time << " seconds" << endl;
  }
  return Grammar(rules, scorer->GetFeatureNames());
}

bool HieroCachingRuleFactory::CannotHaveMatchings(
    shared_ptr<TrieNode> node, int word_id) {
  if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) {
    return true;
  }

  shared_ptr<TrieNode> suffix_link = node->suffix_link;
  return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL;
}

bool HieroCachingRuleFactory::RequiresLookup(
    shared_ptr<TrieNode> node, int word_id) {
  return !node->HasChild(word_id);
}

void HieroCachingRuleFactory::AddTrailingNonterminal(
    vector<int> symbols,
    const Phrase& prefix,
    const shared_ptr<TrieNode>& prefix_node,
    bool starts_with_x) {
  if (prefix.Arity() >= max_nonterminals) {
    return;
  }

  int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1);
  symbols.push_back(var_id);
  Phrase var_phrase = phrase_builder->Build(symbols);

  int suffix_var_id = vocabulary->GetNonterminalIndex(
      prefix.Arity() + (starts_with_x == 0));
  shared_ptr<TrieNode> var_suffix_link =
      prefix_node->suffix_link->GetChild(suffix_var_id);

  prefix_node->AddChild(var_id, make_shared<TrieNode>(
      var_suffix_link, var_phrase, prefix_node->matchings));
}

vector<State> HieroCachingRuleFactory::ExtendState(
    const vector<int>& word_ids,
    const State& state,
    vector<int> symbols,
    const Phrase& phrase,
    const shared_ptr<TrieNode>& node) {
  int span = state.end - state.start;
  vector<State> new_states;
  if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() ||
      span >= max_rule_span) {
    return new_states;
  }

  // New state for adding the next symbol.
  new_states.push_back(State(state.start, state.end + 1, symbols,
      state.subpatterns_start, node, state.starts_with_x));

  int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);
  if (symbols.size() + 1 >= max_rule_symbols ||
      phrase.Arity() >= max_nonterminals ||
      num_subpatterns >= max_chunks) {
    return new_states;
  }

  // New states for adding a nonterminal followed by a new symbol.
  int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1);
  symbols.push_back(var_id);
  vector<int> subpatterns_start = state.subpatterns_start;
  size_t i = state.end + 1 + min_gap_size;
  while (i < word_ids.size() && i - state.start <= max_rule_span) {
    subpatterns_start.push_back(i);
    new_states.push_back(State(state.start, i, symbols, subpatterns_start,
        node->GetChild(var_id), state.starts_with_x));
    subpatterns_start.pop_back();
    ++i;
  }

  return new_states;
}

} // namespace extractor