summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-25 23:56:31 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-25 23:56:31 +0000
commit6bdb362473cf0ee1c636ca0c3f4cca63d82a5573 (patch)
tree1b00bd9647bcbc37838170af7ebde3911672047b /extractor
parent0febed366e6c3fedd288bc1bddfcd65f247cab81 (diff)
Clean up leave-one-out sampling.
Diffstat (limited to 'extractor')
-rw-r--r--extractor/grammar_extractor.cc6
-rw-r--r--extractor/grammar_extractor.h4
-rw-r--r--extractor/grammar_extractor_test.cc4
-rw-r--r--extractor/mocks/mock_rule_factory.h6
-rw-r--r--extractor/mocks/mock_sampler.h4
-rw-r--r--extractor/rule_factory.cc7
-rw-r--r--extractor/rule_factory.h3
-rw-r--r--extractor/rule_factory_test.cc8
-rw-r--r--extractor/run_extractor.cc3
-rw-r--r--extractor/sampler.cc12
-rw-r--r--extractor/sampler.h4
-rw-r--r--extractor/sampler_test.cc30
12 files changed, 58 insertions, 33 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 4d0738f7..1dc94c25 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar GrammarExtractor::GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 8f570df2..eb79f53c 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -46,7 +46,9 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+ Grammar GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index f32a9599..719e90ff 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
Grammar grammar(rules, feature_names);
unordered_set<int> blacklisted_sentence_ids;
shared_ptr<DataArray> source_data_array;
- EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids);
}
} // namespace
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 6b7b6586..53eb5022 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,9 +7,9 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const
- unordered_set<int>& blacklisted_sentence_ids,
- const shared_ptr<DataArray> source_data_array));
+ MOCK_METHOD2(GetGrammar, Grammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
index 75c43c27..b2742f62 100644
--- a/extractor/mocks/mock_sampler.h
+++ b/extractor/mocks/mock_sampler.h
@@ -7,7 +7,9 @@ namespace extractor {
class MockSampler : public Sampler {
public:
- MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 6ae2d792..5b66f685 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar HieroCachingRuleFactory::GetGrammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
+ PhraseLocation sample = sampler->Sample(
+ next_node->matchings, blacklisted_sentence_ids);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index a1ff76e4..1a9fa2af 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -74,8 +74,7 @@ class HieroCachingRuleFactory {
// (See class description for more details.)
virtual Grammar GetGrammar(
const vector<int>& word_ids,
- const unordered_set<int>& blacklisted_sentence_ids,
- const shared_ptr<DataArray> source_data_array);
+ const unordered_set<int>& blacklisted_sentence_ids);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index f26cc567..332c5959 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {
.WillRepeatedly(Return(feature_names));
sampler = make_shared<MockSampler>();
- EXPECT_CALL(*sampler, Sample(_))
+ EXPECT_CALL(*sampler, Sample(_, _))
.WillRepeatedly(Return(PhraseLocation(0, 1)));
Phrase phrase;
@@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
vector<int> word_ids = {2, 3, 4};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
vector<int> word_ids = {2, 3, 4, 2, 3};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 85c8a422..6b22a302 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -237,7 +237,8 @@ int main(int argc, char** argv) {
unordered_set<int> blacklisted_sentence_ids;
if (leave_one_out) blacklisted_sentence_ids.insert(i);
- Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index 963afa7a..fc386ed1 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,7 +12,9 @@ Sampler::Sampler() {}
Sampler::~Sampler() {}
-PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
+PhraseLocation Sampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
@@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s
double step = max(1.0, (double) (high - low) / max_samples);
double i = low, last = i;
bool found;
+ shared_ptr<DataArray> source_data_array = suffix_array->GetData();
while (sample.size() < max_samples && i < high) {
int x = suffix_array->GetSuffix(Round(i));
int id = source_data_array->GetSentenceId(x);
- if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
+ if (blacklisted_sentence_ids.count(id)) {
found = false;
double backoff_step = 1;
while (true) {
@@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s
double j = i - backoff_step;
x = suffix_array->GetSuffix(Round(j));
id = source_data_array->GetSentenceId(x);
- if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {
found = true; last = i; break;
}
double k = i + backoff_step;
x = suffix_array->GetSuffix(Round(k));
id = source_data_array->GetSentenceId(x);
- if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ if (k < min(i+step, (double)high) &&
+ !blacklisted_sentence_ids.count(id)) {
found = true; last = k; break;
}
if (j <= last && k >= high) break;
diff --git a/extractor/sampler.h b/extractor/sampler.h
index de450c48..bd8a5876 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -23,7 +23,9 @@ class Sampler {
virtual ~Sampler();
// Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
+ virtual PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
protected:
Sampler();
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index 965567ba..14e72780 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -19,6 +19,8 @@ class SamplerTest : public Test {
source_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData())
+ .WillRepeatedly(Return(source_data_array));
for (int i = 0; i < 10; ++i) {
EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
}
@@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) {
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
+ return;
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 4);
expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 100);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler->Sample(location, blacklist));
}
TEST_F(SamplerTest, TestSubstringsSample) {
@@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) {
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
sampler = make_shared<Sampler>(suffix_array, 7);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklist));
}
} // namespace