summaryrefslogtreecommitdiff
path: root/extractor/rule_extractor_test.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
commit6d347f1ce078dede3da0e1498f75e357351c6543 (patch)
tree8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/rule_extractor_test.cc
parentd11b76def6899790161c47a73018146311356d8b (diff)
parent5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff)
merge paul's extractor code
Diffstat (limited to 'extractor/rule_extractor_test.cc')
-rw-r--r--extractor/rule_extractor_test.cc168
1 files changed, 168 insertions, 0 deletions
diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc
new file mode 100644
index 00000000..5c1501c7
--- /dev/null
+++ b/extractor/rule_extractor_test.cc
@@ -0,0 +1,168 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_scorer.h"
+#include "mocks/mock_target_phrase_extractor.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule_extractor.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class RuleExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(10));
+
+ helper = make_shared<MockRuleExtractorHelper>();
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _))
+ .WillRepeatedly(Return(true));
+ unordered_map<int, int> source_indexes;
+ EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _))
+ .WillRepeatedly(Return(source_indexes));
+
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalValue(87))
+ .WillRepeatedly(Return("a"));
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ vector<int> symbols = {87};
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ PhraseAlignment phrase_alignment = {make_pair(0, 0)};
+
+ target_phrase_extractor = make_shared<MockTargetPhraseExtractor>();
+ vector<pair<Phrase, PhraseAlignment> > target_phrases = {
+ make_pair(target_phrase, phrase_alignment)
+ };
+ EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _))
+ .WillRepeatedly(Return(target_phrases));
+
+ scorer = make_shared<MockScorer>();
+ vector<double> scores = {0.3, 7.2};
+ EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores));
+
+ extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder,
+ scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false);
+ }
+
+ shared_ptr<MockDataArray> source_data_array;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<MockScorer> scorer;
+ shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractor> extractor;
+};
+
+TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set FindFixPoint to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set CheckGaps to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ vector<pair<int, int> > gaps(3);
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(1, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ vector<int> links(10, -1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll(
+ SetArgReferee<0>(links),
+ SetArgReferee<1>(links),
+ SetArgReferee<2>(links),
+ SetArgReferee<3>(links)));
+
+ vector<pair<int, int> > gaps;
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(4, rules.size());
+}
+
+} // namespace
+} // namespace extractor