summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
Diffstat (limited to 'extractor')
-rw-r--r--extractor/data_array.cc12
-rw-r--r--extractor/data_array.h8
-rw-r--r--extractor/data_array_test.cc12
-rw-r--r--extractor/mocks/mock_data_array.h2
-rw-r--r--extractor/precomputation.cc3
-rw-r--r--extractor/precomputation_test.cc2
6 files changed, 37 insertions, 2 deletions
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index ac0493fd..9612aa8a 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -90,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const {
return id2word[data[index]];
}
+vector<int> DataArray::GetWordIds(int index, int size) const {
+ return vector<int>(data.begin() + index, data.begin() + index + size);
+}
+
+vector<string> DataArray::GetWords(int start_index, int size) const {
+ vector<string> words;
+ for (int word_id: GetWordIds(start_index, size)) {
+ words.push_back(id2word[word_id]);
+ }
+ return words;
+}
+
int DataArray::GetSize() const {
return data.size();
}
diff --git a/extractor/data_array.h b/extractor/data_array.h
index c5dc8a26..b96901d1 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -59,6 +59,14 @@ class DataArray {
// Returns the original word at the specified position.
virtual string GetWordAtIndex(int index) const;
+ // Returns the substring of word ids starting at the specified position and
+ // having the specified length.
+ virtual vector<int> GetWordIds(int start_index, int size) const;
+
+ // Returns the substring of words starting at the specified position and
+ // having the specified length.
+ virtual vector<string> GetWords(int start_index, int size) const;
+
// Returns the size of the data array.
virtual int GetSize() const;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index b6b56561..99f79d91 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -56,6 +56,18 @@ TEST_F(DataArrayTest, TestGetData) {
}
}
+TEST_F(DataArrayTest, TestSubstrings) {
+ vector<int> expected_word_ids = {3, 4, 5};
+ vector<string> expected_words = {"are", "mere", "."};
+ EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3));
+ EXPECT_EQ(expected_words, source_data.GetWords(1, 3));
+
+ expected_word_ids = {7, 8};
+ expected_words = {"a", "lot"};
+ EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2));
+ EXPECT_EQ(expected_words, target_data.GetWords(7, 2));
+}
+
TEST_F(DataArrayTest, TestVocabulary) {
EXPECT_EQ(9, source_data.GetVocabularySize());
EXPECT_EQ(4, source_data.GetWordId("mere"));
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index edc525fa..98e711d2 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -9,6 +9,8 @@ class MockDataArray : public DataArray {
MOCK_CONST_METHOD0(GetData, vector<int>());
MOCK_CONST_METHOD1(AtIndex, int(int index));
MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
+ MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size));
+ MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 3e58e2a9..b79daae3 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -91,6 +91,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
}
}
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
// Extract the most frequent patterns.
vector<vector<int>> frequent_patterns;
while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
@@ -98,7 +99,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
int len = heap.top().second.second;
heap.pop();
- vector<int> pattern(data.begin() + start, data.begin() + start + len);
+ vector<int> pattern = data_array->GetWordIds(start, len);
if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==
pattern.end()) {
frequent_patterns.push_back(pattern);
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index 3a98ce05..d5f5ef63 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) {
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
- key = {2, -1, 2, -2, 2};
+ key = {2, -1, 2, -1, 2};
expected_value = {1, 5, 8, 5, 8, 11};
EXPECT_TRUE(precomputation.Contains(key));
EXPECT_EQ(expected_value, precomputation.GetCollocations(key));