summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extractor/Makefile.am3
-rw-r--r--extractor/data_array.cc4
-rw-r--r--extractor/data_array.h3
-rw-r--r--extractor/data_array_test.cc4
-rw-r--r--extractor/mocks/mock_data_array.h1
-rw-r--r--extractor/suffix_array.cc4
-rw-r--r--extractor/suffix_array_test.cc6
-rw-r--r--extractor/translation_table.cc14
-rw-r--r--extractor/translation_table_test.cc10
9 files changed, 12 insertions, 37 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index faf25d89..65a3d436 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -53,8 +53,7 @@ endif
noinst_PROGRAMS = $(RUNNABLE_TESTS)
-# TESTS = $(RUNNABLE_TESTS)
-TESTS = precomputation_test
+TESTS = $(RUNNABLE_TESTS)
alignment_test_SOURCES = alignment_test.cc
alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 6757cae7..ac0493fd 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -115,10 +115,6 @@ int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
-bool DataArray::HasWord(const string& word) const {
- return word2id.count(word);
-}
-
int DataArray::GetWordId(const string& word) const {
auto result = word2id.find(word);
return result == word2id.end() ? -1 : result->second;
diff --git a/extractor/data_array.h b/extractor/data_array.h
index e9af5bd0..c5dc8a26 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -65,9 +65,6 @@ class DataArray {
// Returns the number of distinct words in the data array.
virtual int GetVocabularySize() const;
- // Returns whether a word has ever been observed in the data array.
- virtual bool HasWord(const string& word) const;
-
// Returns the word id for a given word or -1 if it the word has never been
// observed.
virtual int GetWordId(const string& word) const;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 6c329e34..b6b56561 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -58,16 +58,12 @@ TEST_F(DataArrayTest, TestGetData) {
TEST_F(DataArrayTest, TestVocabulary) {
EXPECT_EQ(9, source_data.GetVocabularySize());
- EXPECT_TRUE(source_data.HasWord("mere"));
EXPECT_EQ(4, source_data.GetWordId("mere"));
EXPECT_EQ("mere", source_data.GetWord(4));
- EXPECT_FALSE(source_data.HasWord("banane"));
EXPECT_EQ(11, target_data.GetVocabularySize());
- EXPECT_TRUE(target_data.HasWord("apples"));
EXPECT_EQ(4, target_data.GetWordId("apples"));
EXPECT_EQ("apples", target_data.GetWord(4));
- EXPECT_FALSE(target_data.HasWord("bananas"));
}
TEST_F(DataArrayTest, TestSentenceData) {
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index d39cb0c4..edc525fa 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -11,7 +11,6 @@ class MockDataArray : public DataArray {
MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
- MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
MOCK_CONST_METHOD1(GetWord, string(int word_id));
MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index ac230d13..4a514b12 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -187,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const {
PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
int offset) const {
- if (!data_array->HasWord(word)) {
+ int word_id = data_array->GetWordId(word);
+ if (word_id == -1) {
// Return empty phrase location.
return PhraseLocation(0, 0);
}
- int word_id = data_array->GetWordId(word);
if (offset == 0) {
return PhraseLocation(word_start[word_id], word_start[word_id + 1]);
}
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index a9fd1eab..161edbc0 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) {
EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
}
- EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));
- EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
+ EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1));
EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));
- EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));
- EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));
- EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3));
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 1b1ba112..11e29e1e 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount(
double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
@@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore(
double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index 72551a12..3cfc0011 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -36,13 +36,10 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(source_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*source_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*source_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*source_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
vector<int> target_sentence_start = {0, 7, 10, 13};
@@ -54,13 +51,10 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(target_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*target_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*target_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*target_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<pair<int, int>> links1 = {
make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),