summaryrefslogtreecommitdiff
path: root/extractor/target_phrase_extractor_test.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-05-02 09:09:59 +0200
committerPatrick Simianer <p@simianer.de>2013-05-02 09:09:59 +0200
commit0ce66778da6079506896739e9d97dc7dff83cd72 (patch)
treef435457bb23dab0c566c9896f9d38cece9d15885 /extractor/target_phrase_extractor_test.cc
parentb6754386f1109b960b05cdf2eabbc97bdd38e8df (diff)
parentb7ea2615bc9bb69031ff714ddce1539c9f1bda2d (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'extractor/target_phrase_extractor_test.cc')
-rw-r--r--extractor/target_phrase_extractor_test.cc143
1 files changed, 143 insertions, 0 deletions
diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc
new file mode 100644
index 00000000..80927dee
--- /dev/null
+++ b/extractor/target_phrase_extractor_test.cc
@@ -0,0 +1,143 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "target_phrase_extractor.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class TargetPhraseExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ alignment = make_shared<MockAlignment>();
+ vocabulary = make_shared<MockVocabulary>();
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ helper = make_shared<MockRuleExtractorHelper>();
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<TargetPhraseExtractor> extractor;
+};
+
+TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) {
+ EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5));
+ EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3));
+
+ vector<string> target_words = {"a", "b", "c", "d", "e"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24};
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i + 3))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1),
+ make_pair(4, 4)
+ };
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {1, 0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, true);
+
+ vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)};
+ vector<int> target_low = {0, 3, 2, 1, 4};
+ unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 0, 5, source_indexes, 1);
+ EXPECT_EQ(1, results.size());
+ vector<int> expected_symbols = {20, -2, 22, -1, 24};
+ EXPECT_EQ(expected_symbols, results[0].first.Get());
+ vector<string> expected_words = {"a", "c", "e"};
+ EXPECT_EQ(expected_words, results[0].first.GetWords());
+ vector<pair<int, int> > expected_alignment = {
+ make_pair(0, 0), make_pair(2, 2), make_pair(4, 4)
+ };
+ EXPECT_EQ(expected_alignment, results[0].second);
+}
+
+TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) {
+ vector<string> target_words = {"a", "b", "c", "d", "e", "f", "END_OF_LINE"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 1};
+ EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6));
+ EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0));
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {make_pair(1, 1)};
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, false);
+
+ vector<pair<int, int> > target_gaps = {make_pair(2, 4)};
+ vector<int> target_low = {-1, 1, -1, -1, -1, -1};
+ unordered_map<int, int> source_indexes = {{1, 1}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 1, 5, source_indexes, 0);
+ EXPECT_EQ(10, results.size());
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 4; j <= 6; ++j) {
+ for (int k = 4; k <= j; ++k) {
+ vector<string> expected_words;
+ for (int l = i; l < 2; ++l) {
+ expected_words.push_back(target_words[l]);
+ }
+ for (int l = k; l < j; ++l) {
+ expected_words.push_back(target_words[l]);
+ }
+
+ PhraseAlignment expected_alignment;
+ expected_alignment.push_back(make_pair(1, 1 - i));
+
+ bool found_expected_pair = false;
+ for (auto result: results) {
+ if (result.first.GetWords() == expected_words &&
+ result.second == expected_alignment) {
+ found_expected_pair = true;
+ }
+ }
+
+ EXPECT_TRUE(found_expected_pair);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace extractor