diff options
Diffstat (limited to 'extractor/rule_factory_test.cc')
-rw-r--r-- | extractor/rule_factory_test.cc | 54 |
1 files changed, 51 insertions, 3 deletions
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index d6fbab74..d329382a 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -5,6 +5,7 @@ #include <vector> #include "grammar.h" +#include "mocks/mock_fast_intersector.h" #include "mocks/mock_intersector.h" #include "mocks/mock_matchings_finder.h" #include "mocks/mock_rule_extractor.h" @@ -25,6 +26,7 @@ class RuleFactoryTest : public Test { virtual void SetUp() { finder = make_shared<MockMatchingsFinder>(); intersector = make_shared<MockIntersector>(); + fast_intersector = make_shared<MockFastIntersector>(); vocabulary = make_shared<MockVocabulary>(); EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a")); @@ -49,14 +51,12 @@ class RuleFactoryTest : public Test { extractor = make_shared<MockRuleExtractor>(); EXPECT_CALL(*extractor, ExtractRules(_, _)) .WillRepeatedly(Return(rules)); - - factory = make_shared<HieroCachingRuleFactory>(finder, intersector, - phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5); } vector<string> feature_names; shared_ptr<MockMatchingsFinder> finder; shared_ptr<MockIntersector> intersector; + shared_ptr<MockFastIntersector> fast_intersector; shared_ptr<MockVocabulary> vocabulary; shared_ptr<PhraseBuilder> phrase_builder; shared_ptr<MockScorer> scorer; @@ -66,6 +66,10 @@ class RuleFactoryTest : public Test { }; TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + EXPECT_CALL(*finder, Find(_, _, _)) .Times(6) .WillRepeatedly(Return(PhraseLocation(0, 1))); @@ -73,14 +77,37 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) .Times(1) .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); vector<int> word_ids = {2, 3, 4}; Grammar grammar = factory->GetGrammar(word_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(6) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(1) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(7, grammar.GetRules().size()); } TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + EXPECT_CALL(*finder, Find(_, _, _)) .Times(12) .WillRepeatedly(Return(PhraseLocation(0, 1))); @@ -89,10 +116,31 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { .Times(16) .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); + vector<int> word_ids = {2, 3, 4, 2, 3}; Grammar grammar = factory->GetGrammar(word_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(12) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(16) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(28, grammar.GetRules().size()); } } // namespace |