summaryrefslogtreecommitdiff
path: root/extractor/target_phrase_extractor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/target_phrase_extractor_test.cc')
-rw-r--r--extractor/target_phrase_extractor_test.cc116
1 files changed, 116 insertions, 0 deletions
diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc
new file mode 100644
index 00000000..7394f4d9
--- /dev/null
+++ b/extractor/target_phrase_extractor_test.cc
@@ -0,0 +1,116 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "target_phrase_extractor.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class TargetPhraseExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ alignment = make_shared<MockAlignment>();
+ vocabulary = make_shared<MockVocabulary>();
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ helper = make_shared<MockRuleExtractorHelper>();
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<TargetPhraseExtractor> extractor;
+};
+
+TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) {
+ EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5));
+ EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3));
+
+ vector<string> target_words = {"a", "b", "c", "d", "e"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24};
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i + 3))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1),
+ make_pair(4, 4)
+ };
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {1, 0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, true);
+
+ vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)};
+ vector<int> target_low = {0, 3, 2, 1, 4};
+ unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 0, 5, source_indexes, 1);
+ EXPECT_EQ(1, results.size());
+ vector<int> expected_symbols = {20, -2, 22, -1, 24};
+ EXPECT_EQ(expected_symbols, results[0].first.Get());
+ vector<string> expected_words = {"a", "c", "e"};
+ EXPECT_EQ(expected_words, results[0].first.GetWords());
+ vector<pair<int, int> > expected_alignment = {
+ make_pair(0, 0), make_pair(2, 2), make_pair(4, 4)
+ };
+ EXPECT_EQ(expected_alignment, results[0].second);
+}
+
+TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) {
+ vector<string> target_words = {"a", "b", "c", "d", "e", "f"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 26};
+ EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6));
+ EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0));
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {make_pair(1, 1)};
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, false);
+
+ vector<pair<int, int> > target_gaps = {make_pair(2, 4)};
+ vector<int> target_low = {-1, 1, -1, -1, -1, -1};
+ unordered_map<int, int> source_indexes = {{1, 1}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 1, 5, source_indexes, 0);
+ EXPECT_EQ(10, results.size());
+ // TODO(pauldb): Finish unit test once it's clear how these alignments should
+ // look like.
+}
+
+} // namespace