#include <gtest/gtest.h> #include <memory> #include "mocks/mock_alignment.h" #include "mocks/mock_data_array.h" #include "mocks/mock_rule_extractor_helper.h" #include "mocks/mock_scorer.h" #include "mocks/mock_target_phrase_extractor.h" #include "mocks/mock_vocabulary.h" #include "phrase.h" #include "phrase_builder.h" #include "phrase_location.h" #include "rule_extractor.h" #include "rule.h" using namespace std; using namespace ::testing; namespace extractor { namespace { class RuleExtractorTest : public Test { protected: virtual void SetUp() { source_data_array = make_shared<MockDataArray>(); EXPECT_CALL(*source_data_array, GetSentenceId(_)) .WillRepeatedly(Return(0)); EXPECT_CALL(*source_data_array, GetSentenceStart(_)) .WillRepeatedly(Return(0)); EXPECT_CALL(*source_data_array, GetSentenceLength(_)) .WillRepeatedly(Return(10)); helper = make_shared<MockRuleExtractorHelper>(); EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) .WillRepeatedly(Return(true)); EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) .WillRepeatedly(Return(true)); unordered_map<int, int> source_indexes; EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _)) .WillRepeatedly(Return(source_indexes)); vocabulary = make_shared<MockVocabulary>(); EXPECT_CALL(*vocabulary, GetTerminalValue(87)) .WillRepeatedly(Return("a")); phrase_builder = make_shared<PhraseBuilder>(vocabulary); vector<int> symbols = {87}; Phrase target_phrase = phrase_builder->Build(symbols); PhraseAlignment phrase_alignment = {make_pair(0, 0)}; target_phrase_extractor = make_shared<MockTargetPhraseExtractor>(); vector<pair<Phrase, PhraseAlignment>> target_phrases = { make_pair(target_phrase, phrase_alignment) }; EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) .WillRepeatedly(Return(target_phrases)); scorer = make_shared<MockScorer>(); vector<double> scores = {0.3, 7.2}; EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder, scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); } shared_ptr<MockDataArray> source_data_array; shared_ptr<MockVocabulary> vocabulary; shared_ptr<PhraseBuilder> phrase_builder; shared_ptr<MockRuleExtractorHelper> helper; shared_ptr<MockScorer> scorer; shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor; shared_ptr<RuleExtractor> extractor; }; TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) .WillRepeatedly(Return(false)); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) .WillRepeatedly(Return(false)); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); // Set FindFixPoint to return false. vector<pair<int, int>> gaps; helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); // Set CheckGaps to return false. vector<pair<int, int>> gaps; helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); vector<pair<int, int>> gaps(3); // Set FindFixPoint to return true. The number of gaps equals the number of // nonterminals, so we won't add any extremities. helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(1, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { vector<int> symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector<int> matching = {2}; PhraseLocation phrase_location(matching, 1); vector<int> links(10, -1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( SetArgReferee<0>(links), SetArgReferee<1>(links), SetArgReferee<2>(links), SetArgReferee<3>(links))); vector<pair<int, int>> gaps; // Set FindFixPoint to return true. The number of gaps equals the number of // nonterminals, so we won't add any extremities. helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(4, rules.size()); } } // namespace } // namespace extractor