summaryrefslogtreecommitdiff
path: root/extractor/rule_factory_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_factory_test.cc')
-rw-r--r--extractor/rule_factory_test.cc54
1 files changed, 51 insertions, 3 deletions
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index d6fbab74..d329382a 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -5,6 +5,7 @@
#include <vector>
#include "grammar.h"
+#include "mocks/mock_fast_intersector.h"
#include "mocks/mock_intersector.h"
#include "mocks/mock_matchings_finder.h"
#include "mocks/mock_rule_extractor.h"
@@ -25,6 +26,7 @@ class RuleFactoryTest : public Test {
virtual void SetUp() {
finder = make_shared<MockMatchingsFinder>();
intersector = make_shared<MockIntersector>();
+ fast_intersector = make_shared<MockFastIntersector>();
vocabulary = make_shared<MockVocabulary>();
EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a"));
@@ -49,14 +51,12 @@ class RuleFactoryTest : public Test {
extractor = make_shared<MockRuleExtractor>();
EXPECT_CALL(*extractor, ExtractRules(_, _))
.WillRepeatedly(Return(rules));
-
- factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
- phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5);
}
vector<string> feature_names;
shared_ptr<MockMatchingsFinder> finder;
shared_ptr<MockIntersector> intersector;
+ shared_ptr<MockFastIntersector> fast_intersector;
shared_ptr<MockVocabulary> vocabulary;
shared_ptr<PhraseBuilder> phrase_builder;
shared_ptr<MockScorer> scorer;
@@ -66,6 +66,10 @@ class RuleFactoryTest : public Test {
};
TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, false);
+
EXPECT_CALL(*finder, Find(_, _, _))
.Times(6)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
@@ -73,14 +77,37 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
EXPECT_CALL(*intersector, Intersect(_, _, _, _, _))
.Times(1)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0);
vector<int> word_ids = {2, 3, 4};
Grammar grammar = factory->GetGrammar(word_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
+
+ // Test for fast intersector.
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, true);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(6)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(1)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0);
+
+ grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(7, grammar.GetRules().size());
}
TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, false);
+
EXPECT_CALL(*finder, Find(_, _, _))
.Times(12)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
@@ -89,10 +116,31 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
.Times(16)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0);
+
vector<int> word_ids = {2, 3, 4, 2, 3};
Grammar grammar = factory->GetGrammar(word_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
+
+ // Test for fast intersector.
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, true);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(12)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(16)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0);
+
+ grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(28, grammar.GetRules().size());
}
} // namespace