From f528ac27dab11770f01595b043675dba2947a263 Mon Sep 17 00:00:00 2001
From: Paul Baltescu <pauldb89@gmail.com>
Date: Sun, 24 Nov 2013 13:19:28 +0000
Subject: Reduce memory overhead for constructing the intersector.

---
 extractor/compile.cc                  |   4 ++
 extractor/data_array.cc               |  14 ++++-
 extractor/data_array.h                |  10 +++-
 extractor/data_array_test.cc          |  12 ++++
 extractor/fast_intersector.cc         |  40 ++++---------
 extractor/fast_intersector.h          |   8 +--
 extractor/fast_intersector_test.cc    |  10 ++--
 extractor/grammar_extractor.cc        |   5 +-
 extractor/grammar_extractor.h         |   1 +
 extractor/mocks/mock_data_array.h     |   4 +-
 extractor/mocks/mock_precomputation.h |   3 +-
 extractor/precomputation.cc           |  96 +++++++++++++++--------------
 extractor/precomputation.h            |  45 +++++++-------
 extractor/precomputation_test.cc      | 110 ++++++++++++++++++++++++----------
 extractor/run_extractor.cc            |   5 ++
 extractor/suffix_array_test.cc        |   2 +-
 extractor/translation_table_test.cc   |   4 +-
 17 files changed, 225 insertions(+), 148 deletions(-)

(limited to 'extractor')

diff --git a/extractor/compile.cc b/extractor/compile.cc
index 65fdd509..0d62757e 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -13,6 +13,7 @@
 #include "suffix_array.h"
 #include "time_util.h"
 #include "translation_table.h"
+#include "vocabulary.h"
 
 namespace ar = boost::archive;
 namespace fs = boost::filesystem;
@@ -125,9 +126,12 @@ int main(int argc, char** argv) {
   cerr << "Reading alignment took "
        << GetDuration(start_time, stop_time) << " seconds" << endl;
 
+  shared_ptr<Vocabulary> vocabulary;
+
   start_time = Clock::now();
   cerr << "Precomputing collocations..." << endl;
   Precomputation precomputation(
+      vocabulary,
       source_suffix_array,
       vm["frequent"].as<int>(),
       vm["super_frequent"].as<int>(),
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 82efcd51..dacc4283 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -78,7 +78,7 @@ void DataArray::CreateDataArray(const vector<string>& lines) {
 
 DataArray::~DataArray() {}
 
-const vector<int>& DataArray::GetData() const {
+vector<int> DataArray::GetData() const {
   return data;
 }
 
@@ -90,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const {
   return id2word[data[index]];
 }
 
+vector<int> DataArray::GetWordIds(int index, int size) const {
+  return vector<int>(data.begin() + index, data.begin() + index + size);
+}
+
+vector<string> DataArray::GetWords(int start_index, int size) const {
+  vector<string> words;
+  for (int word_id: GetWordIds(start_index, size)) {
+    words.push_back(id2word[word_id]);
+  }
+  return words;
+}
+
 int DataArray::GetSize() const {
   return data.size();
 }
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 5207366d..e3823d18 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -51,7 +51,7 @@ class DataArray {
   virtual ~DataArray();
 
   // Returns a vector containing the word ids.
-  virtual const vector<int>& GetData() const;
+  virtual vector<int> GetData() const;
 
   // Returns the word id at the specified position.
   virtual int AtIndex(int index) const;
@@ -59,6 +59,14 @@ class DataArray {
   // Returns the original word at the specified position.
   virtual string GetWordAtIndex(int index) const;
 
+  // Returns the substring of word ids starting at the specified position and
+  // having the specified length.
+  virtual vector<int> GetWordIds(int start_index, int size) const;
+
+  // Returns the substring of words starting at the specified position and
+  // having the specified length.
+  virtual vector<string> GetWords(int start_index, int size) const;
+
   // Returns the size of the data array.
   virtual int GetSize() const;
 
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 6c329e34..7b085cd9 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -56,6 +56,18 @@ TEST_F(DataArrayTest, TestGetData) {
   }
 }
 
+TEST_F(DataArrayTest, TestSubstrings) {
+  vector<int> expected_word_ids = {3, 4, 5};
+  vector<string> expected_words = {"are", "mere", "."};
+  EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3));
+  EXPECT_EQ(expected_words, source_data.GetWords(1, 3));
+
+  expected_word_ids = {7, 8};
+  expected_words = {"a", "lot"};
+  EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2));
+  EXPECT_EQ(expected_words, target_data.GetWords(7, 2));
+}
+
 TEST_F(DataArrayTest, TestVocabulary) {
   EXPECT_EQ(9, source_data.GetVocabularySize());
   EXPECT_TRUE(source_data.HasWord("mere"));
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
index a8591a72..0d1fa6d8 100644
--- a/extractor/fast_intersector.cc
+++ b/extractor/fast_intersector.cc
@@ -11,41 +11,22 @@
 
 namespace extractor {
 
-FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array,
-                                 shared_ptr<Precomputation> precomputation,
-                                 shared_ptr<Vocabulary> vocabulary,
-                                 int max_rule_span,
-                                 int min_gap_size) :
+FastIntersector::FastIntersector(
+    shared_ptr<SuffixArray> suffix_array,
+    shared_ptr<Precomputation> precomputation,
+    shared_ptr<Vocabulary> vocabulary,
+    int max_rule_span,
+    int min_gap_size) :
     suffix_array(suffix_array),
+    precomputation(precomputation),
     vocabulary(vocabulary),
     max_rule_span(max_rule_span),
-    min_gap_size(min_gap_size) {
-  Index precomputed_collocations = precomputation->GetCollocations();
-  for (pair<vector<int>, vector<int>> entry: precomputed_collocations) {
-    vector<int> phrase = ConvertPhrase(entry.first);
-    collocations[phrase] = entry.second;
-  }
-}
+    min_gap_size(min_gap_size) {}
 
 FastIntersector::FastIntersector() {}
 
 FastIntersector::~FastIntersector() {}
 
-vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) {
-  vector<int> new_phrase;
-  new_phrase.reserve(old_phrase.size());
-  shared_ptr<DataArray> data_array = suffix_array->GetData();
-  for (int word_id: old_phrase) {
-    if (word_id < 0) {
-      new_phrase.push_back(word_id);
-    } else {
-      new_phrase.push_back(
-          vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
-    }
-  }
-  return new_phrase;
-}
-
 PhraseLocation FastIntersector::Intersect(
     PhraseLocation& prefix_location,
     PhraseLocation& suffix_location,
@@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect(
   assert(vocabulary->IsTerminal(symbols.front())
       && vocabulary->IsTerminal(symbols.back()));
 
-  if (collocations.count(symbols)) {
-    return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
+  if (precomputation->Contains(symbols)) {
+    return PhraseLocation(precomputation->GetCollocations(symbols),
+                          phrase.Arity() + 1);
   }
 
   bool prefix_ends_with_x =
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
index 2819d239..305373dc 100644
--- a/extractor/fast_intersector.h
+++ b/extractor/fast_intersector.h
@@ -12,7 +12,6 @@ using namespace std;
 namespace extractor {
 
 typedef boost::hash<vector<int>> VectorHash;
-typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
 
 class Phrase;
 class PhraseLocation;
@@ -52,11 +51,6 @@ class FastIntersector {
   FastIntersector();
 
  private:
-  // Uses the vocabulary to convert the phrase from the numberized format
-  // specified by the source data array to the numberized format given by the
-  // vocabulary.
-  vector<int> ConvertPhrase(const vector<int>& old_phrase);
-
   // Estimates the number of computations needed if the prefix/suffix is
   // extended. If the last/first symbol is separated from the rest of the phrase
   // by a nonterminal, then for each occurrence of the prefix/suffix we need to
@@ -85,10 +79,10 @@ class FastIntersector {
   pair<int, int> GetSearchRange(bool has_marginal_x) const;
 
   shared_ptr<SuffixArray> suffix_array;
+  shared_ptr<Precomputation> precomputation;
   shared_ptr<Vocabulary> vocabulary;
   int max_rule_span;
   int min_gap_size;
-  Index collocations;
 };
 
 } // namespace extractor
diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc
index 76c3aaea..f2a26ba1 100644
--- a/extractor/fast_intersector_test.cc
+++ b/extractor/fast_intersector_test.cc
@@ -59,15 +59,13 @@ class FastIntersectorTest : public Test {
     }
 
     precomputation = make_shared<MockPrecomputation>();
-    EXPECT_CALL(*precomputation, GetCollocations())
-        .WillRepeatedly(ReturnRef(collocations));
+    EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false));
 
     phrase_builder = make_shared<PhraseBuilder>(vocabulary);
     intersector = make_shared<FastIntersector>(suffix_array, precomputation,
                                                vocabulary, 15, 1);
   }
 
-  Index collocations;
   shared_ptr<MockDataArray> data_array;
   shared_ptr<MockSuffixArray> suffix_array;
   shared_ptr<MockPrecomputation> precomputation;
@@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) {
   Phrase phrase = phrase_builder->Build(symbols);
   PhraseLocation prefix_location(15, 16), suffix_location(16, 17);
 
-  collocations[symbols] = expected_location;
-  EXPECT_CALL(*precomputation, GetCollocations())
-      .WillRepeatedly(ReturnRef(collocations));
+  EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true));
+  EXPECT_CALL(*precomputation, GetCollocations(symbols)).
+      WillRepeatedly(Return(expected_location));
   intersector = make_shared<FastIntersector>(suffix_array, precomputation,
                                              vocabulary, 15, 1);
 
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 487abcaf..4d0738f7 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor(
     shared_ptr<SuffixArray> source_suffix_array,
     shared_ptr<DataArray> target_data_array,
     shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
-    shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
+    shared_ptr<Scorer> scorer, shared_ptr<Vocabulary> vocabulary,
+    int min_gap_size, int max_rule_span,
     int max_nonterminals, int max_rule_symbols, int max_samples,
     bool require_tight_phrases) :
-    vocabulary(make_shared<Vocabulary>()),
+    vocabulary(vocabulary),
     rule_factory(make_shared<HieroCachingRuleFactory>(
         source_suffix_array, target_data_array, alignment, vocabulary,
         precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index ae407b47..8f570df2 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -32,6 +32,7 @@ class GrammarExtractor {
       shared_ptr<Alignment> alignment,
       shared_ptr<Precomputation> precomputation,
       shared_ptr<Scorer> scorer,
+      shared_ptr<Vocabulary> vocabulary,
       int min_gap_size,
       int max_rule_span,
       int max_nonterminals,
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index 6f85abb4..4bdcf21f 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -6,9 +6,11 @@ namespace extractor {
 
 class MockDataArray : public DataArray {
  public:
-  MOCK_CONST_METHOD0(GetData, const vector<int>&());
+  MOCK_CONST_METHOD0(GetData, vector<int>());
   MOCK_CONST_METHOD1(AtIndex, int(int index));
   MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
+  MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size));
+  MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size));
   MOCK_CONST_METHOD0(GetSize, int());
   MOCK_CONST_METHOD0(GetVocabularySize, int());
   MOCK_CONST_METHOD1(HasWord, bool(const string& word));
diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h
index 8753343e..5f7aa999 100644
--- a/extractor/mocks/mock_precomputation.h
+++ b/extractor/mocks/mock_precomputation.h
@@ -6,7 +6,8 @@ namespace extractor {
 
 class MockPrecomputation : public Precomputation {
  public:
-  MOCK_CONST_METHOD0(GetCollocations, const Index&());
+  MOCK_CONST_METHOD1(Contains, bool(const vector<int>& pattern));
+  MOCK_CONST_METHOD1(GetCollocations, vector<int>(const vector<int>& pattern));
 };
 
 } // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 3b8aed69..38d8f489 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -5,22 +5,21 @@
 
 #include "data_array.h"
 #include "suffix_array.h"
+#include "vocabulary.h"
 
 using namespace std;
 
 namespace extractor {
 
-int Precomputation::FIRST_NONTERMINAL = -1;
-int Precomputation::SECOND_NONTERMINAL = -2;
+int Precomputation::NONTERMINAL = -1;
 
 Precomputation::Precomputation(
-    shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
-    int num_super_frequent_patterns, int max_rule_span,
-    int max_rule_symbols, int min_gap_size,
+    shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+    int num_frequent_patterns, int num_super_frequent_patterns,
+    int max_rule_span, int max_rule_symbols, int min_gap_size,
     int max_frequent_phrase_len, int min_frequency) {
-  vector<int> data = suffix_array->GetData()->GetData();
   vector<vector<int>> frequent_patterns = FindMostFrequentPatterns(
-      suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
+      suffix_array, num_frequent_patterns, max_frequent_phrase_len,
       min_frequency);
 
   // Construct sets containing the frequent and superfrequent contiguous
@@ -34,28 +33,30 @@ Precomputation::Precomputation(
     }
   }
 
+  shared_ptr<DataArray> data_array = suffix_array->GetData();
   vector<tuple<int, int, int>> matchings;
-  for (size_t i = 0; i < data.size(); ++i) {
+  for (size_t i = 0; i < data_array->GetSize(); ++i) {
     // If the sentence is over, add all the discontiguous frequent patterns to
     // the index.
-    if (data[i] == DataArray::END_OF_LINE) {
-      AddCollocations(matchings, data, max_rule_span, min_gap_size,
-          max_rule_symbols);
+    if (data_array->AtIndex(i) == DataArray::END_OF_LINE) {
+      UpdateIndex(data_array, vocabulary, matchings, max_rule_span,
+                  min_gap_size, max_rule_symbols);
       matchings.clear();
       continue;
     }
-    vector<int> pattern;
     // Find all the contiguous frequent patterns starting at position i.
-    for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
-      pattern.push_back(data[i + j - 1]);
-      if (frequent_patterns_set.count(pattern)) {
-        int is_super_frequent = super_frequent_patterns_set.count(pattern);
-        matchings.push_back(make_tuple(i, j, is_super_frequent));
-      } else {
+    vector<int> pattern;
+    for (int j = 1;
+         j <= max_frequent_phrase_len && i + j <= data_array->GetSize();
+         ++j) {
+      pattern.push_back(data_array->AtIndex(i + j - 1));
+      if (!frequent_patterns_set.count(pattern)) {
         // If the current pattern is not frequent, any longer pattern having the
         // current pattern as prefix will not be frequent.
         break;
       }
+      int is_super_frequent = super_frequent_patterns_set.count(pattern);
+      matchings.push_back(make_tuple(i, j, is_super_frequent));
     }
   }
 }
@@ -65,8 +66,8 @@ Precomputation::Precomputation() {}
 Precomputation::~Precomputation() {}
 
 vector<vector<int>> Precomputation::FindMostFrequentPatterns(
-    shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
-    int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) {
+    shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
+    int max_frequent_phrase_len, int min_frequency) {
   vector<int> lcp = suffix_array->BuildLCPArray();
   vector<int> run_start(max_frequent_phrase_len);
 
@@ -83,6 +84,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
     }
   }
 
+  shared_ptr<DataArray> data_array = suffix_array->GetData();
   // Extract the most frequent patterns.
   vector<vector<int>> frequent_patterns;
   while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
@@ -90,7 +92,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
     int len = heap.top().second.second;
     heap.pop();
 
-    vector<int> pattern(data.begin() + start, data.begin() + start + len);
+    vector<int> pattern = data_array->GetWordIds(start, len);
     if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==
         pattern.end()) {
       frequent_patterns.push_back(pattern);
@@ -99,8 +101,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
   return frequent_patterns;
 }
 
-void Precomputation::AddCollocations(
-    const vector<tuple<int, int, int>>& matchings, const vector<int>& data,
+void Precomputation::UpdateIndex(
+    shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary,
+    const vector<tuple<int, int, int>>& matchings,
     int max_rule_span, int min_gap_size, int max_rule_symbols) {
   // Select the leftmost subpattern.
   for (size_t i = 0; i < matchings.size(); ++i) {
@@ -118,16 +121,15 @@ void Precomputation::AddCollocations(
       if (start2 - start1 - size1 >= min_gap_size
           && start2 + size2 - start1 <= max_rule_span
           && size1 + size2 + 1 <= max_rule_symbols) {
-        vector<int> pattern(data.begin() + start1,
-            data.begin() + start1 + size1);
-        pattern.push_back(Precomputation::FIRST_NONTERMINAL);
-        pattern.insert(pattern.end(), data.begin() + start2,
-            data.begin() + start2 + size2);
-        AddStartPositions(collocations[pattern], start1, start2);
+        vector<int> pattern;
+        AppendSubpattern(pattern, data_array, vocabulary, start1, size1);
+        pattern.push_back(Precomputation::NONTERMINAL);
+        AppendSubpattern(pattern, data_array, vocabulary, start2, size2);
+        AppendCollocation(index[pattern], {start1, start2});
 
         // Try extending the binary collocation to a ternary collocation.
         if (is_super2) {
-          pattern.push_back(Precomputation::SECOND_NONTERMINAL);
+          pattern.push_back(Precomputation::NONTERMINAL);
           // Select the rightmost subpattern.
           for (size_t k = j + 1; k < matchings.size(); ++k) {
             int start3, size3, is_super3;
@@ -140,9 +142,8 @@ void Precomputation::AddCollocations(
                 && start3 + size3 - start1 <= max_rule_span
                 && size1 + size2 + size3 + 2 <= max_rule_symbols
                 && (is_super1 || is_super3)) {
-              pattern.insert(pattern.end(), data.begin() + start3,
-                  data.begin() + start3 + size3);
-              AddStartPositions(collocations[pattern], start1, start2, start3);
+              AppendSubpattern(pattern, data_array, vocabulary, start3, size3);
+              AppendCollocation(index[pattern], {start1, start2, start3});
               pattern.erase(pattern.end() - size3);
             }
           }
@@ -152,25 +153,30 @@ void Precomputation::AddCollocations(
   }
 }
 
-void Precomputation::AddStartPositions(
-    vector<int>& positions, int pos1, int pos2) {
-  positions.push_back(pos1);
-  positions.push_back(pos2);
+void Precomputation::AppendSubpattern(
+    vector<int>& pattern, shared_ptr<DataArray> data_array,
+    shared_ptr<Vocabulary> vocabulary, int start, int size) {
+  vector<string> words = data_array->GetWords(start, size);
+  for (const string& word: words) {
+    pattern.push_back(vocabulary->GetTerminalIndex(word));
+  }
+}
+
+void Precomputation::AppendCollocation(
+    vector<int>& collocations, const vector<int>& collocation) {
+  copy(collocation.begin(), collocation.end(), back_inserter(collocations));
 }
 
-void Precomputation::AddStartPositions(
-    vector<int>& positions, int pos1, int pos2, int pos3) {
-  positions.push_back(pos1);
-  positions.push_back(pos2);
-  positions.push_back(pos3);
+bool Precomputation::Contains(const vector<int>& pattern) const {
+  return index.count(pattern);
 }
 
-const Index& Precomputation::GetCollocations() const {
-  return collocations;
+vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const {
+  return index.at(pattern);
 }
 
 bool Precomputation::operator==(const Precomputation& other) const {
-  return collocations == other.collocations;
+  return index == other.index;
 }
 
 } // namespace extractor
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index e5fa3e37..6ade58df 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -19,7 +19,9 @@ namespace extractor {
 typedef boost::hash<vector<int>> VectorHash;
 typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
 
+class DataArray;
 class SuffixArray;
+class Vocabulary;
 
 /**
  * Data structure wrapping an index with all the occurrences of the most
@@ -35,9 +37,9 @@ class Precomputation {
  public:
   // Constructs the index using the suffix array.
   Precomputation(
-      shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
-      int num_super_frequent_patterns, int max_rule_span,
-      int max_rule_symbols, int min_gap_size,
+      shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+      int num_frequent_patterns, int num_super_frequent_patterns,
+      int max_rule_span, int max_rule_symbols, int min_gap_size,
       int max_frequent_phrase_len, int min_frequency);
 
   // Creates empty precomputation data structure.
@@ -45,40 +47,43 @@ class Precomputation {
 
   virtual ~Precomputation();
 
-  // Returns a reference to the index.
-  virtual const Index& GetCollocations() const;
+  // Returns whether a pattern is contained in the index of collocations.
+  virtual bool Contains(const vector<int>& pattern) const;
+
+  // Returns the list of collocations for a given pattern.
+  virtual vector<int> GetCollocations(const vector<int>& pattern) const;
 
   bool operator==(const Precomputation& other) const;
 
-  static int FIRST_NONTERMINAL;
-  static int SECOND_NONTERMINAL;
+  static int NONTERMINAL;
 
  private:
   // Finds the most frequent contiguous collocations.
   vector<vector<int>> FindMostFrequentPatterns(
-      shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
-      int num_frequent_patterns, int max_frequent_phrase_len,
-      int min_frequency);
+      shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
+      int max_frequent_phrase_len, int min_frequency);
 
   // Given the locations of the frequent contiguous collocations in a sentence,
   // it adds new entries to the index for each discontiguous collocation
   // matching the criteria specified in the class description.
-  void AddCollocations(
-      const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data,
+  void UpdateIndex(
+      shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary,
+      const vector<tuple<int, int, int>>& matchings,
       int max_rule_span, int min_gap_size, int max_rule_symbols);
 
-  // Adds an occurrence of a binary collocation.
-  void AddStartPositions(vector<int>& positions, int pos1, int pos2);
+  void AppendSubpattern(
+      vector<int>& pattern, shared_ptr<DataArray> data_array,
+      shared_ptr<Vocabulary> vocabulary, int start, int size);
 
-  // Adds an occurrence of a ternary collocation.
-  void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
+  // Adds an occurrence of a collocation.
+  void AppendCollocation(vector<int>& collocations, const vector<int>& collocation);
 
   friend class boost::serialization::access;
 
   template<class Archive> void save(Archive& ar, unsigned int) const {
-    int num_entries = collocations.size();
+    int num_entries = index.size();
     ar << num_entries;
-    for (pair<vector<int>, vector<int>> entry: collocations) {
+    for (pair<vector<int>, vector<int>> entry: index) {
       ar << entry;
     }
   }
@@ -89,13 +94,13 @@ class Precomputation {
     for (size_t i = 0; i < num_entries; ++i) {
       pair<vector<int>, vector<int>> entry;
       ar >> entry;
-      collocations.insert(entry);
+      index.insert(entry);
     }
   }
 
   BOOST_SERIALIZATION_SPLIT_MEMBER();
 
-  Index collocations;
+  Index index;
 };
 
 } // namespace extractor
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index e81ece5d..fd85fcf8 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -9,6 +9,7 @@
 
 #include "mocks/mock_data_array.h"
 #include "mocks/mock_suffix_array.h"
+#include "mocks/mock_vocabulary.h"
 #include "precomputation.h"
 
 using namespace std;
@@ -23,7 +24,31 @@ class PrecomputationTest : public Test {
   virtual void SetUp() {
     data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
     data_array = make_shared<MockDataArray>();
-    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+    EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(data.size()));
+    for (size_t i = 0; i < data.size(); ++i) {
+      EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+    }
+    vector<pair<int, int>> expected_calls = {{8, 1}, {8, 2}, {6, 1}};
+    for (const auto& call: expected_calls) {
+      int start = call.first;
+      int size = call.second;
+      vector<int> word_ids(data.begin() + start, data.begin() + start + size);
+      EXPECT_CALL(*data_array, GetWordIds(start, size))
+          .WillRepeatedly(Return(word_ids));
+    }
+
+    expected_calls = {{1, 1}, {5, 1}, {8, 1}, {9, 1}, {5, 2},
+                      {6, 1}, {8, 2}, {1, 2}, {2, 1}, {11, 1}};
+    for (const auto& call: expected_calls) {
+      int start = call.first;
+      int size = call.second;
+      vector<string> words;
+      for (size_t j = start; j < start + size; ++j) {
+        words.push_back(to_string(data[j]));
+      }
+      EXPECT_CALL(*data_array, GetWords(start, size))
+          .WillRepeatedly(Return(words));
+    }
 
     vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
     vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
@@ -35,77 +60,98 @@ class PrecomputationTest : public Test {
     }
     EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
 
-    precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
+    vocabulary = make_shared<MockVocabulary>();
+    EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2));
+    EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3));
+
+    precomputation = Precomputation(vocabulary, suffix_array,
+                                    3, 3, 10, 5, 1, 4, 2);
   }
 
   vector<int> data;
   shared_ptr<MockDataArray> data_array;
   shared_ptr<MockSuffixArray> suffix_array;
+  shared_ptr<MockVocabulary> vocabulary;
   Precomputation precomputation;
 };
 
 TEST_F(PrecomputationTest, TestCollocations) {
-  Index collocations = precomputation.GetCollocations();
-
   vector<int> key = {2, 3, -1, 2};
   vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {2, 3, -1, 2, 3};
   expected_value = {1, 5, 1, 8, 5, 8};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {2, 3, -1, 3};
   expected_value = {1, 6, 1, 9, 5, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {3, -1, 2};
   expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {3, -1, 3};
   expected_value = {2, 6, 2, 9, 6, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {3, -1, 2, 3};
   expected_value = {2, 5, 2, 8, 6, 8};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {2, -1, 2};
   expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {2, -1, 2, 3};
   expected_value = {1, 5, 1, 8, 5, 8};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
   key = {2, -1, 3};
   expected_value = {1, 6, 1, 9, 5, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
 
-  key = {2, -1, 2, -2, 2};
+  key = {2, -1, 2, -1, 2};
   expected_value = {1, 5, 8, 5, 8, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {2, -1, 2, -2, 3};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {2, -1, 2, -1, 3};
   expected_value = {1, 5, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {2, -1, 3, -2, 2};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {2, -1, 3, -1, 2};
   expected_value = {1, 6, 8, 5, 9, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {2, -1, 3, -2, 3};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {2, -1, 3, -1, 3};
   expected_value = {1, 6, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {3, -1, 2, -2, 2};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {3, -1, 2, -1, 2};
   expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {3, -1, 2, -2, 3};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {3, -1, 2, -1, 3};
   expected_value = {2, 5, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {3, -1, 3, -2, 2};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {3, -1, 3, -1, 2};
   expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
-  EXPECT_EQ(expected_value, collocations[key]);
-  key = {3, -1, 3, -2, 3};
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
+  key = {3, -1, 3, -1, 3};
   expected_value = {2, 6, 9};
-  EXPECT_EQ(expected_value, collocations[key]);
+  EXPECT_TRUE(precomputation.Contains(key));
+  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
 
   // Exceeds max_rule_symbols.
-  key = {2, -1, 2, -2, 2, 3};
-  EXPECT_EQ(0, collocations.count(key));
+  key = {2, -1, 2, -1, 2, 3};
+  EXPECT_FALSE(precomputation.Contains(key));
   // Contains non frequent pattern.
   key = {2, -1, 5};
-  EXPECT_EQ(0, collocations.count(key));
+  EXPECT_FALSE(precomputation.Contains(key));
 }
 
 TEST_F(PrecomputationTest, TestSerialization) {
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 6eb55073..85c8a422 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -28,6 +28,7 @@
 #include "suffix_array.h"
 #include "time_util.h"
 #include "translation_table.h"
+#include "vocabulary.h"
 
 namespace fs = boost::filesystem;
 namespace po = boost::program_options;
@@ -142,11 +143,14 @@ int main(int argc, char** argv) {
   cerr << "Reading alignment took "
        << GetDuration(start_time, stop_time) << " seconds" << endl;
 
+  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+
   // Constructs an index storing the occurrences in the source data for each
   // frequent collocation.
   start_time = Clock::now();
   cerr << "Precomputing collocations..." << endl;
   shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
+      vocabulary,
       source_suffix_array,
       vm["frequent"].as<int>(),
       vm["super_frequent"].as<int>(),
@@ -194,6 +198,7 @@ int main(int argc, char** argv) {
       alignment,
       precomputation,
       scorer,
+      vocabulary,
       vm["min_gap_size"].as<int>(),
       vm["max_rule_span"].as<int>(),
       vm["max_nonterminals"].as<int>(),
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index ba0dbcc3..a9fd1eab 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -21,7 +21,7 @@ class SuffixArrayTest : public Test {
   virtual void SetUp() {
     data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};
     data_array = make_shared<MockDataArray>();
-    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
     EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
     EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
     suffix_array = SuffixArray(data_array);
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index 606777bd..72551a12 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -28,7 +28,7 @@ class TranslationTableTest : public Test {
     vector<int> source_sentence_start = {0, 6, 10, 14};
     shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
     EXPECT_CALL(*source_data_array, GetData())
-        .WillRepeatedly(ReturnRef(source_data));
+        .WillRepeatedly(Return(source_data));
     EXPECT_CALL(*source_data_array, GetNumSentences())
         .WillRepeatedly(Return(3));
     for (size_t i = 0; i < source_sentence_start.size(); ++i) {
@@ -48,7 +48,7 @@ class TranslationTableTest : public Test {
     vector<int> target_sentence_start = {0, 7, 10, 13};
     shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
     EXPECT_CALL(*target_data_array, GetData())
-        .WillRepeatedly(ReturnRef(target_data));
+        .WillRepeatedly(Return(target_data));
     for (size_t i = 0; i < target_sentence_start.size(); ++i) {
       EXPECT_CALL(*target_data_array, GetSentenceStart(i))
           .WillRepeatedly(Return(target_sentence_start[i]));
-- 
cgit v1.2.3