summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-27 14:33:36 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-27 14:33:36 +0000
commita6e6a369f40d8fb6a191fd7f74fc5efa8bfae2a0 (patch)
treeab2ea6c2b00adb438929cf34dc334c11f2bc6396
parent8f65daa5bdaddaac24cea4df70049757536d6080 (diff)
Unify sampling backoff strategy.
-rw-r--r--extractor/Makefile.am24
-rw-r--r--extractor/backoff_sampler.cc66
-rw-r--r--extractor/backoff_sampler.h41
-rw-r--r--extractor/matchings_sampler.cc38
-rw-r--r--extractor/matchings_sampler.h31
-rw-r--r--extractor/matchings_sampler_test.cc118
-rw-r--r--extractor/mocks/mock_matchings_sampler.h15
-rw-r--r--extractor/mocks/mock_suffix_array_sampler.h15
-rw-r--r--extractor/phrase_location.cc2
-rw-r--r--extractor/phrase_location_sampler.cc34
-rw-r--r--extractor/phrase_location_sampler.h35
-rw-r--r--extractor/phrase_location_sampler_test.cc50
-rw-r--r--extractor/precomputation.cc3
-rw-r--r--extractor/precomputation_test.cc2
-rw-r--r--extractor/rule_factory.cc4
-rw-r--r--extractor/sampler.cc78
-rw-r--r--extractor/sampler.h22
-rw-r--r--extractor/sampler_test.cc92
-rw-r--r--extractor/sampler_test_blacklist.cc102
-rw-r--r--extractor/suffix_array_sampler.cc40
-rw-r--r--extractor/suffix_array_sampler.h34
-rw-r--r--extractor/suffix_array_sampler_test.cc114
22 files changed, 657 insertions, 303 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 7825012c..e5b439f9 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -15,13 +15,15 @@ EXTRA_PROGRAMS = alignment_test \
feature_target_given_source_coherent_test \
grammar_extractor_test \
matchings_finder_test \
+ matchings_sampler_test \
+ phrase_location_sampler_test \
phrase_test \
precomputation_test \
rule_extractor_helper_test \
rule_extractor_test \
rule_factory_test \
- sampler_test \
scorer_test \
+ suffix_array_sampler_test \
suffix_array_test \
target_phrase_extractor_test \
translation_table_test \
@@ -40,13 +42,15 @@ if HAVE_GTEST
feature_target_given_source_coherent_test \
grammar_extractor_test \
matchings_finder_test \
+ matchings_sampler_test \
+ phrase_location_sampler_test \
phrase_test \
precomputation_test \
rule_extractor_helper_test \
rule_extractor_test \
rule_factory_test \
- sampler_test \
scorer_test \
+ suffix_array_sampler_test \
suffix_array_test \
target_phrase_extractor_test \
translation_table_test \
@@ -55,8 +59,7 @@ endif
noinst_PROGRAMS = $(RUNNABLE_TESTS)
-# TESTS = $(RUNNABLE_TESTS)
-TESTS = vocabulary_test
+TESTS = $(RUNNABLE_TESTS)
alignment_test_SOURCES = alignment_test.cc
alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
@@ -82,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc
grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
matchings_finder_test_SOURCES = matchings_finder_test.cc
matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+matchings_sampler_test_SOURCES = matchings_sampler_test.cc
+matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc
+phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
phrase_test_SOURCES = phrase_test.cc
phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
precomputation_test_SOURCES = precomputation_test.cc
@@ -92,10 +99,10 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc
rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
rule_factory_test_SOURCES = rule_factory_test.cc
rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
-sampler_test_SOURCES = sampler_test.cc
-sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
scorer_test_SOURCES = scorer_test.cc
scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc
+suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
suffix_array_test_SOURCES = suffix_array_test.cc
suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
@@ -116,6 +123,7 @@ extract_LDADD = libextractor.a
libextractor_a_SOURCES = \
alignment.cc \
+ backoff_sampler.cc \
data_array.cc \
fast_intersector.cc \
features/count_source_target.cc \
@@ -129,18 +137,20 @@ libextractor_a_SOURCES = \
grammar.cc \
grammar_extractor.cc \
matchings_finder.cc \
+ matchings_sampler.cc \
matchings_trie.cc \
phrase.cc \
phrase_builder.cc \
phrase_location.cc \
+ phrase_location_sampler.cc \
precomputation.cc \
rule.cc \
rule_extractor.cc \
rule_extractor_helper.cc \
rule_factory.cc \
- sampler.cc \
scorer.cc \
suffix_array.cc \
+ suffix_array_sampler.cc \
target_phrase_extractor.cc \
time_util.cc \
translation_table.cc \
diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc
new file mode 100644
index 00000000..28b12909
--- /dev/null
+++ b/extractor/backoff_sampler.cc
@@ -0,0 +1,66 @@
+#include "backoff_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+BackoffSampler::BackoffSampler(
+ shared_ptr<DataArray> source_data_array, int max_samples) :
+ source_data_array(source_data_array), max_samples(max_samples) {}
+
+BackoffSampler::BackoffSampler() {}
+
+PhraseLocation BackoffSampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
+ vector<int> samples;
+ int low = GetRangeLow(location), high = GetRangeHigh(location);
+ int last_position = low - 1;
+ double step = max(1.0, (double) (high - low) / max_samples);
+ for (double num_samples = 0, i = low;
+ num_samples < max_samples && i < high;
+ ++num_samples, i += step) {
+ int position = GetPosition(location, round(i));
+ int sentence_id = source_data_array->GetSentenceId(position);
+ bool found = false;
+ if (last_position >= position ||
+ blacklisted_sentence_ids.count(sentence_id)) {
+ for (double backoff_step = 1; backoff_step < step; ++backoff_step) {
+ double j = i - backoff_step;
+ if (round(j) >= 0) {
+ position = GetPosition(location, round(j));
+ sentence_id = source_data_array->GetSentenceId(position);
+ if (position > last_position &&
+ !blacklisted_sentence_ids.count(sentence_id)) {
+ found = true;
+ last_position = position;
+ break;
+ }
+ }
+
+ double k = i + backoff_step;
+ if (round(k) < high) {
+ position = GetPosition(location, round(k));
+ sentence_id = source_data_array->GetSentenceId(position);
+ if (!blacklisted_sentence_ids.count(sentence_id)) {
+ found = true;
+ last_position = position;
+ break;
+ }
+ }
+ }
+ } else {
+ found = true;
+ last_position = position;
+ }
+
+ if (found) {
+ AppendMatching(samples, position, location);
+ }
+ }
+
+ return PhraseLocation(samples, GetNumSubpatterns(location));
+}
+
+} // namespace extractor
diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h
new file mode 100644
index 00000000..5c244105
--- /dev/null
+++ b/extractor/backoff_sampler.h
@@ -0,0 +1,41 @@
+#ifndef _BACKOFF_SAMPLER_H_
+#define _BACKOFF_SAMPLER_H_
+
+#include <vector>
+
+#include "sampler.h"
+
+namespace extractor {
+
+class DataArray;
+class PhraseLocation;
+
+class BackoffSampler : public Sampler {
+ public:
+ BackoffSampler(shared_ptr<DataArray> source_data_array, int max_samples);
+
+ BackoffSampler();
+
+ PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
+
+ private:
+ virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0;
+
+ virtual int GetRangeLow(const PhraseLocation& location) const = 0;
+
+ virtual int GetRangeHigh(const PhraseLocation& location) const = 0;
+
+ virtual int GetPosition(const PhraseLocation& location, int index) const = 0;
+
+ virtual void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const = 0;
+
+ shared_ptr<DataArray> source_data_array;
+ int max_samples;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc
new file mode 100644
index 00000000..bb916e49
--- /dev/null
+++ b/extractor/matchings_sampler.cc
@@ -0,0 +1,38 @@
+#include "matchings_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+MatchingsSampler::MatchingsSampler(
+ shared_ptr<DataArray> data_array, int max_samples) :
+ BackoffSampler(data_array, max_samples) {}
+
+MatchingsSampler::MatchingsSampler() {}
+
+int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const {
+ return location.num_subpatterns;
+}
+
+int MatchingsSampler::GetRangeLow(const PhraseLocation&) const {
+ return 0;
+}
+
+int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const {
+ return location.matchings->size() / location.num_subpatterns;
+}
+
+int MatchingsSampler::GetPosition(const PhraseLocation& location,
+ int index) const {
+ return (*location.matchings)[index * location.num_subpatterns];
+}
+
+void MatchingsSampler::AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const {
+ copy(location.matchings->begin() + index,
+ location.matchings->begin() + index + location.num_subpatterns,
+ back_inserter(samples));
+}
+
+} // namespace extractor
diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h
new file mode 100644
index 00000000..ca4fce93
--- /dev/null
+++ b/extractor/matchings_sampler.h
@@ -0,0 +1,31 @@
+#ifndef _MATCHINGS_SAMPLER_H_
+#define _MATCHINGS_SAMPLER_H_
+
+#include "backoff_sampler.h"
+
+namespace extractor {
+
+class DataArray;
+
+class MatchingsSampler : public BackoffSampler {
+ public:
+ MatchingsSampler(shared_ptr<DataArray> data_array, int max_samples);
+
+ MatchingsSampler();
+
+ private:
+ int GetNumSubpatterns(const PhraseLocation& location) const;
+
+ int GetRangeLow(const PhraseLocation& location) const;
+
+ int GetRangeHigh(const PhraseLocation& location) const;
+
+ int GetPosition(const PhraseLocation& location, int index) const;
+
+ void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc
new file mode 100644
index 00000000..bc927152
--- /dev/null
+++ b/extractor/matchings_sampler_test.cc
@@ -0,0 +1,118 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_data_array.h"
+#include "matchings_sampler.h"
+#include "phrase_location.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class MatchingsSamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ location = PhraseLocation(locations, 2);
+
+ data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2));
+ }
+ }
+
+ unordered_set<int> blacklisted_sentence_ids;
+ PhraseLocation location;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MatchingsSampler> sampler;
+};
+
+TEST_F(MatchingsSamplerTest, TestSample) {
+ sampler = make_shared<MatchingsSampler>(data_array, 1);
+ vector<int> expected_locations = {0, 1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 2);
+ expected_locations = {0, 1, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 3);
+ expected_locations = {0, 1, 4, 5, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 7);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(MatchingsSamplerTest, TestBackoffSample) {
+ sampler = make_shared<MatchingsSampler>(data_array, 1);
+ blacklisted_sentence_ids = {0};
+ vector<int> expected_locations = {2, 3};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4};
+ expected_locations = {};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 2);
+ blacklisted_sentence_ids = {0, 3};
+ expected_locations = {2, 3, 4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 3);
+ blacklisted_sentence_ids = {0, 3};
+ expected_locations = {2, 3, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 2, 3};
+ expected_locations = {2, 3, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 4);
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {1, 3};
+ expected_locations = {0, 1, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 7);
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4};
+ expected_locations = {};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 2, 4};
+ expected_locations = {2, 3, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {1, 3};
+ expected_locations = {0, 1, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor
diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h
new file mode 100644
index 00000000..de2009c3
--- /dev/null
+++ b/extractor/mocks/mock_matchings_sampler.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "phrase_location.h"
+#include "matchings_sampler.h"
+
+namespace extractor {
+
+class MockMatchingsSampler : public MatchingsSampler {
+ public:
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h
new file mode 100644
index 00000000..d799b969
--- /dev/null
+++ b/extractor/mocks/mock_suffix_array_sampler.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "phrase_location.h"
+#include "suffix_array_sampler.h"
+
+namespace extractor {
+
+class MockSuffixArraySampler : public SuffixArrayRangeSampler {
+ public:
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
+};
+
+} // namespace extractor
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
index 13140cac..2c367893 100644
--- a/extractor/phrase_location.cc
+++ b/extractor/phrase_location.cc
@@ -1,5 +1,7 @@
#include "phrase_location.h"
+#include <iostream>
+
namespace extractor {
PhraseLocation::PhraseLocation(int sa_low, int sa_high) :
diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc
new file mode 100644
index 00000000..a2eec105
--- /dev/null
+++ b/extractor/phrase_location_sampler.cc
@@ -0,0 +1,34 @@
+#include "phrase_location_sampler.h"
+
+#include "matchings_sampler.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+#include "suffix_array_sampler.h"
+
+namespace extractor {
+
+PhraseLocationSampler::PhraseLocationSampler(
+ shared_ptr<SuffixArray> suffix_array, int max_samples) {
+ matchings_sampler = make_shared<MatchingsSampler>(
+ suffix_array->GetData(), max_samples);
+ suffix_array_sampler = make_shared<SuffixArrayRangeSampler>(
+ suffix_array, max_samples);
+}
+
+PhraseLocationSampler::PhraseLocationSampler(
+ shared_ptr<MatchingsSampler> matchings_sampler,
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler) :
+ matchings_sampler(matchings_sampler),
+ suffix_array_sampler(suffix_array_sampler) {}
+
+PhraseLocation PhraseLocationSampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
+ if (location.matchings == NULL) {
+ return suffix_array_sampler->Sample(location, blacklisted_sentence_ids);
+ } else {
+ return matchings_sampler->Sample(location, blacklisted_sentence_ids);
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h
new file mode 100644
index 00000000..0e88335e
--- /dev/null
+++ b/extractor/phrase_location_sampler.h
@@ -0,0 +1,35 @@
+#ifndef _PHRASE_LOCATION_SAMPLER_H_
+#define _PHRASE_LOCATION_SAMPLER_H_
+
+#include <memory>
+
+#include "sampler.h"
+
+namespace extractor {
+
+class MatchingsSampler;
+class PhraseLocation;
+class SuffixArray;
+class SuffixArrayRangeSampler;
+
+class PhraseLocationSampler : public Sampler {
+ public:
+ PhraseLocationSampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
+
+ // For testing only.
+ PhraseLocationSampler(
+ shared_ptr<MatchingsSampler> matchings_sampler,
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler);
+
+ PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
+
+ private:
+ shared_ptr<MatchingsSampler> matchings_sampler;
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc
new file mode 100644
index 00000000..e7520ce7
--- /dev/null
+++ b/extractor/phrase_location_sampler_test.cc
@@ -0,0 +1,50 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_matchings_sampler.h"
+#include "mocks/mock_suffix_array_sampler.h"
+#include "phrase_location.h"
+#include "phrase_location_sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class MatchingsSamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ matchings_sampler = make_shared<MockMatchingsSampler>();
+ suffix_array_sampler = make_shared<MockSuffixArraySampler>();
+
+ sampler = make_shared<PhraseLocationSampler>(
+ matchings_sampler, suffix_array_sampler);
+ }
+
+ shared_ptr<MockMatchingsSampler> matchings_sampler;
+ shared_ptr<MockSuffixArraySampler> suffix_array_sampler;
+ shared_ptr<PhraseLocationSampler> sampler;
+};
+
+TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) {
+ vector<int> locations = {0, 1, 2, 3};
+ PhraseLocation location(0, 3), result(locations, 2);
+ unordered_set<int> blacklisted_sentence_ids;
+ EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids))
+ .WillOnce(Return(result));
+ EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(MatchingsSamplerTest, TestMatchings) {
+ vector<int> locations = {0, 1, 2, 3};
+ PhraseLocation location(locations, 2), result(locations, 2);
+ unordered_set<int> blacklisted_sentence_ids;
+ EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids))
+ .WillOnce(Return(result));
+ EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index b79daae3..3e58e2a9 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -91,7 +91,6 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
}
}
- shared_ptr<DataArray> data_array = suffix_array->GetData();
// Extract the most frequent patterns.
vector<vector<int>> frequent_patterns;
while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
@@ -99,7 +98,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
int len = heap.top().second.second;
heap.pop();
- vector<int> pattern = data_array->GetWordIds(start, len);
+ vector<int> pattern(data.begin() + start, data.begin() + start + len);
if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==
pattern.end()) {
frequent_patterns.push_back(pattern);
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index d5f5ef63..3a98ce05 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) {
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {2, -1, 2, -1, 2};
+ key = {2, -1, 2, -2, 2};
expected_value = {1, 5, 8, 5, 8, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 5b66f685..18a60695 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -12,6 +12,7 @@
#include "phrase_builder.h"
#include "rule.h"
#include "rule_extractor.h"
+#include "phrase_location_sampler.h"
#include "sampler.h"
#include "scorer.h"
#include "suffix_array.h"
@@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
target_data_array, alignment, phrase_builder, scorer, vocabulary,
max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
false, require_tight_phrases);
- sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+ sampler = make_shared<PhraseLocationSampler>(
+ source_suffix_array, max_samples);
}
HieroCachingRuleFactory::HieroCachingRuleFactory(
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
deleted file mode 100644
index 887aaec1..00000000
--- a/extractor/sampler.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-#include "sampler.h"
-
-#include "phrase_location.h"
-#include "suffix_array.h"
-
-namespace extractor {
-
-Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) :
- suffix_array(suffix_array), max_samples(max_samples) {}
-
-Sampler::Sampler() {}
-
-Sampler::~Sampler() {}
-
-PhraseLocation Sampler::Sample(
- const PhraseLocation& location,
- const unordered_set<int>& blacklisted_sentence_ids) const {
- shared_ptr<DataArray> source_data_array = suffix_array->GetData();
- vector<int> sample;
- int num_subpatterns;
- if (location.matchings == NULL) {
- // Sample suffix array range.
- num_subpatterns = 1;
- int low = location.sa_low, high = location.sa_high;
- double step = max(1.0, (double) (high - low) / max_samples);
- double i = low, last = i - 1;
- while (sample.size() < max_samples && i < high) {
- int x = suffix_array->GetSuffix(Round(i));
- int id = source_data_array->GetSentenceId(x);
- bool found = false;
- if (blacklisted_sentence_ids.count(id)) {
- for (int backoff_step = 1; backoff_step <= step; ++backoff_step) {
- double j = i - backoff_step;
- x = suffix_array->GetSuffix(Round(j));
- id = source_data_array->GetSentenceId(x);
- if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {
- found = true;
- last = i;
- break;
- }
- double k = i + backoff_step;
- x = suffix_array->GetSuffix(Round(k));
- id = source_data_array->GetSentenceId(x);
- if (k < min(i+step, (double) high) &&
- !blacklisted_sentence_ids.count(id)) {
- found = true;
- last = k;
- break;
- }
- }
- } else {
- found = true;
- last = i;
- }
- if (found) sample.push_back(x);
- i += step;
- }
- } else {
- // Sample vector of occurrences.
- num_subpatterns = location.num_subpatterns;
- int num_matchings = location.matchings->size() / num_subpatterns;
- double step = max(1.0, (double) num_matchings / max_samples);
- for (double i = 0, num_samples = 0;
- i < num_matchings && num_samples < max_samples;
- i += step, ++num_samples) {
- int start = Round(i) * num_subpatterns;
- sample.insert(sample.end(), location.matchings->begin() + start,
- location.matchings->begin() + start + num_subpatterns);
- }
- }
- return PhraseLocation(sample, num_subpatterns);
-}
-
-int Sampler::Round(double x) const {
- return x + 0.5;
-}
-
-} // namespace extractor
diff --git a/extractor/sampler.h b/extractor/sampler.h
index bd8a5876..3c4e37f1 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -4,38 +4,20 @@
#include <memory>
#include <unordered_set>
-#include "data_array.h"
-
using namespace std;
namespace extractor {
class PhraseLocation;
-class SuffixArray;
/**
- * Provides uniform sampling for a PhraseLocation.
+ * Base sampler class.
*/
class Sampler {
public:
- Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
-
- virtual ~Sampler();
-
- // Samples uniformly at most max_samples phrase occurrences.
virtual PhraseLocation Sample(
const PhraseLocation& location,
- const unordered_set<int>& blacklisted_sentence_ids) const;
-
- protected:
- Sampler();
-
- private:
- // Round floating point number to the nearest integer.
- int Round(double x) const;
-
- shared_ptr<SuffixArray> suffix_array;
- int max_samples;
+ const unordered_set<int>& blacklisted_sentence_ids) const = 0;
};
} // namespace extractor
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
deleted file mode 100644
index 14e72780..00000000
--- a/extractor/sampler_test.cc
+++ /dev/null
@@ -1,92 +0,0 @@
-#include <gtest/gtest.h>
-
-#include <memory>
-
-#include "mocks/mock_suffix_array.h"
-#include "mocks/mock_data_array.h"
-#include "phrase_location.h"
-#include "sampler.h"
-
-using namespace std;
-using namespace ::testing;
-
-namespace extractor {
-namespace {
-
-class SamplerTest : public Test {
- protected:
- virtual void SetUp() {
- source_data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
- suffix_array = make_shared<MockSuffixArray>();
- EXPECT_CALL(*suffix_array, GetData())
- .WillRepeatedly(Return(source_data_array));
- for (int i = 0; i < 10; ++i) {
- EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
- }
- }
-
- shared_ptr<MockSuffixArray> suffix_array;
- shared_ptr<Sampler> sampler;
- shared_ptr<MockDataArray> source_data_array;
-};
-
-TEST_F(SamplerTest, TestSuffixArrayRange) {
- PhraseLocation location(0, 10);
- unordered_set<int> blacklist;
-
- sampler = make_shared<Sampler>(suffix_array, 1);
- vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1),
- sampler->Sample(location, blacklist));
- return;
-
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 4);
- expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1),
- sampler->Sample(location, blacklist));
-}
-
-TEST_F(SamplerTest, TestSubstringsSample) {
- vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- unordered_set<int> blacklist;
- PhraseLocation location(locations, 2);
-
- sampler = make_shared<Sampler>(suffix_array, 1);
- vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2),
- sampler->Sample(location, blacklist));
-
- sampler = make_shared<Sampler>(suffix_array, 7);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2),
- sampler->Sample(location, blacklist));
-}
-
-} // namespace
-} // namespace extractor
diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc
deleted file mode 100644
index 3305b990..00000000
--- a/extractor/sampler_test_blacklist.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-#include <gtest/gtest.h>
-
-#include <memory>
-
-#include "mocks/mock_suffix_array.h"
-#include "mocks/mock_data_array.h"
-#include "phrase_location.h"
-#include "sampler.h"
-
-using namespace std;
-using namespace ::testing;
-
-namespace extractor {
-namespace {
-
-class SamplerTestBlacklist : public Test {
- protected:
- virtual void SetUp() {
- source_data_array = make_shared<MockDataArray>();
- for (int i = 0; i < 10; ++i) {
- EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
- }
- for (int i = -10; i < 0; ++i) {
- EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0));
- }
- suffix_array = make_shared<MockSuffixArray>();
- for (int i = -10; i < 10; ++i) {
- EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
- }
- }
-
- shared_ptr<MockSuffixArray> suffix_array;
- shared_ptr<Sampler> sampler;
- shared_ptr<MockDataArray> source_data_array;
-};
-
-TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) {
- PhraseLocation location(0, 10);
- unordered_set<int> blacklist;
- vector<int> expected_locations;
-
- blacklist.insert(0);
- sampler = make_shared<Sampler>(suffix_array, 1);
- expected_locations = {1};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- for (int i = 0; i < 9; i++) {
- blacklist.insert(i);
- }
- sampler = make_shared<Sampler>(suffix_array, 1);
- expected_locations = {9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(5);
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {1, 4};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(1);
- blacklist.insert(2);
- blacklist.insert(3);
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {4, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(3);
- blacklist.insert(7);
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {1, 2, 6};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(3);
- blacklist.insert(5);
- blacklist.insert(8);
- sampler = make_shared<Sampler>(suffix_array, 4);
- expected_locations = {1, 2, 4, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(9);
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-}
-
-} // namespace
-} // namespace extractor
diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc
new file mode 100644
index 00000000..4a4ced34
--- /dev/null
+++ b/extractor/suffix_array_sampler.cc
@@ -0,0 +1,40 @@
+#include "suffix_array_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+namespace extractor {
+
+SuffixArrayRangeSampler::SuffixArrayRangeSampler(
+ shared_ptr<SuffixArray> source_suffix_array, int max_samples) :
+ BackoffSampler(source_suffix_array->GetData(), max_samples),
+ source_suffix_array(source_suffix_array) {}
+
+SuffixArrayRangeSampler::SuffixArrayRangeSampler() {}
+
+int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const {
+ return 1;
+}
+
+int SuffixArrayRangeSampler::GetRangeLow(
+ const PhraseLocation& location) const {
+ return location.sa_low;
+}
+
+int SuffixArrayRangeSampler::GetRangeHigh(
+ const PhraseLocation& location) const {
+ return location.sa_high;
+}
+
+int SuffixArrayRangeSampler::GetPosition(
+ const PhraseLocation&, int position) const {
+ return source_suffix_array->GetSuffix(position);
+}
+
+void SuffixArrayRangeSampler::AppendMatching(
+ vector<int>& samples, int index, const PhraseLocation&) const {
+ samples.push_back(source_suffix_array->GetSuffix(index));
+}
+
+} // namespace extractor
diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h
new file mode 100644
index 00000000..bb3c2653
--- /dev/null
+++ b/extractor/suffix_array_sampler.h
@@ -0,0 +1,34 @@
+#ifndef _SUFFIX_ARRAY_SAMPLER_H_
+#define _SUFFIX_ARRAY_SAMPLER_H_
+
+#include "backoff_sampler.h"
+
+namespace extractor {
+
+class SuffixArray;
+
+class SuffixArrayRangeSampler : public BackoffSampler {
+ public:
+ SuffixArrayRangeSampler(shared_ptr<SuffixArray> suffix_array,
+ int max_samples);
+
+ SuffixArrayRangeSampler();
+
+ private:
+ int GetNumSubpatterns(const PhraseLocation& location) const;
+
+ int GetRangeLow(const PhraseLocation& location) const;
+
+ int GetRangeHigh(const PhraseLocation& location) const;
+
+ int GetPosition(const PhraseLocation& location, int index) const;
+
+ void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const;
+
+ shared_ptr<SuffixArray> source_suffix_array;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc
new file mode 100644
index 00000000..4b88c027
--- /dev/null
+++ b/extractor/suffix_array_sampler_test.cc
@@ -0,0 +1,114 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_suffix_array.h"
+#include "suffix_array_sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SuffixArraySamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
+ }
+
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+};
+
+TEST_F(SuffixArraySamplerTest, TestSample) {
+ PhraseLocation location(0, 10);
+ unordered_set<int> blacklisted_sentence_ids;
+
+ SuffixArrayRangeSampler sampler(suffix_array, 1);
+ vector<int> expected_locations = {0};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 2);
+ expected_locations = {0, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 3);
+ expected_locations = {0, 3, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 4);
+ expected_locations = {0, 3, 5, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(SuffixArraySamplerTest, TestBackoffSample) {
+ PhraseLocation location(0, 10);
+
+ SuffixArrayRangeSampler sampler(suffix_array, 1);
+ unordered_set<int> blacklisted_sentence_ids = {0};
+ vector<int> expected_locations = {1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ expected_locations = {9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 2);
+ blacklisted_sentence_ids = {0, 5};
+ expected_locations = {1, 4};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 3);
+ blacklisted_sentence_ids = {0, 3, 7};
+ expected_locations = {1, 2, 6};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 4);
+ blacklisted_sentence_ids = {0, 3, 5, 8};
+ expected_locations = {1, 2, 4, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 100);
+ blacklisted_sentence_ids = {0};
+ expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {9};
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor