diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-11-25 23:56:31 +0000 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-11-25 23:56:31 +0000 |
commit | 3c73e472444ff0cd436b12f3679440a6969cbf2d (patch) | |
tree | 9ceee03648ea671d7f05215826dc0d0a5890e36b /extractor | |
parent | 2b95390f08d9f556e6207ecff03b4b0fd5ede993 (diff) |
Clean up leave-one-out sampling.
Diffstat (limited to 'extractor')
-rw-r--r-- | extractor/grammar_extractor.cc | 6 | ||||
-rw-r--r-- | extractor/grammar_extractor.h | 4 | ||||
-rw-r--r-- | extractor/grammar_extractor_test.cc | 4 | ||||
-rw-r--r-- | extractor/mocks/mock_rule_factory.h | 6 | ||||
-rw-r--r-- | extractor/mocks/mock_sampler.h | 4 | ||||
-rw-r--r-- | extractor/rule_factory.cc | 7 | ||||
-rw-r--r-- | extractor/rule_factory.h | 3 | ||||
-rw-r--r-- | extractor/rule_factory_test.cc | 8 | ||||
-rw-r--r-- | extractor/run_extractor.cc | 3 | ||||
-rw-r--r-- | extractor/sampler.cc | 12 | ||||
-rw-r--r-- | extractor/sampler.h | 4 | ||||
-rw-r--r-- | extractor/sampler_test.cc | 30 |
12 files changed, 58 insertions, 33 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 4d0738f7..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar GrammarExtractor::GetGrammar( + const string& sentence, + const unordered_set<int>& blacklisted_sentence_ids) { vector<string> words = TokenizeSentence(sentence); vector<int> word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids); } vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 8f570df2..eb79f53c 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -46,7 +46,9 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); + Grammar GetGrammar( + const string& sentence, + const unordered_set<int>& blacklisted_sentence_ids); private: // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { Grammar grammar(rules, feature_names); unordered_set<int> blacklisted_sentence_ids; shared_ptr<DataArray> source_data_array; - EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) + EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids)) .WillOnce(Return(grammar)); GrammarExtractor extractor(vocabulary, factory); string sentence = "Anna has many many apples ."; - extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); + extractor.GetGrammar(sentence, blacklisted_sentence_ids); } } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 6b7b6586..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,9 +7,9 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const - unordered_set<int>& blacklisted_sentence_ids, - const shared_ptr<DataArray> source_data_array)); + MOCK_METHOD2(GetGrammar, Grammar( + const vector<int>& word_ids, + const unordered_set<int>& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor { class MockSampler : public Sampler { public: - MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..5b66f685 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( + const vector<int>& word_ids, + const unordered_set<int>& blacklisted_sentence_ids) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); + PhraseLocation sample = sampler->Sample( + next_node->matchings, blacklisted_sentence_ids); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a1ff76e4..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -74,8 +74,7 @@ class HieroCachingRuleFactory { // (See class description for more details.) virtual Grammar GetGrammar( const vector<int>& word_ids, - const unordered_set<int>& blacklisted_sentence_ids, - const shared_ptr<DataArray> source_data_array); + const unordered_set<int>& blacklisted_sentence_ids); protected: HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test { .WillRepeatedly(Return(feature_names)); sampler = make_shared<MockSampler>(); - EXPECT_CALL(*sampler, Sample(_)) + EXPECT_CALL(*sampler, Sample(_, _)) .WillRepeatedly(Return(PhraseLocation(0, 1))); Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { vector<int> word_ids = {2, 3, 4}; unordered_set<int> blacklisted_sentence_ids; - shared_ptr<DataArray> source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { vector<int> word_ids = {2, 3, 4, 2, 3}; unordered_set<int> blacklisted_sentence_ids; - shared_ptr<DataArray> source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 85c8a422..6b22a302 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -237,7 +237,8 @@ int main(int argc, char** argv) { unordered_set<int> blacklisted_sentence_ids; if (leave_one_out) blacklisted_sentence_ids.insert(i); - Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..fc386ed1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,9 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { +PhraseLocation Sampler::Sample( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids) const { vector<int> sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double step = max(1.0, (double) (high - low) / max_samples); double i = low, last = i; bool found; + shared_ptr<DataArray> source_data_array = suffix_array->GetData(); while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); - if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + if (blacklisted_sentence_ids.count(id)) { found = false; double backoff_step = 1; while (true) { @@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { found = true; last = i; break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (k < min(i+step, (double)high) && + !blacklisted_sentence_ids.count(id)) { found = true; last = k; break; } if (j <= last && k >= high) break; diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..bd8a5876 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -23,7 +23,9 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; + virtual PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids) const; protected: Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 965567ba..14e72780 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -19,6 +19,8 @@ class SamplerTest : public Test { source_data_array = make_shared<MockDataArray>(); EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(source_data_array)); for (int i = 0; i < 10; ++i) { EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); } @@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) { sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); + return; sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 4); expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 100); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); } TEST_F(SamplerTest, TestSubstringsSample) { @@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) { sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 7); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); } } // namespace |