From 3c73e472444ff0cd436b12f3679440a6969cbf2d Mon Sep 17 00:00:00 2001
From: Paul Baltescu <pauldb89@gmail.com>
Date: Mon, 25 Nov 2013 23:56:31 +0000
Subject: Clean up leave-one-out sampling.

---
 extractor/grammar_extractor.cc      |  6 ++++--
 extractor/grammar_extractor.h       |  4 +++-
 extractor/grammar_extractor_test.cc |  4 ++--
 extractor/mocks/mock_rule_factory.h |  6 +++---
 extractor/mocks/mock_sampler.h      |  4 +++-
 extractor/rule_factory.cc           |  7 +++++--
 extractor/rule_factory.h            |  3 +--
 extractor/rule_factory_test.cc      |  8 +++-----
 extractor/run_extractor.cc          |  3 ++-
 extractor/sampler.cc                | 12 ++++++++----
 extractor/sampler.h                 |  4 +++-
 extractor/sampler_test.cc           | 30 +++++++++++++++++++++---------
 12 files changed, 58 insertions(+), 33 deletions(-)

diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 4d0738f7..1dc94c25 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor(
     vocabulary(vocabulary),
     rule_factory(rule_factory) {}
 
-Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar GrammarExtractor::GetGrammar(
+    const string& sentence,
+    const unordered_set<int>& blacklisted_sentence_ids) {
   vector<string> words = TokenizeSentence(sentence);
   vector<int> word_ids = AnnotateWords(words);
-  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);
 }
 
 vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 8f570df2..eb79f53c 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -46,7 +46,9 @@ class GrammarExtractor {
 
   // Converts the sentence to a vector of word ids and uses the RuleFactory to
   // extract the SCFG rules which may be used to decode the sentence.
-  Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+  Grammar GetGrammar(
+      const string& sentence,
+      const unordered_set<int>& blacklisted_sentence_ids);
 
  private:
   // Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index f32a9599..719e90ff 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
   Grammar grammar(rules, feature_names);
   unordered_set<int> blacklisted_sentence_ids;
   shared_ptr<DataArray> source_data_array;
-  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
+  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))
       .WillOnce(Return(grammar));
 
   GrammarExtractor extractor(vocabulary, factory);
   string sentence = "Anna has many many apples .";
 
-  extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
+  extractor.GetGrammar(sentence, blacklisted_sentence_ids);
 }
 
 } // namespace
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 6b7b6586..53eb5022 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,9 +7,9 @@ namespace extractor {
 
 class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
  public:
-  MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const
-      unordered_set<int>& blacklisted_sentence_ids,
-      const shared_ptr<DataArray> source_data_array));
+  MOCK_METHOD2(GetGrammar, Grammar(
+      const vector<int>& word_ids,
+      const unordered_set<int>& blacklisted_sentence_ids));
 };
 
 } // namespace extractor
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
index 75c43c27..b2742f62 100644
--- a/extractor/mocks/mock_sampler.h
+++ b/extractor/mocks/mock_sampler.h
@@ -7,7 +7,9 @@ namespace extractor {
 
 class MockSampler : public Sampler {
  public:
-  MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+  MOCK_CONST_METHOD2(Sample, PhraseLocation(
+      const PhraseLocation& location,
+      const unordered_set<int>& blacklisted_sentence_ids));
 };
 
 } // namespace extractor
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 6ae2d792..5b66f685 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
 
 HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
 
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar HieroCachingRuleFactory::GetGrammar(
+    const vector<int>& word_ids,
+    const unordered_set<int>& blacklisted_sentence_ids) {
   Clock::time_point start_time = Clock::now();
   double total_extract_time = 0;
   double total_intersect_time = 0;
@@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u
       Clock::time_point extract_start = Clock::now();
       if (!state.starts_with_x) {
         // Extract rules for the sampled set of occurrences.
-        PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
+        PhraseLocation sample = sampler->Sample(
+            next_node->matchings, blacklisted_sentence_ids);
         vector<Rule> new_rules =
             rule_extractor->ExtractRules(next_phrase, sample);
         rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index a1ff76e4..1a9fa2af 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -74,8 +74,7 @@ class HieroCachingRuleFactory {
   // (See class description for more details.)
   virtual Grammar GetGrammar(
       const vector<int>& word_ids,
-      const unordered_set<int>& blacklisted_sentence_ids,
-      const shared_ptr<DataArray> source_data_array);
+      const unordered_set<int>& blacklisted_sentence_ids);
 
  protected:
   HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index f26cc567..332c5959 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {
         .WillRepeatedly(Return(feature_names));
 
     sampler = make_shared<MockSampler>();
-    EXPECT_CALL(*sampler, Sample(_))
+    EXPECT_CALL(*sampler, Sample(_, _))
         .WillRepeatedly(Return(PhraseLocation(0, 1)));
 
     Phrase phrase;
@@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
 
   vector<int> word_ids = {2, 3, 4};
   unordered_set<int> blacklisted_sentence_ids;
-  shared_ptr<DataArray> source_data_array;
-  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
   EXPECT_EQ(feature_names, grammar.GetFeatureNames());
   EXPECT_EQ(7, grammar.GetRules().size());
 }
@@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
 
   vector<int> word_ids = {2, 3, 4, 2, 3};
   unordered_set<int> blacklisted_sentence_ids;
-  shared_ptr<DataArray> source_data_array;
-  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
   EXPECT_EQ(feature_names, grammar.GetFeatureNames());
   EXPECT_EQ(28, grammar.GetRules().size());
 }
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 85c8a422..6b22a302 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -237,7 +237,8 @@ int main(int argc, char** argv) {
 
     unordered_set<int> blacklisted_sentence_ids;
     if (leave_one_out) blacklisted_sentence_ids.insert(i);
-    Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
+    Grammar grammar = extractor.GetGrammar(
+        sentences[i], blacklisted_sentence_ids);
     ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
     output << grammar;
   }
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index 963afa7a..fc386ed1 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,7 +12,9 @@ Sampler::Sampler() {}
 
 Sampler::~Sampler() {}
 
-PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
+PhraseLocation Sampler::Sample(
+    const PhraseLocation& location,
+    const unordered_set<int>& blacklisted_sentence_ids) const {
   vector<int> sample;
   int num_subpatterns;
   if (location.matchings == NULL) {
@@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s
     double step = max(1.0, (double) (high - low) / max_samples);
     double i = low, last = i;
     bool found;
+    shared_ptr<DataArray> source_data_array = suffix_array->GetData();
     while (sample.size() < max_samples && i < high) {
       int x = suffix_array->GetSuffix(Round(i));
       int id = source_data_array->GetSentenceId(x);
-      if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
+      if (blacklisted_sentence_ids.count(id)) {
         found = false;
         double backoff_step = 1;
         while (true) {
@@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s
           double j = i - backoff_step;
           x = suffix_array->GetSuffix(Round(j));
           id = source_data_array->GetSentenceId(x);
-          if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+          if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {
             found = true; last = i; break;
           }
           double k = i + backoff_step;
           x = suffix_array->GetSuffix(Round(k));
           id = source_data_array->GetSentenceId(x);
-          if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+          if (k < min(i+step, (double)high) &&
+              !blacklisted_sentence_ids.count(id)) {
             found = true; last = k; break;
           }
           if (j <= last && k >= high) break;
diff --git a/extractor/sampler.h b/extractor/sampler.h
index de450c48..bd8a5876 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -23,7 +23,9 @@ class Sampler {
   virtual ~Sampler();
 
   // Samples uniformly at most max_samples phrase occurrences.
-  virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
+  virtual PhraseLocation Sample(
+      const PhraseLocation& location,
+      const unordered_set<int>& blacklisted_sentence_ids) const;
 
  protected:
   Sampler();
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index 965567ba..14e72780 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -19,6 +19,8 @@ class SamplerTest : public Test {
     source_data_array = make_shared<MockDataArray>();
     EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
     suffix_array = make_shared<MockSuffixArray>();
+    EXPECT_CALL(*suffix_array, GetData())
+        .WillRepeatedly(Return(source_data_array));
     for (int i = 0; i < 10; ++i) {
       EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
     }
@@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) {
 
   sampler = make_shared<Sampler>(suffix_array, 1);
   vector<int> expected_locations = {0};
-  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 1),
+            sampler->Sample(location, blacklist));
+  return;
 
   sampler = make_shared<Sampler>(suffix_array, 2);
   expected_locations = {0, 5};
-  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 1),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 3);
   expected_locations = {0, 3, 7};
-  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 1),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 4);
   expected_locations = {0, 3, 5, 8};
-  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 1),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 100);
   expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 1),
+            sampler->Sample(location, blacklist));
 }
 
 TEST_F(SamplerTest, TestSubstringsSample) {
@@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) {
 
   sampler = make_shared<Sampler>(suffix_array, 1);
   vector<int> expected_locations = {0, 1};
-  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 2),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 2);
   expected_locations = {0, 1, 6, 7};
-  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 2),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 3);
   expected_locations = {0, 1, 4, 5, 6, 7};
-  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 2),
+            sampler->Sample(location, blacklist));
 
   sampler = make_shared<Sampler>(suffix_array, 7);
   expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+  EXPECT_EQ(PhraseLocation(expected_locations, 2),
+            sampler->Sample(location, blacklist));
 }
 
 } // namespace
-- 
cgit v1.2.3