/*
 * tsuruoka_maxent.h
 *
 */

#ifndef TSURUOKA_MAXENT_H_
#define TSURUOKA_MAXENT_H_

#include <assert.h>
#include <string.h>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "synutils.h"
#include "stringlib.h"
#include "maxent.h"

typedef std::unordered_map<std::string, int> Map;
typedef std::unordered_map<std::string, int>::iterator Iterator;

struct Tsuruoka_Maxent {
  Tsuruoka_Maxent(const char* pszModelFName) {
    if (pszModelFName != NULL) {
      m_pModel = new ME_Model();
      m_pModel->load_from_file(pszModelFName);
    } else
      m_pModel = NULL;
  }

  ~Tsuruoka_Maxent() {
    if (m_pModel != NULL) delete m_pModel;
  }

  void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm,
               const char* pszModelFName, int iNumIteration) {
    assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 ||
           strcmp(pszAlgorithm, "sgd") == 0 ||
           strcmp(pszAlgorithm, "SGD") == 0);
    FILE* fpIn = fopen(pszInstanceFName, "r");

    ME_Model* pModel = new ME_Model();

    char* pszLine = new char[100001];
    int iNumInstances = 0;
    int iLen;
    while (!feof(fpIn)) {
      pszLine[0] = '\0';
      fgets(pszLine, 20000, fpIn);
      if (strlen(pszLine) == 0) {
        continue;
      }

      iLen = strlen(pszLine);
      while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) {
        pszLine[iLen - 1] = '\0';
        iLen--;
      }

      iNumInstances++;

      ME_Sample* pmes = new ME_Sample();

      char* p = strrchr(pszLine, ' ');
      assert(p != NULL);
      p[0] = '\0';
      p++;
      std::vector<std::string> vecContext;
      SplitOnWhitespace(std::string(pszLine), &vecContext);

      pmes->label = std::string(p);
      for (size_t i = 0; i < vecContext.size(); i++)
        pmes->add_feature(vecContext[i]);
      pModel->add_training_sample((*pmes));
      if (iNumInstances % 100000 == 0)
        fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
      delete pmes;
    }
    fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances);
    fclose(fpIn);

    if (strcmp(pszAlgorithm, "l1") == 0)
      pModel->use_l1_regularizer(1.0);
    else if (strcmp(pszAlgorithm, "l2") == 0)
      pModel->use_l2_regularizer(1.0);
    else
      pModel->use_SGD();

    pModel->train();
    pModel->save_to_file(pszModelFName);

    delete pModel;
    fprintf(stdout, "......Finished Training\n");
    fprintf(stdout, "......Model saved as %s\n", pszModelFName);
    delete[] pszLine;
  }

  double fnEval(const char* pszContext, const char* pszOutcome) const {
    std::vector<std::string> vecContext;
    ME_Sample* pmes = new ME_Sample();
    SplitOnWhitespace(std::string(pszContext), &vecContext);

    for (size_t i = 0; i < vecContext.size(); i++)
      pmes->add_feature(vecContext[i]);
    std::vector<double> vecProb = m_pModel->classify(*pmes);
    delete pmes;
    int iLableID = m_pModel->get_class_id(pszOutcome);
    return vecProb[iLableID];
  }
  void fnEval(const char* pszContext,
              std::vector<std::pair<std::string, double> >& vecOutput) const {
    std::vector<std::string> vecContext;
    ME_Sample* pmes = new ME_Sample();
    SplitOnWhitespace(std::string(pszContext), &vecContext);

    vecOutput.clear();

    for (size_t i = 0; i < vecContext.size(); i++)
      pmes->add_feature(vecContext[i]);
    std::vector<double> vecProb = m_pModel->classify(*pmes);

    for (size_t i = 0; i < vecProb.size(); i++) {
      std::string label = m_pModel->get_class_label(i);
      vecOutput.push_back(make_pair(label, vecProb[i]));
    }
    delete pmes;
  }
  void fnEval(const char* pszContext, std::vector<double>& vecOutput) const {
    std::vector<std::string> vecContext;
    ME_Sample* pmes = new ME_Sample();
    SplitOnWhitespace(std::string(pszContext), &vecContext);

    vecOutput.clear();

    for (size_t i = 0; i < vecContext.size(); i++)
      pmes->add_feature(vecContext[i]);
    std::vector<double> vecProb = m_pModel->classify(*pmes);

    for (size_t i = 0; i < vecProb.size(); i++) {
      std::string label = m_pModel->get_class_label(i);
      vecOutput.push_back(vecProb[i]);
    }
    delete pmes;
  }
  int fnGetClassId(const std::string& strLabel) const {
    return m_pModel->get_class_id(strLabel);
  }

 private:
  ME_Model* m_pModel;
};

#endif /* TSURUOKA_MAXENT_H_ */