#include <iostream>
#include <tr1/memory>

#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "filelib.h"
#include "dict.h"
#include "sampler.h"
#include "ccrp.h"
#include "ccrp_nt.h"

using namespace std;
using namespace std::tr1;
namespace po = boost::program_options;

Dict d; // global dictionary

string Join(char joiner, const vector<int>& phrase) {
  ostringstream os;
  for (int i = 0; i < phrase.size(); ++i) {
    if (i > 0) os << joiner;
    os << d.Convert(phrase[i]);
  }
  return os.str();
}

template <typename BType>
void WriteSeg(const vector<int>& line, const vector<BType>& label, const Dict& d) {
  assert(line.size() == label.size());
  assert(label.back());
  int prev = 0;
  int cur = 0;
  while (cur < line.size()) {
    if (label[cur]) {
      if (prev) cout << ' ';
      cout << "{{";
      for (int i = prev; i <= cur; ++i)
        cout << (i == prev ? "" : " ") << d.Convert(line[i]);
      cout << "}}:" << label[cur];
      prev = cur + 1;
    }
    ++cur;
  }
  cout << endl;
}

ostream& operator<<(ostream& os, const vector<int>& phrase) {
  for (int i = 0; i < phrase.size(); ++i)
    os << (i == 0 ? "" : " ") << d.Convert(phrase[i]);
  return os;
}

struct UnigramLM {
  explicit UnigramLM(const string& fname) {
    ifstream in(fname.c_str());
    assert(in);
  }

  double logprob(int word) const {
    assert(word < freqs_.size());
    return freqs_[word];
  }

  vector<double> freqs_;
};

void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
        ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples")
        ("input,i",po::value<string>(),"Read file from")
        ("random_seed,S",po::value<uint32_t>(), "Random seed")
        ("write_cdec_grammar,g", po::value<string>(), "Write cdec grammar to this file")
        ("write_cdec_weights,w", po::value<string>(), "Write cdec weights to this file")
        ("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion")
        ("no_hyperparameter_inference,N", "Disable hyperparameter inference");
  po::options_description clo("Command line options");
  clo.add_options()
        ("config", po::value<string>(), "Configuration file")
        ("help,h", "Print this help message and exit");
  po::options_description dconfig_options, dcmdline_options;
  dconfig_options.add(opts);
  dcmdline_options.add(opts).add(clo);
  
  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  if (conf->count("config")) {
    ifstream config((*conf)["config"].as<string>().c_str());
    po::store(po::parse_config_file(config, dconfig_options), *conf);
  }
  po::notify(*conf);

  if (conf->count("help") || (conf->count("input") == 0)) {
    cerr << dcmdline_options << endl;
    exit(1);
  }
}

void ReadCorpus(const string& filename, vector<vector<int> >* c, set<int>* vocab) {
  c->clear();
  istream* in;
  if (filename == "-")
    in = &cin;
  else
    in = new ifstream(filename.c_str());
  assert(*in);
  string line;
  while(*in) {
    getline(*in, line);
    if (line.empty() && !*in) break;
    c->push_back(vector<int>());
    vector<int>& v = c->back();
    d.ConvertWhitespaceDelimitedLine(line, &v);
    for (int i = 0; i < v.size(); ++i) vocab->insert(v[i]);
  }
  if (in != &cin) delete in;
}

double log_poisson(unsigned x, const double& lambda) {
  assert(lambda > 0.0);
  return log(lambda) * x - lgamma(x + 1) - lambda;
}

struct UniphraseLM {
  UniphraseLM(const vector<vector<int> >& corpus,
              const set<int>& vocab,
              const po::variables_map& conf) :
    phrases_(1,1),
    gen_(1,1),
    corpus_(corpus),
    uniform_word_(1.0 / vocab.size()),
    gen_p0_(0.5),
    p_end_(0.5),
    use_poisson_(conf.count("poisson_length") > 0) {}

  double p0(const vector<int>& phrase) const {
    static vector<double> p0s(10000, 0.0);
    assert(phrase.size() < 10000);
    double& p = p0s[phrase.size()];
    if (p) return p;
    p = exp(log_p0(phrase));
    if (!p) {
      cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits<double>::min()\n";
      p = std::numeric_limits<double>::min();
    }
    return p;
  }

  double log_p0(const vector<int>& phrase) const {
    double len_logprob;
    if (use_poisson_)
      len_logprob = log_poisson(phrase.size(), 1.0);
    else
      len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_);
    return log(uniform_word_) * phrase.size() + len_logprob;
  }

  double llh() const {
    double llh = gen_.log_crp_prob();
    llh += log(gen_p0_) + log(1 - gen_p0_);
    double llhr = phrases_.log_crp_prob();
    for (CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) {
      llhr += log_p0(it->first);
      //llhr += log_p0(it->first);
      if (!isfinite(llh)) {
        cerr << it->first << endl;
        cerr << log_p0(it->first) << endl;
        abort();
      }
    }
    return llh + llhr;
  }

  void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) {
    cerr << "Initializing...\n";
    z_.resize(corpus_.size());
    int tc = 0;
    for (int i = 0; i < corpus_.size(); ++i) {
      const vector<int>& line = corpus_[i];
      const int ls = line.size();
      const int last_pos = ls - 1;
      vector<bool>& z = z_[i];
      z.resize(ls);
      int prev = 0;
      for (int j = 0; j < ls; ++j) {
        z[j] = rng->next() < 0.5;
        if (j == last_pos) z[j] = true;  // break phrase at the end of the sentence
        if (z[j]) {
          const vector<int> p(line.begin() + prev, line.begin() + j + 1);
          phrases_.increment(p);
          //cerr << p << ": " << p0(p) << endl;
          prev = j + 1;
          gen_.increment(false);
          ++tc; // remove
        }
      }
      ++tc;
      gen_.increment(true); // end of utterance
    }
    cerr << "TC: " << tc << endl;
    cerr << "Initial LLH: " << llh() << endl;
    cerr << "Sampling...\n";
    cerr << gen_ << endl;
    for (int s = 1; s < samples; ++s) {
      cerr << '.';
      if (s % 10 == 0) {
        cerr << " [" << s;
        if (hyp_inf) ResampleHyperparameters(rng);
        cerr << " LLH=" << llh() << "]\n";
        vector<int> z(z_[0].size(), 0);
        //for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j];
        //SegCorpus::Write(corpus_[0], z, d);
      }
      for (int i = 0; i < corpus_.size(); ++i) {
        const vector<int>& line = corpus_[i];
        const int ls = line.size();
        const int last_pos = ls - 1;
        vector<bool>& z = z_[i];
        int prev = 0;
        for (int j = 0; j < last_pos; ++j) { // don't resample last position
          int next = j+1;  while(!z[next]) { ++next; }
          const vector<int> p1p2(line.begin() + prev, line.begin() + next + 1);
          const vector<int> p1(line.begin() + prev, line.begin() + j + 1);
          const vector<int> p2(line.begin() + j + 1, line.begin() + next + 1);

          if (z[j]) {
            phrases_.decrement(p1);
            phrases_.decrement(p2);
            gen_.decrement(false);
            gen_.decrement(false);
          } else {
            phrases_.decrement(p1p2);
            gen_.decrement(false);
          }

          const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_);
          double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_);
          phrases_.increment(p1);
          gen_.increment(false);
          d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_);
          phrases_.decrement(p1);
          gen_.decrement(false);
          z[j] = rng->SelectSample(d1, d2);

          if (z[j]) {
            phrases_.increment(p1);
            phrases_.increment(p2);
            gen_.increment(false);
            gen_.increment(false);
            prev = j + 1;
          } else {
            phrases_.increment(p1p2);
            gen_.increment(false);
          }
        }
      }
    }
//    cerr << endl << endl << gen_ << endl << phrases_ << endl;
    cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl;
  }

  void WriteCdecGrammarForCurrentSample(ostream* os) const {
    CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin();
    for (; it != phrases_.end(); ++it) {
      (*os) << "[X] ||| " << Join(' ', it->first) << " ||| "
                          << Join('_', it->first) << " ||| C=1 P=" 
                          << log(phrases_.prob(it->first, p0(it->first))) << endl;
    }
  }

  double OOVUnigramLogProb() const {
    vector<int> x(1,99999999);
    return log(phrases_.prob(x, p0(x)));
  }

  void ResampleHyperparameters(MT19937* rng) {
    phrases_.resample_hyperparameters(rng);
    gen_.resample_hyperparameters(rng);
    cerr << " " << phrases_.concentration();
  }

  CCRP_NoTable<vector<int> > phrases_;
  CCRP_NoTable<bool> gen_;
  vector<vector<bool> > z_;   // z_[i] is there a phrase boundary after the ith word
  const vector<vector<int> >& corpus_;
  const double uniform_word_;
  const double gen_p0_;
  const double p_end_; // in base length distribution, p of the end of a phrase
  const bool use_poisson_;
};


int main(int argc, char** argv) {
  po::variables_map conf;
  InitCommandLine(argc, argv, &conf);
  boost::shared_ptr<MT19937> prng;
  if (conf.count("random_seed"))
    prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
  else
    prng.reset(new MT19937);
  MT19937& rng = *prng;

  vector<vector<int> > corpus;
  set<int> vocab;
  ReadCorpus(conf["input"].as<string>(), &corpus, &vocab);
  cerr << "Corpus size: " << corpus.size() << " sentences\n";
  cerr << "Vocabulary size: " << vocab.size() << " types\n";

  UniphraseLM ulm(corpus, vocab, conf);
  ulm.Sample(conf["samples"].as<unsigned>(), conf.count("no_hyperparameter_inference") == 0, &rng);
  cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl;

  for (int i = 0; i < corpus.size(); ++i)
    WriteSeg(corpus[i], ulm.z_[i], d);

  if (conf.count("write_cdec_grammar")) {
    string fname = conf["write_cdec_grammar"].as<string>();
    cerr << "Writing model to " << fname << " ...\n";
    WriteFile wf(fname);
    ulm.WriteCdecGrammarForCurrentSample(wf.stream());
  }

  if (conf.count("write_cdec_weights")) {
    string fname = conf["write_cdec_weights"].as<string>();
    cerr << "Writing weights to " << fname << " .\n";
    WriteFile wf(fname);
    ostream& os = *wf.stream();
    os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n";
  }

  

  return 0;
}