diff options
Diffstat (limited to 'extractor')
| -rw-r--r-- | extractor/Makefile.am | 24 | ||||
| -rw-r--r-- | extractor/backoff_sampler.cc | 66 | ||||
| -rw-r--r-- | extractor/backoff_sampler.h | 41 | ||||
| -rw-r--r-- | extractor/matchings_sampler.cc | 38 | ||||
| -rw-r--r-- | extractor/matchings_sampler.h | 31 | ||||
| -rw-r--r-- | extractor/matchings_sampler_test.cc | 118 | ||||
| -rw-r--r-- | extractor/mocks/mock_matchings_sampler.h | 15 | ||||
| -rw-r--r-- | extractor/mocks/mock_suffix_array_sampler.h | 15 | ||||
| -rw-r--r-- | extractor/phrase_location.cc | 2 | ||||
| -rw-r--r-- | extractor/phrase_location_sampler.cc | 34 | ||||
| -rw-r--r-- | extractor/phrase_location_sampler.h | 35 | ||||
| -rw-r--r-- | extractor/phrase_location_sampler_test.cc | 50 | ||||
| -rw-r--r-- | extractor/precomputation.cc | 3 | ||||
| -rw-r--r-- | extractor/precomputation_test.cc | 2 | ||||
| -rw-r--r-- | extractor/rule_factory.cc | 4 | ||||
| -rw-r--r-- | extractor/sampler.cc | 78 | ||||
| -rw-r--r-- | extractor/sampler.h | 22 | ||||
| -rw-r--r-- | extractor/sampler_test.cc | 92 | ||||
| -rw-r--r-- | extractor/sampler_test_blacklist.cc | 102 | ||||
| -rw-r--r-- | extractor/suffix_array_sampler.cc | 40 | ||||
| -rw-r--r-- | extractor/suffix_array_sampler.h | 34 | ||||
| -rw-r--r-- | extractor/suffix_array_sampler_test.cc | 114 | 
22 files changed, 657 insertions, 303 deletions
| diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 7825012c..e5b439f9 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -15,13 +15,15 @@ EXTRA_PROGRAMS = alignment_test \      feature_target_given_source_coherent_test \      grammar_extractor_test \      matchings_finder_test \ +    matchings_sampler_test \ +    phrase_location_sampler_test \      phrase_test \      precomputation_test \      rule_extractor_helper_test \      rule_extractor_test \      rule_factory_test \ -    sampler_test \      scorer_test \ +    suffix_array_sampler_test \      suffix_array_test \      target_phrase_extractor_test \      translation_table_test \ @@ -40,13 +42,15 @@ if HAVE_GTEST      feature_target_given_source_coherent_test \      grammar_extractor_test \      matchings_finder_test \ +    matchings_sampler_test \ +    phrase_location_sampler_test \      phrase_test \      precomputation_test \      rule_extractor_helper_test \      rule_extractor_test \      rule_factory_test \ -    sampler_test \      scorer_test \ +    suffix_array_sampler_test \      suffix_array_test \      target_phrase_extractor_test \      translation_table_test \ @@ -55,8 +59,7 @@ endif  noinst_PROGRAMS = $(RUNNABLE_TESTS) -# TESTS = $(RUNNABLE_TESTS) -TESTS = vocabulary_test +TESTS = $(RUNNABLE_TESTS)  alignment_test_SOURCES = alignment_test.cc  alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -82,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc  grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  matchings_finder_test_SOURCES = matchings_finder_test.cc  matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matchings_sampler_test_SOURCES = matchings_sampler_test.cc +matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc +phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  phrase_test_SOURCES = phrase_test.cc  phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  precomputation_test_SOURCES = precomputation_test.cc @@ -92,10 +99,10 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc  rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  rule_factory_test_SOURCES = rule_factory_test.cc  rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a -sampler_test_SOURCES = sampler_test.cc -sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  scorer_test_SOURCES = scorer_test.cc  scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc +suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  suffix_array_test_SOURCES = suffix_array_test.cc  suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc @@ -116,6 +123,7 @@ extract_LDADD = libextractor.a  libextractor_a_SOURCES = \    alignment.cc \ +  backoff_sampler.cc \    data_array.cc \    fast_intersector.cc \    features/count_source_target.cc \ @@ -129,18 +137,20 @@ libextractor_a_SOURCES = \    grammar.cc \    grammar_extractor.cc \    matchings_finder.cc \ +  matchings_sampler.cc \    matchings_trie.cc \    phrase.cc \    phrase_builder.cc \    phrase_location.cc \ +  phrase_location_sampler.cc \    precomputation.cc \    rule.cc \    rule_extractor.cc \    rule_extractor_helper.cc \    rule_factory.cc \ -  sampler.cc \    scorer.cc \    suffix_array.cc \ +  suffix_array_sampler.cc \    target_phrase_extractor.cc \    time_util.cc \    translation_table.cc \ diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc new file mode 100644 index 00000000..28b12909 --- /dev/null +++ b/extractor/backoff_sampler.cc @@ -0,0 +1,66 @@ +#include "backoff_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +BackoffSampler::BackoffSampler( +    shared_ptr<DataArray> source_data_array, int max_samples) : +    source_data_array(source_data_array), max_samples(max_samples) {} + +BackoffSampler::BackoffSampler() {} + +PhraseLocation BackoffSampler::Sample( +    const PhraseLocation& location, +    const unordered_set<int>& blacklisted_sentence_ids) const { +  vector<int> samples; +  int low = GetRangeLow(location), high = GetRangeHigh(location); +  int last_position = low - 1; +  double step = max(1.0, (double) (high - low) / max_samples); +  for (double num_samples = 0, i = low; +       num_samples < max_samples && i < high; +       ++num_samples, i += step) { +    int position = GetPosition(location, round(i)); +    int sentence_id = source_data_array->GetSentenceId(position); +    bool found = false; +    if (last_position >= position || +        blacklisted_sentence_ids.count(sentence_id)) { +      for (double backoff_step = 1; backoff_step < step; ++backoff_step) { +        double j = i - backoff_step; +        if (round(j) >= 0) { +          position = GetPosition(location, round(j)); +          sentence_id = source_data_array->GetSentenceId(position); +          if (position > last_position && +              !blacklisted_sentence_ids.count(sentence_id)) { +            found = true; +            last_position = position; +            break; +          } +        } + +        double k = i + backoff_step; +        if (round(k) < high) { +          position = GetPosition(location, round(k)); +          sentence_id = source_data_array->GetSentenceId(position); +          if (!blacklisted_sentence_ids.count(sentence_id)) { +            found = true; +            last_position = position; +            break; +          } +        } +      } +    } else { +      found = true; +      last_position = position; +    } + +    if (found) { +      AppendMatching(samples, position, location); +    } +  } + +  return PhraseLocation(samples, GetNumSubpatterns(location)); +} + +} // namespace extractor diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h new file mode 100644 index 00000000..5c244105 --- /dev/null +++ b/extractor/backoff_sampler.h @@ -0,0 +1,41 @@ +#ifndef _BACKOFF_SAMPLER_H_ +#define _BACKOFF_SAMPLER_H_ + +#include <vector> + +#include "sampler.h" + +namespace extractor { + +class DataArray; +class PhraseLocation; + +class BackoffSampler : public Sampler { + public: +  BackoffSampler(shared_ptr<DataArray> source_data_array, int max_samples); + +  BackoffSampler(); + +  PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const; + + private: +  virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0; + +  virtual int GetRangeLow(const PhraseLocation& location) const = 0; + +  virtual int GetRangeHigh(const PhraseLocation& location) const = 0; + +  virtual int GetPosition(const PhraseLocation& location, int index) const = 0; + +  virtual void AppendMatching(vector<int>& samples, int index, +                              const PhraseLocation& location) const = 0; + +  shared_ptr<DataArray> source_data_array; +  int max_samples; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc new file mode 100644 index 00000000..bb916e49 --- /dev/null +++ b/extractor/matchings_sampler.cc @@ -0,0 +1,38 @@ +#include "matchings_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +MatchingsSampler::MatchingsSampler( +    shared_ptr<DataArray> data_array, int max_samples) : +    BackoffSampler(data_array, max_samples) {} + +MatchingsSampler::MatchingsSampler() {} + +int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const { +  return location.num_subpatterns; +} + +int MatchingsSampler::GetRangeLow(const PhraseLocation&) const { +  return 0; +} + +int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const { +  return location.matchings->size() / location.num_subpatterns; +} + +int MatchingsSampler::GetPosition(const PhraseLocation& location, +                                  int index) const { +  return (*location.matchings)[index * location.num_subpatterns]; +} + +void MatchingsSampler::AppendMatching(vector<int>& samples, int index, +                                      const PhraseLocation& location) const { +  copy(location.matchings->begin() + index, +       location.matchings->begin() + index + location.num_subpatterns, +       back_inserter(samples)); +} + +} // namespace extractor diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h new file mode 100644 index 00000000..ca4fce93 --- /dev/null +++ b/extractor/matchings_sampler.h @@ -0,0 +1,31 @@ +#ifndef _MATCHINGS_SAMPLER_H_ +#define _MATCHINGS_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class DataArray; + +class MatchingsSampler : public BackoffSampler { + public: +  MatchingsSampler(shared_ptr<DataArray> data_array, int max_samples); + +  MatchingsSampler(); + + private: +   int GetNumSubpatterns(const PhraseLocation& location) const; + +   int GetRangeLow(const PhraseLocation& location) const; + +   int GetRangeHigh(const PhraseLocation& location) const; + +   int GetPosition(const PhraseLocation& location, int index) const; + +   void AppendMatching(vector<int>& samples, int index, +                       const PhraseLocation& location) const; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc new file mode 100644 index 00000000..bc927152 --- /dev/null +++ b/extractor/matchings_sampler_test.cc @@ -0,0 +1,118 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_data_array.h" +#include "matchings_sampler.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: +  virtual void SetUp() { +    vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +    location = PhraseLocation(locations, 2); + +    data_array = make_shared<MockDataArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2)); +    } +  } + +  unordered_set<int> blacklisted_sentence_ids; +  PhraseLocation location; +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MatchingsSampler> sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSample) { +  sampler = make_shared<MatchingsSampler>(data_array, 1); +  vector<int> expected_locations = {0, 1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 2); +  expected_locations = {0, 1, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 3); +  expected_locations = {0, 1, 4, 5, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 7); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestBackoffSample) { +  sampler = make_shared<MatchingsSampler>(data_array, 1); +  blacklisted_sentence_ids = {0}; +  vector<int> expected_locations = {2, 3}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3, 4}; +  expected_locations = {}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 2); +  blacklisted_sentence_ids = {0, 3}; +  expected_locations = {2, 3, 4, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 3); +  blacklisted_sentence_ids = {0, 3}; +  expected_locations = {2, 3, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 2, 3}; +  expected_locations = {2, 3, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 4); +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {1, 3}; +  expected_locations = {0, 1, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 7); +  blacklisted_sentence_ids = {0, 1, 2, 3, 4}; +  expected_locations = {}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 2, 4}; +  expected_locations = {2, 3, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {1, 3}; +  expected_locations = {0, 1, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h new file mode 100644 index 00000000..de2009c3 --- /dev/null +++ b/extractor/mocks/mock_matchings_sampler.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "phrase_location.h" +#include "matchings_sampler.h" + +namespace extractor { + +class MockMatchingsSampler : public MatchingsSampler { + public: +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h new file mode 100644 index 00000000..d799b969 --- /dev/null +++ b/extractor/mocks/mock_suffix_array_sampler.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "phrase_location.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +class MockSuffixArraySampler : public SuffixArrayRangeSampler { + public: +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 13140cac..2c367893 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,5 +1,7 @@  #include "phrase_location.h" +#include <iostream> +  namespace extractor {  PhraseLocation::PhraseLocation(int sa_low, int sa_high) : diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc new file mode 100644 index 00000000..a2eec105 --- /dev/null +++ b/extractor/phrase_location_sampler.cc @@ -0,0 +1,34 @@ +#include "phrase_location_sampler.h" + +#include "matchings_sampler.h" +#include "phrase_location.h" +#include "suffix_array.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +PhraseLocationSampler::PhraseLocationSampler( +    shared_ptr<SuffixArray> suffix_array, int max_samples) { +  matchings_sampler = make_shared<MatchingsSampler>( +      suffix_array->GetData(), max_samples); +  suffix_array_sampler = make_shared<SuffixArrayRangeSampler>( +      suffix_array, max_samples); +} + +PhraseLocationSampler::PhraseLocationSampler( +    shared_ptr<MatchingsSampler> matchings_sampler, +    shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler) : +    matchings_sampler(matchings_sampler), +    suffix_array_sampler(suffix_array_sampler) {} + +PhraseLocation PhraseLocationSampler::Sample( +    const PhraseLocation& location, +    const unordered_set<int>& blacklisted_sentence_ids) const { +  if (location.matchings == NULL) { +    return suffix_array_sampler->Sample(location, blacklisted_sentence_ids); +  } else { +    return matchings_sampler->Sample(location, blacklisted_sentence_ids); +  } +} + +} // namespace extractor diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h new file mode 100644 index 00000000..0e88335e --- /dev/null +++ b/extractor/phrase_location_sampler.h @@ -0,0 +1,35 @@ +#ifndef _PHRASE_LOCATION_SAMPLER_H_ +#define _PHRASE_LOCATION_SAMPLER_H_ + +#include <memory> + +#include "sampler.h" + +namespace extractor { + +class MatchingsSampler; +class PhraseLocation; +class SuffixArray; +class SuffixArrayRangeSampler; + +class PhraseLocationSampler : public Sampler { + public: +  PhraseLocationSampler(shared_ptr<SuffixArray> suffix_array, int max_samples); + +  // For testing only. +  PhraseLocationSampler( +      shared_ptr<MatchingsSampler> matchings_sampler, +      shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler); + +  PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const; + + private: +  shared_ptr<MatchingsSampler> matchings_sampler; +  shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc new file mode 100644 index 00000000..e7520ce7 --- /dev/null +++ b/extractor/phrase_location_sampler_test.cc @@ -0,0 +1,50 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_matchings_sampler.h" +#include "mocks/mock_suffix_array_sampler.h" +#include "phrase_location.h" +#include "phrase_location_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: +  virtual void SetUp() { +    matchings_sampler = make_shared<MockMatchingsSampler>(); +    suffix_array_sampler = make_shared<MockSuffixArraySampler>(); + +    sampler = make_shared<PhraseLocationSampler>( +        matchings_sampler, suffix_array_sampler); +  } + +  shared_ptr<MockMatchingsSampler> matchings_sampler; +  shared_ptr<MockSuffixArraySampler> suffix_array_sampler; +  shared_ptr<PhraseLocationSampler> sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) { +  vector<int> locations = {0, 1, 2, 3}; +  PhraseLocation location(0, 3), result(locations, 2); +  unordered_set<int> blacklisted_sentence_ids; +  EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids)) +      .WillOnce(Return(result)); +  EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestMatchings) { +  vector<int> locations = {0, 1, 2, 3}; +  PhraseLocation location(locations, 2), result(locations, 2); +  unordered_set<int> blacklisted_sentence_ids; +  EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids)) +      .WillOnce(Return(result)); +  EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index b79daae3..3e58e2a9 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -91,7 +91,6 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(      }    } -  shared_ptr<DataArray> data_array = suffix_array->GetData();    // Extract the most frequent patterns.    vector<vector<int>> frequent_patterns;    while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -99,7 +98,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(      int len = heap.top().second.second;      heap.pop(); -    vector<int> pattern = data_array->GetWordIds(start, len); +    vector<int> pattern(data.begin() + start, data.begin() + start + len);      if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==          pattern.end()) {        frequent_patterns.push_back(pattern); diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index d5f5ef63..3a98ce05 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) {    EXPECT_TRUE(precomputation.Contains(key));    EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); -  key = {2, -1, 2, -1, 2}; +  key = {2, -1, 2, -2, 2};    expected_value = {1, 5, 8, 5, 8, 11};    EXPECT_TRUE(precomputation.Contains(key));    EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 5b66f685..18a60695 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -12,6 +12,7 @@  #include "phrase_builder.h"  #include "rule.h"  #include "rule_extractor.h" +#include "phrase_location_sampler.h"  #include "sampler.h"  #include "scorer.h"  #include "suffix_array.h" @@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(        target_data_array, alignment, phrase_builder, scorer, vocabulary,        max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,        false, require_tight_phrases); -  sampler = make_shared<Sampler>(source_suffix_array, max_samples); +  sampler = make_shared<PhraseLocationSampler>( +      source_suffix_array, max_samples);  }  HieroCachingRuleFactory::HieroCachingRuleFactory( diff --git a/extractor/sampler.cc b/extractor/sampler.cc deleted file mode 100644 index 887aaec1..00000000 --- a/extractor/sampler.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include "sampler.h" - -#include "phrase_location.h" -#include "suffix_array.h" - -namespace extractor { - -Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) : -    suffix_array(suffix_array), max_samples(max_samples) {} - -Sampler::Sampler() {} - -Sampler::~Sampler() {} - -PhraseLocation Sampler::Sample( -    const PhraseLocation& location, -    const unordered_set<int>& blacklisted_sentence_ids) const { -  shared_ptr<DataArray> source_data_array = suffix_array->GetData(); -  vector<int> sample; -  int num_subpatterns; -  if (location.matchings == NULL) { -    // Sample suffix array range. -    num_subpatterns = 1; -    int low = location.sa_low, high = location.sa_high; -    double step = max(1.0, (double) (high - low) / max_samples); -    double i = low, last = i - 1; -    while (sample.size() < max_samples && i < high) { -      int x = suffix_array->GetSuffix(Round(i)); -      int id = source_data_array->GetSentenceId(x); -      bool found = false; -      if (blacklisted_sentence_ids.count(id)) { -        for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { -          double j = i - backoff_step; -          x = suffix_array->GetSuffix(Round(j)); -          id = source_data_array->GetSentenceId(x); -          if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { -            found = true; -            last = i; -            break; -          } -          double k = i + backoff_step; -          x = suffix_array->GetSuffix(Round(k)); -          id = source_data_array->GetSentenceId(x); -          if (k < min(i+step, (double) high) && -              !blacklisted_sentence_ids.count(id)) { -            found = true; -            last = k; -            break; -          } -        } -      } else { -        found = true; -        last = i; -      } -      if (found) sample.push_back(x); -      i += step; -    } -  } else { -    // Sample vector of occurrences. -    num_subpatterns = location.num_subpatterns; -    int num_matchings = location.matchings->size() / num_subpatterns; -    double step = max(1.0, (double) num_matchings / max_samples); -    for (double i = 0, num_samples = 0; -         i < num_matchings && num_samples < max_samples; -         i += step, ++num_samples) { -      int start = Round(i) * num_subpatterns; -      sample.insert(sample.end(), location.matchings->begin() + start, -          location.matchings->begin() + start + num_subpatterns); -    } -  } -  return PhraseLocation(sample, num_subpatterns); -} - -int Sampler::Round(double x) const { -  return x + 0.5; -} - -} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h index bd8a5876..3c4e37f1 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -4,38 +4,20 @@  #include <memory>  #include <unordered_set> -#include "data_array.h" -  using namespace std;  namespace extractor {  class PhraseLocation; -class SuffixArray;  /** - * Provides uniform sampling for a PhraseLocation. + * Base sampler class.   */  class Sampler {   public: -  Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); - -  virtual ~Sampler(); - -  // Samples uniformly at most max_samples phrase occurrences.    virtual PhraseLocation Sample(        const PhraseLocation& location, -      const unordered_set<int>& blacklisted_sentence_ids) const; - - protected: -  Sampler(); - - private: -  // Round floating point number to the nearest integer. -  int Round(double x) const; - -  shared_ptr<SuffixArray> suffix_array; -  int max_samples; +      const unordered_set<int>& blacklisted_sentence_ids) const = 0;  };  } // namespace extractor diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc deleted file mode 100644 index 14e72780..00000000 --- a/extractor/sampler_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include <gtest/gtest.h> - -#include <memory> - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTest : public Test { - protected: -  virtual void SetUp() { -    source_data_array = make_shared<MockDataArray>(); -    EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); -    suffix_array = make_shared<MockSuffixArray>(); -    EXPECT_CALL(*suffix_array, GetData()) -        .WillRepeatedly(Return(source_data_array)); -    for (int i = 0; i < 10; ++i) { -      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); -    } -  } - -  shared_ptr<MockSuffixArray> suffix_array; -  shared_ptr<Sampler> sampler; -  shared_ptr<MockDataArray> source_data_array; -}; - -TEST_F(SamplerTest, TestSuffixArrayRange) { -  PhraseLocation location(0, 10); -  unordered_set<int> blacklist; - -  sampler = make_shared<Sampler>(suffix_array, 1); -  vector<int> expected_locations = {0}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), -            sampler->Sample(location, blacklist)); -  return; - -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {0, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {0, 3, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 4); -  expected_locations = {0, 3, 5, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), -            sampler->Sample(location, blacklist)); -} - -TEST_F(SamplerTest, TestSubstringsSample) { -  vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  unordered_set<int> blacklist; -  PhraseLocation location(locations, 2); - -  sampler = make_shared<Sampler>(suffix_array, 1); -  vector<int> expected_locations = {0, 1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {0, 1, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {0, 1, 4, 5, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), -            sampler->Sample(location, blacklist)); - -  sampler = make_shared<Sampler>(suffix_array, 7); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), -            sampler->Sample(location, blacklist)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc deleted file mode 100644 index 3305b990..00000000 --- a/extractor/sampler_test_blacklist.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include <gtest/gtest.h> - -#include <memory> - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTestBlacklist : public Test { - protected: -  virtual void SetUp() { -    source_data_array = make_shared<MockDataArray>(); -    for (int i = 0; i < 10; ++i) { -      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); -    } -    for (int i = -10; i < 0; ++i) { -      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); -    } -    suffix_array = make_shared<MockSuffixArray>(); -    for (int i = -10; i < 10; ++i) { -      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); -    } -  } - -  shared_ptr<MockSuffixArray> suffix_array; -  shared_ptr<Sampler> sampler; -  shared_ptr<MockDataArray> source_data_array; -}; - -TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { -  PhraseLocation location(0, 10); -  unordered_set<int> blacklist; -  vector<int> expected_locations; -    -  blacklist.insert(0); -  sampler = make_shared<Sampler>(suffix_array, 1); -  expected_locations = {1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -   -  for (int i = 0; i < 9; i++) { -    blacklist.insert(i); -  } -  sampler = make_shared<Sampler>(suffix_array, 1); -  expected_locations = {9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(5); -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {1, 4}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(1); -  blacklist.insert(2); -  blacklist.insert(3); -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {4, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(3); -  blacklist.insert(7); -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {1, 2, 6}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(3); -  blacklist.insert(5); -  blacklist.insert(8); -  sampler = make_shared<Sampler>(suffix_array, 4); -  expected_locations = {1, 2, 4, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -   -  blacklist.insert(0); -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -  -  blacklist.insert(9); -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc new file mode 100644 index 00000000..4a4ced34 --- /dev/null +++ b/extractor/suffix_array_sampler.cc @@ -0,0 +1,40 @@ +#include "suffix_array_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +namespace extractor { + +SuffixArrayRangeSampler::SuffixArrayRangeSampler( +    shared_ptr<SuffixArray> source_suffix_array, int max_samples) : +    BackoffSampler(source_suffix_array->GetData(), max_samples), +    source_suffix_array(source_suffix_array) {} + +SuffixArrayRangeSampler::SuffixArrayRangeSampler() {} + +int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const { +  return 1; +} + +int SuffixArrayRangeSampler::GetRangeLow( +    const PhraseLocation& location) const { +  return location.sa_low; +} + +int SuffixArrayRangeSampler::GetRangeHigh( +    const PhraseLocation& location) const { +  return location.sa_high; +} + +int SuffixArrayRangeSampler::GetPosition( +    const PhraseLocation&, int position) const { +  return source_suffix_array->GetSuffix(position); +} + +void SuffixArrayRangeSampler::AppendMatching( +    vector<int>& samples, int index, const PhraseLocation&) const { +  samples.push_back(source_suffix_array->GetSuffix(index)); +} + +} // namespace extractor diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h new file mode 100644 index 00000000..bb3c2653 --- /dev/null +++ b/extractor/suffix_array_sampler.h @@ -0,0 +1,34 @@ +#ifndef _SUFFIX_ARRAY_SAMPLER_H_ +#define _SUFFIX_ARRAY_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class SuffixArray; + +class SuffixArrayRangeSampler : public BackoffSampler { + public: +  SuffixArrayRangeSampler(shared_ptr<SuffixArray> suffix_array, +                          int max_samples); + +  SuffixArrayRangeSampler(); + + private: +  int GetNumSubpatterns(const PhraseLocation& location) const; + +  int GetRangeLow(const PhraseLocation& location) const; + +  int GetRangeHigh(const PhraseLocation& location) const; + +  int GetPosition(const PhraseLocation& location, int index) const; + +  void AppendMatching(vector<int>& samples, int index, +                      const PhraseLocation& location) const; + +  shared_ptr<SuffixArray> source_suffix_array; +}; + +} // namespace extractor + +#endif diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc new file mode 100644 index 00000000..4b88c027 --- /dev/null +++ b/extractor/suffix_array_sampler_test.cc @@ -0,0 +1,114 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "suffix_array_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SuffixArraySamplerTest : public Test { + protected: +  virtual void SetUp() { +    data_array = make_shared<MockDataArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); +    } + +    suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); +    } +  } + +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(SuffixArraySamplerTest, TestSample) { +  PhraseLocation location(0, 10); +  unordered_set<int> blacklisted_sentence_ids; + +  SuffixArrayRangeSampler sampler(suffix_array, 1); +  vector<int> expected_locations = {0}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 2); +  expected_locations = {0, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 3); +  expected_locations = {0, 3, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 4); +  expected_locations = {0, 3, 5, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 100); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(SuffixArraySamplerTest, TestBackoffSample) { +  PhraseLocation location(0, 10); + +  SuffixArrayRangeSampler sampler(suffix_array, 1); +  unordered_set<int> blacklisted_sentence_ids = {0}; +  vector<int> expected_locations = {1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  expected_locations = {9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 2); +  blacklisted_sentence_ids = {0, 5}; +  expected_locations = {1, 4}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {4, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 3); +  blacklisted_sentence_ids = {0, 3, 7}; +  expected_locations = {1, 2, 6}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 4); +  blacklisted_sentence_ids = {0, 3, 5, 8}; +  expected_locations = {1, 2, 4, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 100); +  blacklisted_sentence_ids = {0}; +  expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {9}; +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor | 
