diff options
| author | Patrick Simianer <p@simianer.de> | 2013-05-02 09:09:59 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2013-05-02 09:09:59 +0200 | 
| commit | 9e50f0237413180fba11b500c9dce5c600e3c157 (patch) | |
| tree | 556fc31d231353c853a864afffddd43dc525549a /extractor/rule_extractor_test.cc | |
| parent | d18024a41cbc1b54db88d499571349a6234b6db8 (diff) | |
| parent | 14ed53426726202813a8e82d706b44266f015fe1 (diff) | |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'extractor/rule_extractor_test.cc')
| -rw-r--r-- | extractor/rule_extractor_test.cc | 168 | 
1 files changed, 168 insertions, 0 deletions
| diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc new file mode 100644 index 00000000..5c1501c7 --- /dev/null +++ b/extractor/rule_extractor_test.cc @@ -0,0 +1,168 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_target_phrase_extractor.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_extractor.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class RuleExtractorTest : public Test { + protected: +  virtual void SetUp() { +    source_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*source_data_array, GetSentenceId(_)) +        .WillRepeatedly(Return(0)); +    EXPECT_CALL(*source_data_array, GetSentenceStart(_)) +        .WillRepeatedly(Return(0)); +    EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +        .WillRepeatedly(Return(10)); + +    helper = make_shared<MockRuleExtractorHelper>(); +    EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) +        .WillRepeatedly(Return(true)); +    EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) +        .WillRepeatedly(Return(true)); +    unordered_map<int, int> source_indexes; +    EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _)) +        .WillRepeatedly(Return(source_indexes)); + +    vocabulary = make_shared<MockVocabulary>(); +    EXPECT_CALL(*vocabulary, GetTerminalValue(87)) +        .WillRepeatedly(Return("a")); +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); +    vector<int> symbols = {87}; +    Phrase target_phrase = phrase_builder->Build(symbols); +    PhraseAlignment phrase_alignment = {make_pair(0, 0)}; + +    target_phrase_extractor = make_shared<MockTargetPhraseExtractor>(); +    vector<pair<Phrase, PhraseAlignment> > target_phrases = { +      make_pair(target_phrase, phrase_alignment) +    }; +    EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) +        .WillRepeatedly(Return(target_phrases)); + +    scorer = make_shared<MockScorer>(); +    vector<double> scores = {0.3, 7.2}; +    EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); + +    extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder, +        scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); +  } + +  shared_ptr<MockDataArray> source_data_array; +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockRuleExtractorHelper> helper; +  shared_ptr<MockScorer> scorer; +  shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor; +  shared_ptr<RuleExtractor> extractor; +}; + +TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) +      .WillRepeatedly(Return(false)); +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) +      .WillRepeatedly(Return(false)); +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  // Set FindFixPoint to return false. +  vector<pair<int, int> > gaps; +  helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  // Set CheckGaps to return false. +  vector<pair<int, int> > gaps; +  helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  vector<pair<int, int> > gaps(3); +  // Set FindFixPoint to return true. The number of gaps equals the number of +  // nonterminals, so we won't add any extremities. +  helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(1, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  vector<int> links(10, -1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( +      SetArgReferee<0>(links), +      SetArgReferee<1>(links), +      SetArgReferee<2>(links), +      SetArgReferee<3>(links))); + +  vector<pair<int, int> > gaps; +  // Set FindFixPoint to return true. The number of gaps equals the number of +  // nonterminals, so we won't add any extremities. +  helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(4, rules.size()); +} + +} // namespace +} // namespace extractor | 
