diff options
-rw-r--r-- | extractor/data_array.cc | 12 | ||||
-rw-r--r-- | extractor/data_array.h | 8 | ||||
-rw-r--r-- | extractor/data_array_test.cc | 12 | ||||
-rw-r--r-- | extractor/mocks/mock_data_array.h | 2 | ||||
-rw-r--r-- | extractor/precomputation.cc | 3 | ||||
-rw-r--r-- | extractor/precomputation_test.cc | 2 |
6 files changed, 37 insertions, 2 deletions
diff --git a/extractor/data_array.cc b/extractor/data_array.cc index ac0493fd..9612aa8a 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -90,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const { return id2word[data[index]]; } +vector<int> DataArray::GetWordIds(int index, int size) const { + return vector<int>(data.begin() + index, data.begin() + index + size); +} + +vector<string> DataArray::GetWords(int start_index, int size) const { + vector<string> words; + for (int word_id: GetWordIds(start_index, size)) { + words.push_back(id2word[word_id]); + } + return words; +} + int DataArray::GetSize() const { return data.size(); } diff --git a/extractor/data_array.h b/extractor/data_array.h index c5dc8a26..b96901d1 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -59,6 +59,14 @@ class DataArray { // Returns the original word at the specified position. virtual string GetWordAtIndex(int index) const; + // Returns the substring of word ids starting at the specified position and + // having the specified length. + virtual vector<int> GetWordIds(int start_index, int size) const; + + // Returns the substring of words starting at the specified position and + // having the specified length. + virtual vector<string> GetWords(int start_index, int size) const; + // Returns the size of the data array. virtual int GetSize() const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index b6b56561..99f79d91 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -56,6 +56,18 @@ TEST_F(DataArrayTest, TestGetData) { } } +TEST_F(DataArrayTest, TestSubstrings) { + vector<int> expected_word_ids = {3, 4, 5}; + vector<string> expected_words = {"are", "mere", "."}; + EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3)); + EXPECT_EQ(expected_words, source_data.GetWords(1, 3)); + + expected_word_ids = {7, 8}; + expected_words = {"a", "lot"}; + EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2)); + EXPECT_EQ(expected_words, target_data.GetWords(7, 2)); +} + TEST_F(DataArrayTest, TestVocabulary) { EXPECT_EQ(9, source_data.GetVocabularySize()); EXPECT_EQ(4, source_data.GetWordId("mere")); diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index edc525fa..98e711d2 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -9,6 +9,8 @@ class MockDataArray : public DataArray { MOCK_CONST_METHOD0(GetData, vector<int>()); MOCK_CONST_METHOD1(AtIndex, int(int index)); MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); + MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size)); + MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); MOCK_CONST_METHOD1(GetWordId, int(const string& word)); diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3e58e2a9..b79daae3 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -91,6 +91,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( } } + shared_ptr<DataArray> data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector<vector<int>> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -98,7 +99,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector<int> pattern(data.begin() + start, data.begin() + start + len); + vector<int> pattern = data_array->GetWordIds(start, len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index 3a98ce05..d5f5ef63 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) { EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -2, 2}; + key = {2, -1, 2, -1, 2}; expected_value = {1, 5, 8, 5, 8, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); |