From 2d2d5eced93d58bc77894d8c328195cd9950b96d Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Wed, 13 Nov 2013 18:00:10 +0100
Subject: unit tests for extractor loo sampling
---
extractor/grammar_extractor_test.cc | 7 ++-
extractor/mocks/mock_rule_factory.h | 2 +-
extractor/rule_factory_test.cc | 8 ++-
extractor/sampler.cc | 18 +++----
extractor/sampler_test.cc | 24 +++++----
extractor/sampler_test_blacklist.cc | 102 ++++++++++++++++++++++++++++++++++++
6 files changed, 138 insertions(+), 23 deletions(-)
create mode 100644 extractor/sampler_test_blacklist.cc
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index 823bb8b4..f32a9599 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -39,12 +39,15 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
vector rules;
vector feature_names;
Grammar grammar(rules, feature_names);
- EXPECT_CALL(*factory, GetGrammar(word_ids))
+ unordered_set blacklisted_sentence_ids;
+ shared_ptr source_data_array;
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence);
+
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
}
} // namespace
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 7389b396..86a084b5 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,7 +7,7 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD1(GetGrammar, Grammar(const vector& word_ids));
+ MOCK_METHOD3(GetGrammar, Grammar(const vector& word_ids, const unordered_set blacklisted_sentence_ids, const shared_ptr source_data_array));
};
} // namespace extractor
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index 08af3dcd..f26cc567 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -76,7 +76,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector word_ids = {2, 3, 4};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set blacklisted_sentence_ids;
+ shared_ptr source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -94,7 +96,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector word_ids = {2, 3, 4, 2, 3};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set blacklisted_sentence_ids;
+ shared_ptr source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index cb470962..d332dd90 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -19,25 +19,25 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_setGetSuffix(i);
+ int x = suffix_array->GetSuffix(Round(i));
int id = source_data_array->GetSentenceId(x);
if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
found = false;
- int backoff_step = 1;
+ double backoff_step = 1;
while (true) {
if ((double)backoff_step >= step) break;
- int j = i - backoff_step;
- x = suffix_array->GetSuffix(j);
+ double j = i - backoff_step;
+ x = suffix_array->GetSuffix(Round(j));
id = source_data_array->GetSentenceId(x);
- if (j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
found = true; last = i; break;
}
- int k = i + backoff_step;
- x = suffix_array->GetSuffix(k);
+ double k = i + backoff_step;
+ x = suffix_array->GetSuffix(Round(k));
id = source_data_array->GetSentenceId(x);
if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
found = true; last = k; break;
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index e9abebfa..965567ba 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -3,6 +3,7 @@
#include
#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
#include "phrase_location.h"
#include "sampler.h"
@@ -15,6 +16,8 @@ namespace {
class SamplerTest : public Test {
protected:
virtual void SetUp() {
+ source_data_array = make_shared();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
suffix_array = make_shared();
for (int i = 0; i < 10; ++i) {
EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
@@ -23,51 +26,54 @@ class SamplerTest : public Test {
shared_ptr suffix_array;
shared_ptr sampler;
+ shared_ptr source_data_array;
};
TEST_F(SamplerTest, TestSuffixArrayRange) {
PhraseLocation location(0, 10);
+ unordered_set blacklist;
sampler = make_shared(suffix_array, 1);
vector expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 2);
expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 3);
expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 4);
expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 100);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
}
TEST_F(SamplerTest, TestSubstringsSample) {
vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ unordered_set blacklist;
PhraseLocation location(locations, 2);
sampler = make_shared(suffix_array, 1);
vector expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 2);
expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 3);
expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared(suffix_array, 7);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
}
} // namespace
diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc
new file mode 100644
index 00000000..3305b990
--- /dev/null
+++ b/extractor/sampler_test_blacklist.cc
@@ -0,0 +1,102 @@
+#include
+
+#include
+
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
+#include "phrase_location.h"
+#include "sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SamplerTestBlacklist : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
+ }
+ for (int i = -10; i < 0; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0));
+ }
+ suffix_array = make_shared();
+ for (int i = -10; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr suffix_array;
+ shared_ptr sampler;
+ shared_ptr source_data_array;
+};
+
+TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) {
+ PhraseLocation location(0, 10);
+ unordered_set blacklist;
+ vector expected_locations;
+
+ blacklist.insert(0);
+ sampler = make_shared(suffix_array, 1);
+ expected_locations = {1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ for (int i = 0; i < 9; i++) {
+ blacklist.insert(i);
+ }
+ sampler = make_shared(suffix_array, 1);
+ expected_locations = {9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(5);
+ sampler = make_shared(suffix_array, 2);
+ expected_locations = {1, 4};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(1);
+ blacklist.insert(2);
+ blacklist.insert(3);
+ sampler = make_shared(suffix_array, 2);
+ expected_locations = {4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(7);
+ sampler = make_shared(suffix_array, 3);
+ expected_locations = {1, 2, 6};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(5);
+ blacklist.insert(8);
+ sampler = make_shared(suffix_array, 4);
+ expected_locations = {1, 2, 4, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ sampler = make_shared(suffix_array, 100);
+ expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(9);
+ sampler = make_shared(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+}
+
+} // namespace
+} // namespace extractor
--
cgit v1.2.3