#include #include #include "mocks/mock_alignment.h" #include "mocks/mock_data_array.h" #include "mocks/mock_rule_extractor_helper.h" #include "mocks/mock_scorer.h" #include "mocks/mock_target_phrase_extractor.h" #include "mocks/mock_vocabulary.h" #include "phrase.h" #include "phrase_builder.h" #include "phrase_location.h" #include "rule_extractor.h" #include "rule.h" using namespace std; using namespace ::testing; namespace extractor { namespace { class RuleExtractorTest : public Test { protected: virtual void SetUp() { source_data_array = make_shared(); EXPECT_CALL(*source_data_array, GetSentenceId(_)) .WillRepeatedly(Return(0)); EXPECT_CALL(*source_data_array, GetSentenceStart(_)) .WillRepeatedly(Return(0)); EXPECT_CALL(*source_data_array, GetSentenceLength(_)) .WillRepeatedly(Return(10)); helper = make_shared(); EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) .WillRepeatedly(Return(true)); EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) .WillRepeatedly(Return(true)); unordered_map source_indexes; EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _)) .WillRepeatedly(Return(source_indexes)); vocabulary = make_shared(); EXPECT_CALL(*vocabulary, GetTerminalValue(87)) .WillRepeatedly(Return("a")); phrase_builder = make_shared(vocabulary); vector symbols = {87}; Phrase target_phrase = phrase_builder->Build(symbols); PhraseAlignment phrase_alignment = {make_pair(0, 0)}; target_phrase_extractor = make_shared(); vector> target_phrases = { make_pair(target_phrase, phrase_alignment) }; EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) .WillRepeatedly(Return(target_phrases)); scorer = make_shared(); vector scores = {0.3, 7.2}; EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); extractor = make_shared(source_data_array, phrase_builder, scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); } shared_ptr source_data_array; shared_ptr vocabulary; shared_ptr phrase_builder; shared_ptr helper; shared_ptr scorer; shared_ptr target_phrase_extractor; shared_ptr extractor; }; TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) .WillRepeatedly(Return(false)); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) .WillRepeatedly(Return(false)); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); // Set FindFixPoint to return false. vector> gaps; helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); // Set CheckGaps to return false. vector> gaps; helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(0, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); vector> gaps(3); // Set FindFixPoint to return true. The number of gaps equals the number of // nonterminals, so we won't add any extremities. helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(1, rules.size()); } TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { vector symbols = {87}; Phrase phrase = phrase_builder->Build(symbols); vector matching = {2}; PhraseLocation phrase_location(matching, 1); vector links(10, -1); EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( SetArgReferee<0>(links), SetArgReferee<1>(links), SetArgReferee<2>(links), SetArgReferee<3>(links))); vector> gaps; // Set FindFixPoint to return true. The number of gaps equals the number of // nonterminals, so we won't add any extremities. helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); vector rules = extractor->ExtractRules(phrase, phrase_location); EXPECT_EQ(4, rules.size()); } } // namespace } // namespace extractor