#ifndef LM_SEARCH_TRIE_H
#define LM_SEARCH_TRIE_H

#include "lm/config.hh"
#include "lm/model_type.hh"
#include "lm/return.hh"
#include "lm/trie.hh"
#include "lm/weights.hh"

#include "util/file.hh"
#include "util/file_piece.hh"

#include <vector>
#include <cstdlib>

#include <assert.h>

namespace lm {
namespace ngram {
class BinaryFormat;
class SortedVocabulary;
namespace trie {

template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);

template <class Quant, class Bhiksha> class TrieSearch {
  public:
    typedef NodeRange Node;

    typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
    typedef typename Quant::MiddlePointer MiddlePointer;
    typedef typename Quant::LongestPointer LongestPointer;

    static const bool kDifferentRest = false;

    static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);

    static const unsigned int kVersion = 1;

    static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
      Quant::UpdateConfigFromBinary(file, offset, config);
      // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
      if (counts.size() > 2)
        Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
    }

    static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
      uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
      for (unsigned char i = 1; i < counts.size() - 1; ++i) {
        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
      }
      return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
    }

    TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}

    ~TrieSearch() { FreeMiddles(); }

    uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);

    void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);

    unsigned char Order() const {
      return middle_end_ - middle_begin_ + 2;
    }

    ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }

    UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
      extend_left = static_cast<uint64_t>(word);
      UnigramPointer ret(unigram_.Find(word, next));
      independent_left = (next.begin == next.end);
      return ret;
    }

    MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
      return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
    }

    MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
      util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
      independent_left = (address.base == NULL) || (node.begin == node.end);
      return MiddlePointer(quant_, order_minus_2, address);
    }

    LongestPointer LookupLongest(WordIndex word, const Node &node) const {
      return LongestPointer(quant_, longest_.Find(word, node));
    }

    bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
      assert(begin != end);
      bool independent_left;
      uint64_t ignored;
      LookupUnigram(*begin, node, independent_left, ignored);
      for (const WordIndex *i = begin + 1; i < end; ++i) {
        if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
      }
      return true;
    }

  private:
    friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);

    // Middles are managed manually so we can delay construction and they don't have to be copyable.
    void FreeMiddles() {
      for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
        i->~Middle();
      }
      std::free(middle_begin_);
    }

    typedef trie::BitPackedMiddle<Bhiksha> Middle;

    typedef trie::BitPackedLongest Longest;
    Longest longest_;

    Middle *middle_begin_, *middle_end_;
    Quant quant_;

    typedef ::lm::ngram::trie::Unigram Unigram;
    Unigram unigram_;
};

} // namespace trie
} // namespace ngram
} // namespace lm

#endif // LM_SEARCH_TRIE_H