#include <gtest/gtest.h>

#include <memory>

#include "fast_intersector.h"
#include "mocks/mock_data_array.h"
#include "mocks/mock_precomputation.h"
#include "mocks/mock_suffix_array.h"
#include "mocks/mock_vocabulary.h"
#include "phrase.h"
#include "phrase_location.h"
#include "phrase_builder.h"

using namespace std;
using namespace ::testing;

namespace extractor {
namespace {

class FastIntersectorTest : public Test {
 protected:
  virtual void SetUp() {
    vector<string> words = {"EOL", "it", "makes", "him", "and", "mars", ",",
                            "sets", "on", "takes", "off", "."};
    vocabulary = make_shared<MockVocabulary>();
    for (size_t i = 0; i < words.size(); ++i) {
      EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i]))
          .WillRepeatedly(Return(i));
      EXPECT_CALL(*vocabulary, GetTerminalValue(i))
          .WillRepeatedly(Return(words[i]));
    }

    vector<int> data = {1, 2, 3, 4, 1, 5, 3, 6, 1,
                        7, 3, 8, 4, 1, 9, 3, 10, 11, 0};
    data_array = make_shared<MockDataArray>();
    for (size_t i = 0; i < data.size(); ++i) {
      EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
      EXPECT_CALL(*data_array, GetSentenceId(i))
          .WillRepeatedly(Return(0));
    }
    EXPECT_CALL(*data_array, GetSentenceStart(0))
        .WillRepeatedly(Return(0));
    EXPECT_CALL(*data_array, GetSentenceStart(1))
        .WillRepeatedly(Return(19));
    for (size_t i = 0; i < words.size(); ++i) {
      EXPECT_CALL(*data_array, GetWordId(words[i]))
          .WillRepeatedly(Return(i));
      EXPECT_CALL(*data_array, GetWord(i))
          .WillRepeatedly(Return(words[i]));
    }

    vector<int> suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9,
                            11, 14, 16, 17};
    suffix_array = make_shared<MockSuffixArray>();
    EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
    for (size_t i = 0; i < suffixes.size(); ++i) {
      EXPECT_CALL(*suffix_array, GetSuffix(i)).
          WillRepeatedly(Return(suffixes[i]));
    }

    precomputation = make_shared<MockPrecomputation>();
    EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false));

    phrase_builder = make_shared<PhraseBuilder>(vocabulary);
    intersector = make_shared<FastIntersector>(suffix_array, precomputation,
                                               vocabulary, 15, 1);
  }

  shared_ptr<MockDataArray> data_array;
  shared_ptr<MockSuffixArray> suffix_array;
  shared_ptr<MockPrecomputation> precomputation;
  shared_ptr<MockVocabulary> vocabulary;
  shared_ptr<FastIntersector> intersector;
  shared_ptr<PhraseBuilder> phrase_builder;
};

TEST_F(FastIntersectorTest, TestCachedCollocation) {
  vector<int> symbols = {8, -1, 9};
  vector<int> expected_location = {11};
  Phrase phrase = phrase_builder->Build(symbols);
  PhraseLocation prefix_location(15, 16), suffix_location(16, 17);

  EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true));
  EXPECT_CALL(*precomputation, GetCollocations(symbols)).
      WillRepeatedly(Return(expected_location));
  intersector = make_shared<FastIntersector>(suffix_array, precomputation,
                                             vocabulary, 15, 1);

  PhraseLocation result = intersector->Intersect(
      prefix_location, suffix_location, phrase);

  EXPECT_EQ(PhraseLocation(expected_location, 2), result);
  EXPECT_EQ(PhraseLocation(15, 16), prefix_location);
  EXPECT_EQ(PhraseLocation(16, 17), suffix_location);
}

TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) {
  vector<int> symbols = {1, -1, 3, -1, 1};
  Phrase phrase = phrase_builder->Build(symbols);
  vector<int> prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
                             8, 15, 3, 15};
  vector<int> suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13};
  PhraseLocation prefix_location(prefix_locs, 2);
  PhraseLocation suffix_location(suffix_locs, 2);

  vector<int> expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8,
                               4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13,
                               0, 10, 13};
  PhraseLocation result = intersector->Intersect(
      prefix_location, suffix_location, phrase);
  EXPECT_EQ(PhraseLocation(expected_locs, 3), result);
}

TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) {
  vector<int> symbols = {1, -1, 3};
  Phrase phrase = phrase_builder->Build(symbols);
  PhraseLocation prefix_location(1, 5), suffix_location(6, 10);

  vector<int> expected_prefix_locs = {0, 4, 8, 13};
  vector<int> expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
                               8, 15, 13, 15};
  PhraseLocation result = intersector->Intersect(
      prefix_location, suffix_location, phrase);
  EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
  EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location);
}

TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) {
  // The suffix matches in fewer positions, but because it starts with an X
  // it requires more operations and we prefer extending the prefix.
  vector<int> symbols = {1, -1, 4, 1};
  Phrase phrase = phrase_builder->Build(symbols);
  vector<int> prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12};
  PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12);

  vector<int> expected_locs = {0, 3, 0, 12, 4, 12, 8, 12};
  PhraseLocation result = intersector->Intersect(
      prefix_location, suffix_location, phrase);
  EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
  EXPECT_EQ(PhraseLocation(10, 12), suffix_location);
}

} // namespace
} // namespace extractor