summaryrefslogtreecommitdiff
path: root/extractor/translation_table_test.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
commit6d347f1ce078dede3da0e1498f75e357351c6543 (patch)
tree8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/translation_table_test.cc
parentd11b76def6899790161c47a73018146311356d8b (diff)
parent5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff)
merge paul's extractor code
Diffstat (limited to 'extractor/translation_table_test.cc')
-rw-r--r--extractor/translation_table_test.cc84
1 files changed, 84 insertions, 0 deletions
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
new file mode 100644
index 00000000..051b5715
--- /dev/null
+++ b/extractor/translation_table_test.cc
@@ -0,0 +1,84 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "translation_table.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+TEST(TranslationTableTest, TestScores) {
+ vector<string> words = {"a", "b", "c"};
+
+ vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0};
+ vector<int> source_sentence_start = {0, 6, 10, 14};
+ shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetData())
+ .WillRepeatedly(ReturnRef(source_data));
+ EXPECT_CALL(*source_data_array, GetNumSentences())
+ .WillRepeatedly(Return(3));
+ for (size_t i = 0; i < source_sentence_start.size(); ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(source_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*source_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*source_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*source_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
+ vector<int> target_sentence_start = {0, 7, 10, 13};
+ shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetData())
+ .WillRepeatedly(ReturnRef(target_data));
+ for (size_t i = 0; i < target_sentence_start.size(); ++i) {
+ EXPECT_CALL(*target_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(target_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*target_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*target_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*target_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<pair<int, int> > links1 = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
+ make_pair(4, 4), make_pair(4, 5)
+ };
+ vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)};
+ vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)};
+ shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1));
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2));
+ EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3));
+
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+
+ EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a"));
+ EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b"));
+ EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c"));
+ EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d"));
+
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a"));
+ EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b"));
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c"));
+ EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d"));
+}
+
+} // namespace
+} // namespace extractor