#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include <boost/archive/binary_iarchive.hpp>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#if HAVE_OPEN_MP
 #include <omp.h>
#else
  const unsigned omp_get_num_threads() { return 1; }
#endif

#include "alignment.h"
#include "data_array.h"
#include "features/count_source_target.h"
#include "features/feature.h"
#include "features/is_source_singleton.h"
#include "features/is_source_target_singleton.h"
#include "features/max_lex_source_given_target.h"
#include "features/max_lex_target_given_source.h"
#include "features/sample_source_count.h"
#include "features/target_given_source_coherent.h"
#include "grammar.h"
#include "grammar_extractor.h"
#include "precomputation.h"
#include "rule.h"
#include "scorer.h"
#include "suffix_array.h"
#include "time_util.h"
#include "translation_table.h"
#include "vocabulary.h"

namespace ar = boost::archive;
namespace fs = boost::filesystem;
namespace po = boost::program_options;
using namespace extractor;
using namespace features;
using namespace std;

// Returns the file path in which a given grammar should be written.
fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
  string file_name = "grammar." + to_string(file_number);
  return grammar_path / file_name;
}

int main(int argc, char** argv) {
  po::options_description general_options("General options");
  int max_threads = 1;
  #pragma omp parallel
  max_threads = omp_get_num_threads();
  string threads_option = "Number of threads used for grammar extraction "
                          "max(" + to_string(max_threads) + ")";
  general_options.add_options()
    ("threads,t", po::value<int>()->required()->default_value(1),
     threads_option.c_str())
    ("grammars,g", po::value<string>()->required(), "Grammars output path")
    ("max_rule_span", po::value<int>()->default_value(15),
        "Maximum rule span")
    ("max_rule_symbols", po::value<int>()->default_value(5),
        "Maximum number of symbols (terminals + nontermals) in a rule")
    ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
    ("max_nonterminals", po::value<int>()->default_value(2),
        "Maximum number of nonterminals in a rule")
    ("max_samples", po::value<int>()->default_value(300),
        "Maximum number of samples")
    ("tight_phrases", po::value<bool>()->default_value(true),
        "False if phrases may be loose (better, but slower)")
    ("leave_one_out", po::value<bool>()->zero_tokens(),
        "do leave-one-out estimation of grammars "
        "(e.g. for extracting grammars for the training set");

  po::options_description cmdline_options("Command line options");
  cmdline_options.add_options()
    ("help", "Show available options")
    ("config,c", po::value<string>()->required(), "Path to config file");
  cmdline_options.add(general_options);

  po::options_description config_options("Config file options");
  config_options.add_options()
    ("target", po::value<string>()->required(),
        "Path to target data file in binary format")
    ("source", po::value<string>()->required(),
        "Path to source suffix array file in binary format")
    ("alignment", po::value<string>()->required(),
        "Path to alignment file in binary format")
    ("precomputation", po::value<string>()->required(),
        "Path to precomputation file in binary format")
    ("vocabulary", po::value<string>()->required(),
        "Path to vocabulary file in binary format")
    ("ttable", po::value<string>()->required(),
        "Path to translation table in binary format");
  config_options.add(general_options);

  po::variables_map vm;
  po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
  if (vm.count("help")) {
    po::options_description all_options;
    all_options.add(cmdline_options).add(config_options);
    cout << all_options << endl;
    return 0;
  }

  po::notify(vm);

  ifstream config_stream(vm["config"].as<string>());
  po::store(po::parse_config_file(config_stream, config_options), vm);
  po::notify(vm);

  int num_threads = vm["threads"].as<int>();
  cerr << "Grammar extraction will use " << num_threads << " threads." << endl;

  Clock::time_point read_start_time = Clock::now();

  Clock::time_point start_time = Clock::now();
  cerr << "Reading target data in binary format..." << endl;
  shared_ptr<DataArray> target_data_array = make_shared<DataArray>();
  ifstream target_fstream(vm["target"].as<string>());
  ar::binary_iarchive target_stream(target_fstream);
  target_stream >> *target_data_array;
  Clock::time_point end_time = Clock::now();
  cerr << "Reading target data took " << GetDuration(start_time, end_time)
       << " seconds" << endl;

  start_time = Clock::now();
  cerr << "Reading source suffix array in binary format..." << endl;
  shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>();
  ifstream source_fstream(vm["source"].as<string>());
  ar::binary_iarchive source_stream(source_fstream);
  source_stream >> *source_suffix_array;
  end_time = Clock::now();
  cerr << "Reading source suffix array took "
       << GetDuration(start_time, end_time) << " seconds" << endl;

  start_time = Clock::now();
  cerr << "Reading alignment in binary format..." << endl;
  shared_ptr<Alignment> alignment = make_shared<Alignment>();
  ifstream alignment_fstream(vm["alignment"].as<string>());
  ar::binary_iarchive alignment_stream(alignment_fstream);
  alignment_stream >> *alignment;
  end_time = Clock::now();
  cerr << "Reading alignment took " << GetDuration(start_time, end_time)
       << " seconds" << endl;

  start_time = Clock::now();
  cerr << "Reading precomputation in binary format..." << endl;
  shared_ptr<Precomputation> precomputation = make_shared<Precomputation>();
  ifstream precomputation_fstream(vm["precomputation"].as<string>());
  ar::binary_iarchive precomputation_stream(precomputation_fstream);
  precomputation_stream >> *precomputation;
  end_time = Clock::now();
  cerr << "Reading precomputation took " << GetDuration(start_time, end_time)
       << " seconds" << endl;

  start_time = Clock::now();
  cerr << "Reading vocabulary in binary format..." << endl;
  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
  ifstream vocabulary_fstream(vm["vocabulary"].as<string>());
  ar::binary_iarchive vocabulary_stream(vocabulary_fstream);
  vocabulary_stream >> *vocabulary;
  end_time = Clock::now();
  cerr << "Reading vocabulary took " << GetDuration(start_time, end_time)
       << " seconds" << endl;

  start_time = Clock::now();
  cerr << "Reading translation table in binary format..." << endl;
  shared_ptr<TranslationTable> table = make_shared<TranslationTable>();
  ifstream ttable_fstream(vm["ttable"].as<string>());
  ar::binary_iarchive ttable_stream(ttable_fstream);
  ttable_stream >> *table;
  end_time = Clock::now();
  cerr << "Reading translation table took " << GetDuration(start_time, end_time)
       << " seconds" << endl;

  Clock::time_point read_end_time = Clock::now();
  cerr << "Total time spent loading data structures into memory: "
       << GetDuration(read_start_time, read_end_time) << " seconds" << endl;

  Clock::time_point extraction_start_time = Clock::now();
  // Features used to score each grammar rule.
  vector<shared_ptr<Feature>> features = {
      make_shared<TargetGivenSourceCoherent>(),
      make_shared<SampleSourceCount>(),
      make_shared<CountSourceTarget>(),
      make_shared<MaxLexSourceGivenTarget>(table),
      make_shared<MaxLexTargetGivenSource>(table),
      make_shared<IsSourceSingleton>(),
      make_shared<IsSourceTargetSingleton>()
  };
  shared_ptr<Scorer> scorer = make_shared<Scorer>(features);

  GrammarExtractor extractor(
      source_suffix_array,
      target_data_array,
      alignment,
      precomputation,
      scorer,
      vocabulary,
      vm["min_gap_size"].as<int>(),
      vm["max_rule_span"].as<int>(),
      vm["max_nonterminals"].as<int>(),
      vm["max_rule_symbols"].as<int>(),
      vm["max_samples"].as<int>(),
      vm["tight_phrases"].as<bool>());

  // Creates the grammars directory if it doesn't exist.
  fs::path grammar_path = vm["grammars"].as<string>();
  if (!fs::is_directory(grammar_path)) {
    fs::create_directory(grammar_path);
  }

  // Reads all sentences for which we extract grammar rules (the paralellization
  // is simplified if we read all sentences upfront).
  string sentence;
  vector<string> sentences;
  while (getline(cin, sentence)) {
    sentences.push_back(sentence);
  }

  // Extracts the grammar for each sentence and saves it to a file.
  vector<string> suffixes(sentences.size());
  bool leave_one_out = vm.count("leave_one_out");
  #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
  for (size_t i = 0; i < sentences.size(); ++i) {
    string suffix;
    int position = sentences[i].find("|||");
    if (position != sentences[i].npos) {
      suffix = sentences[i].substr(position);
      sentences[i] = sentences[i].substr(0, position);
    }
    suffixes[i] = suffix;

    unordered_set<int> blacklisted_sentence_ids;
    if (leave_one_out) {
      blacklisted_sentence_ids.insert(i);
    }
    Grammar grammar = extractor.GetGrammar(
        sentences[i], blacklisted_sentence_ids);
    ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
    output << grammar;
  }

  for (size_t i = 0; i < sentences.size(); ++i) {
    cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""
         << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
  }

  Clock::time_point extraction_stop_time = Clock::now();
  cerr << "Overall extraction step took "
       << GetDuration(extraction_start_time, extraction_stop_time)
       << " seconds" << endl;

  return 0;
}