diff options
Diffstat (limited to 'extractor')
| -rw-r--r-- | extractor/grammar_extractor.cc | 6 | ||||
| -rw-r--r-- | extractor/grammar_extractor.h | 4 | ||||
| -rw-r--r-- | extractor/grammar_extractor_test.cc | 4 | ||||
| -rw-r--r-- | extractor/mocks/mock_rule_factory.h | 6 | ||||
| -rw-r--r-- | extractor/mocks/mock_sampler.h | 4 | ||||
| -rw-r--r-- | extractor/rule_factory.cc | 7 | ||||
| -rw-r--r-- | extractor/rule_factory.h | 3 | ||||
| -rw-r--r-- | extractor/rule_factory_test.cc | 8 | ||||
| -rw-r--r-- | extractor/run_extractor.cc | 3 | ||||
| -rw-r--r-- | extractor/sampler.cc | 12 | ||||
| -rw-r--r-- | extractor/sampler.h | 4 | ||||
| -rw-r--r-- | extractor/sampler_test.cc | 30 | 
12 files changed, 58 insertions, 33 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 4d0738f7..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor(      vocabulary(vocabulary),      rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar GrammarExtractor::GetGrammar( +    const string& sentence, +    const unordered_set<int>& blacklisted_sentence_ids) {    vector<string> words = TokenizeSentence(sentence);    vector<int> word_ids = AnnotateWords(words); -  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);  }  vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 8f570df2..eb79f53c 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -46,7 +46,9 @@ class GrammarExtractor {    // Converts the sentence to a vector of word ids and uses the RuleFactory to    // extract the SCFG rules which may be used to decode the sentence. -  Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); +  Grammar GetGrammar( +      const string& sentence, +      const unordered_set<int>& blacklisted_sentence_ids);   private:    // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {    Grammar grammar(rules, feature_names);    unordered_set<int> blacklisted_sentence_ids;    shared_ptr<DataArray> source_data_array; -  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) +  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))        .WillOnce(Return(grammar));    GrammarExtractor extractor(vocabulary, factory);    string sentence = "Anna has many many apples ."; -  extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); +  extractor.GetGrammar(sentence, blacklisted_sentence_ids);  }  } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 6b7b6586..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,9 +7,9 @@ namespace extractor {  class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {   public: -  MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const -      unordered_set<int>& blacklisted_sentence_ids, -      const shared_ptr<DataArray> source_data_array)); +  MOCK_METHOD2(GetGrammar, Grammar( +      const vector<int>& word_ids, +      const unordered_set<int>& blacklisted_sentence_ids));  };  } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor {  class MockSampler : public Sampler {   public: -  MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids));  };  } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..5b66f685 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}  HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( +    const vector<int>& word_ids, +    const unordered_set<int>& blacklisted_sentence_ids) {    Clock::time_point start_time = Clock::now();    double total_extract_time = 0;    double total_intersect_time = 0; @@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u        Clock::time_point extract_start = Clock::now();        if (!state.starts_with_x) {          // Extract rules for the sampled set of occurrences. -        PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); +        PhraseLocation sample = sampler->Sample( +            next_node->matchings, blacklisted_sentence_ids);          vector<Rule> new_rules =              rule_extractor->ExtractRules(next_phrase, sample);          rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a1ff76e4..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -74,8 +74,7 @@ class HieroCachingRuleFactory {    // (See class description for more details.)    virtual Grammar GetGrammar(        const vector<int>& word_ids, -      const unordered_set<int>& blacklisted_sentence_ids, -      const shared_ptr<DataArray> source_data_array); +      const unordered_set<int>& blacklisted_sentence_ids);   protected:    HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {          .WillRepeatedly(Return(feature_names));      sampler = make_shared<MockSampler>(); -    EXPECT_CALL(*sampler, Sample(_)) +    EXPECT_CALL(*sampler, Sample(_, _))          .WillRepeatedly(Return(PhraseLocation(0, 1)));      Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {    vector<int> word_ids = {2, 3, 4};    unordered_set<int> blacklisted_sentence_ids; -  shared_ptr<DataArray> source_data_array; -  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(7, grammar.GetRules().size());  } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {    vector<int> word_ids = {2, 3, 4, 2, 3};    unordered_set<int> blacklisted_sentence_ids; -  shared_ptr<DataArray> source_data_array; -  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(28, grammar.GetRules().size());  } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 85c8a422..6b22a302 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -237,7 +237,8 @@ int main(int argc, char** argv) {      unordered_set<int> blacklisted_sentence_ids;      if (leave_one_out) blacklisted_sentence_ids.insert(i); -    Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); +    Grammar grammar = extractor.GetGrammar( +        sentences[i], blacklisted_sentence_ids);      ofstream output(GetGrammarFilePath(grammar_path, i).c_str());      output << grammar;    } diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..fc386ed1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,9 @@ Sampler::Sampler() {}  Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { +PhraseLocation Sampler::Sample( +    const PhraseLocation& location, +    const unordered_set<int>& blacklisted_sentence_ids) const {    vector<int> sample;    int num_subpatterns;    if (location.matchings == NULL) { @@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s      double step = max(1.0, (double) (high - low) / max_samples);      double i = low, last = i;      bool found; +    shared_ptr<DataArray> source_data_array = suffix_array->GetData();      while (sample.size() < max_samples && i < high) {        int x = suffix_array->GetSuffix(Round(i));        int id = source_data_array->GetSentenceId(x); -      if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { +      if (blacklisted_sentence_ids.count(id)) {          found = false;          double backoff_step = 1;          while (true) { @@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s            double j = i - backoff_step;            x = suffix_array->GetSuffix(Round(j));            id = source_data_array->GetSentenceId(x); -          if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { +          if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {              found = true; last = i; break;            }            double k = i + backoff_step;            x = suffix_array->GetSuffix(Round(k));            id = source_data_array->GetSentenceId(x); -          if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { +          if (k < min(i+step, (double)high) && +              !blacklisted_sentence_ids.count(id)) {              found = true; last = k; break;            }            if (j <= last && k >= high) break; diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..bd8a5876 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -23,7 +23,9 @@ class Sampler {    virtual ~Sampler();    // Samples uniformly at most max_samples phrase occurrences. -  virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; +  virtual PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const;   protected:    Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 965567ba..14e72780 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -19,6 +19,8 @@ class SamplerTest : public Test {      source_data_array = make_shared<MockDataArray>();      EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));      suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, GetData()) +        .WillRepeatedly(Return(source_data_array));      for (int i = 0; i < 10; ++i) {        EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));      } @@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) {    sampler = make_shared<Sampler>(suffix_array, 1);    vector<int> expected_locations = {0}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler->Sample(location, blacklist)); +  return;    sampler = make_shared<Sampler>(suffix_array, 2);    expected_locations = {0, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 3);    expected_locations = {0, 3, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 4);    expected_locations = {0, 3, 5, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 100);    expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler->Sample(location, blacklist));  }  TEST_F(SamplerTest, TestSubstringsSample) { @@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) {    sampler = make_shared<Sampler>(suffix_array, 1);    vector<int> expected_locations = {0, 1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 2);    expected_locations = {0, 1, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 3);    expected_locations = {0, 1, 4, 5, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklist));    sampler = make_shared<Sampler>(suffix_array, 7);    expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklist));  }  } // namespace  | 
