diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-25 15:13:30 +0100 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-25 15:13:30 +0100 |
commit | 9a0a9582d38315fd83628112144077b35b5f1367 (patch) | |
tree | 27267f38981291742665f08e64204eb9b42671ef /extractor | |
parent | 23e89686849d290e8b64875a0bdf77cbdb70d2df (diff) |
Reduce memory used by precomputation.
Diffstat (limited to 'extractor')
-rw-r--r-- | extractor/fast_intersector.cc | 11 | ||||
-rw-r--r-- | extractor/fast_intersector_test.cc | 8 | ||||
-rw-r--r-- | extractor/mocks/mock_precomputation.h | 2 | ||||
-rw-r--r-- | extractor/precomputation.cc | 138 | ||||
-rw-r--r-- | extractor/precomputation.h | 49 | ||||
-rw-r--r-- | extractor/precomputation_test.cc | 143 |
6 files changed, 199 insertions, 152 deletions
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index a8591a72..5360c1da 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -20,10 +20,13 @@ FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array, vocabulary(vocabulary), max_rule_span(max_rule_span), min_gap_size(min_gap_size) { - Index precomputed_collocations = precomputation->GetCollocations(); - for (pair<vector<int>, vector<int>> entry: precomputed_collocations) { - vector<int> phrase = ConvertPhrase(entry.first); - collocations[phrase] = entry.second; + auto precomputed_collocations = precomputation->GetCollocations(); + for (auto item: precomputed_collocations) { + vector<int> phrase = ConvertPhrase(item.first); + vector<int> location = item.second; + vector<int>& phrase_collocations = collocations[phrase]; + phrase_collocations.insert(phrase_collocations.end(), location.begin(), + location.end()); } } diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 76c3aaea..2e618b63 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -60,14 +60,14 @@ class FastIntersectorTest : public Test { precomputation = make_shared<MockPrecomputation>(); EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + .WillRepeatedly(Return(collocations)); phrase_builder = make_shared<PhraseBuilder>(vocabulary); intersector = make_shared<FastIntersector>(suffix_array, precomputation, vocabulary, 15, 1); } - Index collocations; + Collocations collocations; shared_ptr<MockDataArray> data_array; shared_ptr<MockSuffixArray> suffix_array; shared_ptr<MockPrecomputation> precomputation; @@ -82,9 +82,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) { Phrase phrase = phrase_builder->Build(symbols); PhraseLocation prefix_location(15, 16), suffix_location(16, 17); - collocations[symbols] = expected_location; + collocations.push_back(make_pair(symbols, expected_location)); EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + .WillRepeatedly(Return(collocations)); intersector = make_shared<FastIntersector>(suffix_array, precomputation, vocabulary, 15, 1); diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 8753343e..86f4ce27 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -6,7 +6,7 @@ namespace extractor { class MockPrecomputation : public Precomputation { public: - MOCK_CONST_METHOD0(GetCollocations, const Index&()); + MOCK_CONST_METHOD0(GetCollocations, Collocations()); }; } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..37dbf7b7 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -14,63 +14,65 @@ int Precomputation::FIRST_NONTERMINAL = -1; int Precomputation::SECOND_NONTERMINAL = -2; Precomputation::Precomputation( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, + shared_ptr<SuffixArray> suffix_array, int num_frequent_phrases, + int num_super_frequent_phrases, int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { vector<int> data = suffix_array->GetData()->GetData(); - vector<vector<int>> frequent_patterns = FindMostFrequentPatterns( - suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + vector<vector<int>> frequent_phrases = FindMostFrequentPhrases( + suffix_array, data, num_frequent_phrases, max_frequent_phrase_len, min_frequency); // Construct sets containing the frequent and superfrequent contiguous // collocations. - unordered_set<vector<int>, VectorHash> frequent_patterns_set; - unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; - for (size_t i = 0; i < frequent_patterns.size(); ++i) { - frequent_patterns_set.insert(frequent_patterns[i]); - if (i < num_super_frequent_patterns) { - super_frequent_patterns_set.insert(frequent_patterns[i]); + unordered_set<vector<int>, VectorHash> frequent_phrases_set; + unordered_set<vector<int>, VectorHash> super_frequent_phrases_set; + for (size_t i = 0; i < frequent_phrases.size(); ++i) { + frequent_phrases_set.insert(frequent_phrases[i]); + if (i < num_super_frequent_phrases) { + super_frequent_phrases_set.insert(frequent_phrases[i]); } } - vector<tuple<int, int, int>> matchings; + vector<tuple<int, int, int>> locations; for (size_t i = 0; i < data.size(); ++i) { - // If the sentence is over, add all the discontiguous frequent patterns to - // the index. + // If the sentence is over, add all the discontiguous frequent phrases to + // the list. if (data[i] == DataArray::END_OF_LINE) { - AddCollocations(matchings, data, max_rule_span, min_gap_size, + AddCollocations(locations, data, max_rule_span, min_gap_size, max_rule_symbols); - matchings.clear(); + locations.clear(); continue; } - vector<int> pattern; - // Find all the contiguous frequent patterns starting at position i. + vector<int> phrase; + // Find all the contiguous frequent phrases starting at position i. for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { - pattern.push_back(data[i + j - 1]); - if (frequent_patterns_set.count(pattern)) { - int is_super_frequent = super_frequent_patterns_set.count(pattern); - matchings.push_back(make_tuple(i, j, is_super_frequent)); + phrase.push_back(data[i + j - 1]); + if (frequent_phrases_set.count(phrase)) { + int is_super_frequent = super_frequent_phrases_set.count(phrase); + locations.push_back(make_tuple(i, j, is_super_frequent)); } else { - // If the current pattern is not frequent, any longer pattern having the - // current pattern as prefix will not be frequent. + // If the current phrase is not frequent, any longer phrase having the + // current phrase as prefix will not be frequent. break; } } } + + collocations.shrink_to_fit(); } Precomputation::Precomputation() {} Precomputation::~Precomputation() {} -vector<vector<int>> Precomputation::FindMostFrequentPatterns( +vector<vector<int>> Precomputation::FindMostFrequentPhrases( shared_ptr<SuffixArray> suffix_array, const vector<int>& data, - int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + int num_frequent_phrases, int max_frequent_phrase_len, int min_frequency) { vector<int> lcp = suffix_array->BuildLCPArray(); vector<int> run_start(max_frequent_phrase_len); - // Find all the patterns occurring at least min_frequency times. + // Find all the phrases occurring at least min_frequency times. priority_queue<pair<int, pair<int, int>>> heap; for (size_t i = 1; i < lcp.size(); ++i) { for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { @@ -83,34 +85,34 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( } } - // Extract the most frequent patterns. - vector<vector<int>> frequent_patterns; - while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { + // Extract the most frequent phrases. + vector<vector<int>> frequent_phrases; + while (frequent_phrases.size() < num_frequent_phrases && !heap.empty()) { int start = heap.top().second.first; int len = heap.top().second.second; heap.pop(); - vector<int> pattern(data.begin() + start, data.begin() + start + len); - if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == - pattern.end()) { - frequent_patterns.push_back(pattern); + vector<int> phrase(data.begin() + start, data.begin() + start + len); + if (find(phrase.begin(), phrase.end(), DataArray::END_OF_LINE) == + phrase.end()) { + frequent_phrases.push_back(phrase); } } - return frequent_patterns; + return frequent_phrases; } void Precomputation::AddCollocations( - const vector<tuple<int, int, int>>& matchings, const vector<int>& data, + const vector<tuple<int, int, int>>& locations, const vector<int>& data, int max_rule_span, int min_gap_size, int max_rule_symbols) { - // Select the leftmost subpattern. - for (size_t i = 0; i < matchings.size(); ++i) { + // Select the leftmost subphrase. + for (size_t i = 0; i < locations.size(); ++i) { int start1, size1, is_super1; - tie(start1, size1, is_super1) = matchings[i]; + tie(start1, size1, is_super1) = locations[i]; - // Select the second (middle) subpattern - for (size_t j = i + 1; j < matchings.size(); ++j) { + // Select the second (middle) subphrase + for (size_t j = i + 1; j < locations.size(); ++j) { int start2, size2, is_super2; - tie(start2, size2, is_super2) = matchings[j]; + tie(start2, size2, is_super2) = locations[j]; if (start2 - start1 >= max_rule_span) { break; } @@ -118,20 +120,21 @@ void Precomputation::AddCollocations( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector<int> pattern(data.begin() + start1, + vector<int> collocation(data.begin() + start1, data.begin() + start1 + size1); - pattern.push_back(Precomputation::FIRST_NONTERMINAL); - pattern.insert(pattern.end(), data.begin() + start2, + collocation.push_back(Precomputation::FIRST_NONTERMINAL); + collocation.insert(collocation.end(), data.begin() + start2, data.begin() + start2 + size2); - AddStartPositions(collocations[pattern], start1, start2); + + AddCollocation(collocation, GetLocation(start1, start2)); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::SECOND_NONTERMINAL); - // Select the rightmost subpattern. - for (size_t k = j + 1; k < matchings.size(); ++k) { + collocation.push_back(Precomputation::SECOND_NONTERMINAL); + // Select the rightmost subphrase. + for (size_t k = j + 1; k < locations.size(); ++k) { int start3, size3, is_super3; - tie(start3, size3, is_super3) = matchings[k]; + tie(start3, size3, is_super3) = locations[k]; if (start3 - start1 >= max_rule_span) { break; } @@ -140,10 +143,12 @@ void Precomputation::AddCollocations( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - pattern.insert(pattern.end(), data.begin() + start3, + collocation.insert(collocation.end(), data.begin() + start3, data.begin() + start3 + size3); - AddStartPositions(collocations[pattern], start1, start2, start3); - pattern.erase(pattern.end() - size3); + + AddCollocation(collocation, GetLocation(start1, start2, start3)); + + collocation.erase(collocation.end() - size3); } } } @@ -152,20 +157,29 @@ void Precomputation::AddCollocations( } } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2) { - positions.push_back(pos1); - positions.push_back(pos2); +vector<int> Precomputation::GetLocation(int pos1, int pos2) { + vector<int> location; + location.push_back(pos1); + location.push_back(pos2); + return location; +} + +vector<int> Precomputation::GetLocation(int pos1, int pos2, int pos3) { + vector<int> location; + location.push_back(pos1); + location.push_back(pos2); + location.push_back(pos3); + return location; } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2, int pos3) { - positions.push_back(pos1); - positions.push_back(pos2); - positions.push_back(pos3); +void Precomputation::AddCollocation(vector<int> collocation, + vector<int> location) { + collocation.shrink_to_fit(); + location.shrink_to_fit(); + collocations.push_back(make_pair(collocation, location)); } -const Index& Precomputation::GetCollocations() const { +Collocations Precomputation::GetCollocations() const { return collocations; } diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 9f0c9424..0a06349b 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -19,16 +19,18 @@ using namespace std; namespace extractor { typedef boost::hash<vector<int>> VectorHash; -typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; +typedef vector<pair<vector<int>, vector<int>>> Collocations; class SuffixArray; /** - * Data structure wrapping an index with all the occurrences of the most - * frequent discontiguous collocations in the source data. + * Data structure containing all the data needed for constructing an index with + * all the occurrences of the most frequent discontiguous collocations in the + * source data. * - * Let a, b, c be contiguous collocations. The index will contain an entry for - * every collocation of the form: + * Let a, b, c be contiguous phrases. The data structure will contain the + * locations in the source data where every collocation of the following forms + * occurs: * - aXb, where a and b are frequent * - aXbXc, where a and b are super-frequent and c is frequent or * b and c are super-frequent and a is frequent. @@ -37,8 +39,8 @@ class Precomputation { public: // Constructs the index using the suffix array. Precomputation( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, + shared_ptr<SuffixArray> suffix_array, int num_frequent_phrases, + int num_super_frequent_phrases, int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); @@ -47,8 +49,9 @@ class Precomputation { virtual ~Precomputation(); - // Returns a reference to the index. - virtual const Index& GetCollocations() const; + // Returns the list of the locations of the most frequent collocations in the + // source data. + virtual Collocations GetCollocations() const; bool operator==(const Precomputation& other) const; @@ -57,23 +60,29 @@ class Precomputation { private: // Finds the most frequent contiguous collocations. - vector<vector<int>> FindMostFrequentPatterns( + vector<vector<int>> FindMostFrequentPhrases( shared_ptr<SuffixArray> suffix_array, const vector<int>& data, - int num_frequent_patterns, int max_frequent_phrase_len, + int num_frequent_phrases, int max_frequent_phrase_len, int min_frequency); // Given the locations of the frequent contiguous collocations in a sentence, // it adds new entries to the index for each discontiguous collocation // matching the criteria specified in the class description. - void AddCollocations( - const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data, - int max_rule_span, int min_gap_size, int max_rule_symbols); + void AddCollocations(const vector<std::tuple<int, int, int>>& locations, + const vector<int>& data, int max_rule_span, + int min_gap_size, int max_rule_symbols); - // Adds an occurrence of a binary collocation. - void AddStartPositions(vector<int>& positions, int pos1, int pos2); + // Creates a vector representation for the location of a binary collocation + // containing the starting points of each subpattern. + vector<int> GetLocation(int pos1, int pos2); - // Adds an occurrence of a ternary collocation. - void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); + // Creates a vector representation for the location of a ternary collocation + // containing the starting points of each subpattern. + vector<int> GetLocation(int pos1, int pos2, int pos3); + + // Appends a collocation to the list of collocations after shrinking the + // vectors to avoid unnecessary memory usage. + void AddCollocation(vector<int> collocation, vector<int> location); friend class boost::serialization::access; @@ -91,13 +100,13 @@ class Precomputation { for (size_t i = 0; i < num_entries; ++i) { pair<vector<int>, vector<int>> entry; ar >> entry; - collocations.insert(entry); + collocations.push_back(entry); } } BOOST_SERIALIZATION_SPLIT_MEMBER(); - Index collocations; + Collocations collocations; }; } // namespace extractor diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index e81ece5d..c6e457fd 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -38,6 +38,23 @@ class PrecomputationTest : public Test { precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); } + void CheckCollocation(const Collocations& collocations, + const vector<int>& collocation, + const vector<vector<int>>& locations) { + for (auto location: locations) { + auto item = make_pair(collocation, location); + EXPECT_FALSE(find(collocations.begin(), collocations.end(), item) == + collocations.end()); + } + } + + void CheckIllegalCollocation(const Collocations& collocations, + const vector<int>& collocation) { + for (auto item: collocations) { + EXPECT_FALSE(collocation == item.first); + } + } + vector<int> data; shared_ptr<MockDataArray> data_array; shared_ptr<MockSuffixArray> suffix_array; @@ -45,67 +62,71 @@ class PrecomputationTest : public Test { }; TEST_F(PrecomputationTest, TestCollocations) { - Index collocations = precomputation.GetCollocations(); - - vector<int> key = {2, 3, -1, 2}; - vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, 3, -1, 2, 3}; - expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, 3, -1, 3}; - expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2}; - expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3}; - expected_value = {2, 6, 2, 9, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, 3}; - expected_value = {2, 5, 2, 8, 6, 8}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 2}; - expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 2, 3}; - expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3}; - expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - - key = {2, -1, 2, -2, 2}; - expected_value = {1, 5, 8, 5, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 2, -2, 3}; - expected_value = {1, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 2}; - expected_value = {1, 6, 8, 5, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 3}; - expected_value = {1, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 2}; - expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 3}; - expected_value = {2, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 2}; - expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 3}; - expected_value = {2, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); - - // Exceeds max_rule_symbols. - key = {2, -1, 2, -2, 2, 3}; - EXPECT_EQ(0, collocations.count(key)); - // Contains non frequent pattern. - key = {2, -1, 5}; - EXPECT_EQ(0, collocations.count(key)); + Collocations collocations = precomputation.GetCollocations(); + + EXPECT_EQ(50, collocations.size()); + + vector<int> collocation = {2, 3, -1, 2}; + vector<vector<int>> locations = {{1, 5}, {1, 8}, {5, 8}, {5, 11}, {8, 11}}; + CheckCollocation(collocations, collocation, locations); + + collocation = {2, 3, -1, 2, 3}; + locations = {{1, 5}, {1, 8}, {5, 8}}; + CheckCollocation(collocations, collocation, locations); + + collocation = {2, 3, -1, 3}; + locations = {{1, 6}, {1, 9}, {5, 9}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 2}; + locations = {{2, 5}, {2, 8}, {2, 11}, {6, 8}, {6, 11}, {9, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 3}; + locations = {{2, 6}, {2, 9}, {6, 9}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 2, 3}; + locations = {{2, 5}, {2, 8}, {6, 8}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 2}; + locations = {{1, 5}, {1, 8}, {5, 8}, {5, 11}, {8, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 2, 3}; + locations = {{1, 5}, {1, 8}, {5, 8}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 3}; + locations = {{1, 6}, {1, 9}, {5, 9}}; + CheckCollocation(collocations, collocation, locations); + + collocation = {2, -1, 2, -2, 2}; + locations = {{1, 5, 8}, {5, 8, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 2, -2, 3}; + locations = {{1, 5, 9}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 3, -2, 2}; + locations = {{1, 6, 8}, {5, 9, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {2, -1, 3, -2, 3}; + locations = {{1, 6, 9}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 2, -2, 2}; + locations = {{2, 5, 8}, {2, 5, 11}, {2, 8, 11}, {6, 8, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 2, -2, 3}; + locations = {{2, 5, 9}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 3, -2, 2}; + locations = {{2, 6, 8}, {2, 6, 11}, {2, 9, 11}, {6, 9, 11}}; + CheckCollocation(collocations, collocation, locations); + collocation = {3, -1, 3, -2, 3}; + locations = {{2, 6, 9}}; + CheckCollocation(collocations, collocation, locations); + + // Collocation exceeds max_rule_symbols. + collocation = {2, -1, 2, -2, 2, 3}; + CheckIllegalCollocation(collocations, collocation); + // Collocation contains non frequent pattern. + collocation = {2, -1, 5}; + CheckIllegalCollocation(collocations, collocation); } TEST_F(PrecomputationTest, TestSerialization) { |