From 63b30ed9c8510da8c8e2f6a456576424fddacc0e Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Thu, 14 Feb 2013 23:17:15 +0000 Subject: Working version of the grammar extractor. --- extractor/features/count_source_target_test.cc | 32 ++++++++++ extractor/features/feature.cc | 2 + extractor/features/feature.h | 10 ++- extractor/features/is_source_singleton.cc | 2 +- extractor/features/is_source_singleton_test.cc | 35 ++++++++++ extractor/features/is_source_target_singleton.cc | 2 +- .../features/is_source_target_singleton_test.cc | 35 ++++++++++ extractor/features/max_lex_source_given_target.cc | 5 +- .../features/max_lex_source_given_target_test.cc | 74 ++++++++++++++++++++++ extractor/features/max_lex_target_given_source.cc | 5 +- .../features/max_lex_target_given_source_test.cc | 74 ++++++++++++++++++++++ extractor/features/sample_source_count.cc | 2 +- extractor/features/sample_source_count_test.cc | 36 +++++++++++ extractor/features/target_given_source_coherent.cc | 4 +- .../features/target_given_source_coherent_test.cc | 35 ++++++++++ 15 files changed, 341 insertions(+), 12 deletions(-) create mode 100644 extractor/features/count_source_target_test.cc create mode 100644 extractor/features/is_source_singleton_test.cc create mode 100644 extractor/features/is_source_target_singleton_test.cc create mode 100644 extractor/features/max_lex_source_given_target_test.cc create mode 100644 extractor/features/max_lex_target_given_source_test.cc create mode 100644 extractor/features/sample_source_count_test.cc create mode 100644 extractor/features/target_given_source_coherent_test.cc (limited to 'extractor/features') diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc new file mode 100644 index 00000000..22633bb6 --- /dev/null +++ b/extractor/features/count_source_target_test.cc @@ -0,0 +1,32 @@ +#include + +#include +#include + +#include "count_source_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class CountSourceTargetTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared(); + } + + shared_ptr feature; +}; + +TEST_F(CountSourceTargetTest, TestGetName) { + EXPECT_EQ("CountEF", feature->GetName()); +} + +TEST_F(CountSourceTargetTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 9, 13); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc index 7381c35a..876f5f8f 100644 --- a/extractor/features/feature.cc +++ b/extractor/features/feature.cc @@ -1,3 +1,5 @@ #include "feature.h" const double Feature::MAX_SCORE = 99.0; + +Feature::~Feature() {} diff --git a/extractor/features/feature.h b/extractor/features/feature.h index ad22d3e7..aca58401 100644 --- a/extractor/features/feature.h +++ b/extractor/features/feature.h @@ -10,14 +10,16 @@ using namespace std; struct FeatureContext { FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, - double sample_source_count, int pair_count) : + double source_phrase_count, int pair_count, int num_samples) : source_phrase(source_phrase), target_phrase(target_phrase), - sample_source_count(sample_source_count), pair_count(pair_count) {} + source_phrase_count(source_phrase_count), pair_count(pair_count), + num_samples(num_samples) {} Phrase source_phrase; Phrase target_phrase; - double sample_source_count; + double source_phrase_count; int pair_count; + int num_samples; }; class Feature { @@ -26,6 +28,8 @@ class Feature { virtual string GetName() const = 0; + virtual ~Feature(); + static const double MAX_SCORE; }; diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc index 754df3bf..98d4e5fe 100644 --- a/extractor/features/is_source_singleton.cc +++ b/extractor/features/is_source_singleton.cc @@ -3,7 +3,7 @@ #include double IsSourceSingleton::Score(const FeatureContext& context) const { - return context.sample_source_count == 1; + return context.source_phrase_count == 1; } string IsSourceSingleton::GetName() const { diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc new file mode 100644 index 00000000..8c71e593 --- /dev/null +++ b/extractor/features/is_source_singleton_test.cc @@ -0,0 +1,35 @@ +#include + +#include +#include + +#include "is_source_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared(); + } + + shared_ptr feature; +}; + +TEST_F(IsSourceSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonF", feature->GetName()); +} + +TEST_F(IsSourceSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 31); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1, 3, 25); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc index ec816509..31d36532 100644 --- a/extractor/features/is_source_target_singleton.cc +++ b/extractor/features/is_source_target_singleton.cc @@ -7,5 +7,5 @@ double IsSourceTargetSingleton::Score(const FeatureContext& context) const { } string IsSourceTargetSingleton::GetName() const { - return "IsSingletonEF"; + return "IsSingletonFE"; } diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc new file mode 100644 index 00000000..a51f77c9 --- /dev/null +++ b/extractor/features/is_source_target_singleton_test.cc @@ -0,0 +1,35 @@ +#include + +#include +#include + +#include "is_source_target_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceTargetSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared(); + } + + shared_ptr feature; +}; + +TEST_F(IsSourceTargetSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonFE", feature->GetName()); +} + +TEST_F(IsSourceTargetSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 7); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 2.3, 1, 28); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc index c4792d49..21f5c76a 100644 --- a/extractor/features/max_lex_source_given_target.cc +++ b/extractor/features/max_lex_source_given_target.cc @@ -2,6 +2,7 @@ #include +#include "../data_array.h" #include "../translation_table.h" MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( @@ -10,8 +11,8 @@ MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { vector source_words = context.source_phrase.GetWords(); - // TODO(pauldb): Add NULL to target_words, after fixing translation table. vector target_words = context.target_phrase.GetWords(); + target_words.push_back(DataArray::NULL_WORD_STR); double score = 0; for (string source_word: source_words) { @@ -26,5 +27,5 @@ double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { } string MaxLexSourceGivenTarget::GetName() const { - return "MaxLexFGivenE"; + return "MaxLexFgivenE"; } diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc new file mode 100644 index 00000000..5fd41f8b --- /dev/null +++ b/extractor/features/max_lex_source_given_target_test.cc @@ -0,0 +1,74 @@ +#include + +#include +#include +#include + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_source_given_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexSourceGivenTargetTest : public Test { + protected: + virtual void SetUp() { + vector source_words = {"f1", "f2", "f3"}; + vector target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared(vocabulary); + + table = make_shared(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < source_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], DataArray::NULL_WORD_STR)) + .WillRepeatedly(Return(value)); + } + + feature = make_shared(table); + } + + shared_ptr vocabulary; + shared_ptr phrase_builder; + shared_ptr table; + shared_ptr feature; +}; + +TEST_F(MaxLexSourceGivenTargetTest, TestGetName) { + EXPECT_EQ("MaxLexFgivenE", feature->GetName()); +} + +TEST_F(MaxLexSourceGivenTargetTest, TestScore) { + vector source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11); + EXPECT_EQ(99 - log10(18), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc index d82182fe..f2bc2474 100644 --- a/extractor/features/max_lex_target_given_source.cc +++ b/extractor/features/max_lex_target_given_source.cc @@ -2,6 +2,7 @@ #include +#include "../data_array.h" #include "../translation_table.h" MaxLexTargetGivenSource::MaxLexTargetGivenSource( @@ -9,8 +10,8 @@ MaxLexTargetGivenSource::MaxLexTargetGivenSource( table(table) {} double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { - // TODO(pauldb): Add NULL to source_words, after fixing translation table. vector source_words = context.source_phrase.GetWords(); + source_words.push_back(DataArray::NULL_WORD_STR); vector target_words = context.target_phrase.GetWords(); double score = 0; @@ -26,5 +27,5 @@ double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { } string MaxLexTargetGivenSource::GetName() const { - return "MaxLexEGivenF"; + return "MaxLexEgivenF"; } diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc new file mode 100644 index 00000000..c8701bf7 --- /dev/null +++ b/extractor/features/max_lex_target_given_source_test.cc @@ -0,0 +1,74 @@ +#include + +#include +#include +#include + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_target_given_source.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexTargetGivenSourceTest : public Test { + protected: + virtual void SetUp() { + vector source_words = {"f1", "f2", "f3"}; + vector target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared(vocabulary); + + table = make_shared(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < target_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + DataArray::NULL_WORD_STR, target_words[i])) + .WillRepeatedly(Return(value)); + } + + feature = make_shared(table); + } + + shared_ptr vocabulary; + shared_ptr phrase_builder; + shared_ptr table; + shared_ptr feature; +}; + +TEST_F(MaxLexTargetGivenSourceTest, TestGetName) { + EXPECT_EQ("MaxLexEgivenF", feature->GetName()); +} + +TEST_F(MaxLexTargetGivenSourceTest, TestScore) { + vector source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19); + EXPECT_EQ(-log10(36), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc index c8124cfb..88b645b1 100644 --- a/extractor/features/sample_source_count.cc +++ b/extractor/features/sample_source_count.cc @@ -3,7 +3,7 @@ #include double SampleSourceCount::Score(const FeatureContext& context) const { - return log10(1 + context.sample_source_count); + return log10(1 + context.num_samples); } string SampleSourceCount::GetName() const { diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc new file mode 100644 index 00000000..7d226104 --- /dev/null +++ b/extractor/features/sample_source_count_test.cc @@ -0,0 +1,36 @@ +#include + +#include +#include +#include + +#include "sample_source_count.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class SampleSourceCountTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared(); + } + + shared_ptr feature; +}; + +TEST_F(SampleSourceCountTest, TestGetName) { + EXPECT_EQ("SampleCountF", feature->GetName()); +} + +TEST_F(SampleSourceCountTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0, 3, 1); + EXPECT_EQ(log10(2), feature->Score(context)); + + context = FeatureContext(phrase, phrase, 3.2, 3, 9); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc index 748413c3..274b3364 100644 --- a/extractor/features/target_given_source_coherent.cc +++ b/extractor/features/target_given_source_coherent.cc @@ -3,10 +3,10 @@ #include double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { - double prob = context.pair_count / context.sample_source_count; + double prob = (double) context.pair_count / context.num_samples; return prob > 0 ? -log10(prob) : MAX_SCORE; } string TargetGivenSourceCoherent::GetName() const { - return "EGivenFCoherent"; + return "EgivenFCoherent"; } diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc new file mode 100644 index 00000000..c54c06c2 --- /dev/null +++ b/extractor/features/target_given_source_coherent_test.cc @@ -0,0 +1,35 @@ +#include + +#include +#include + +#include "target_given_source_coherent.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class TargetGivenSourceCoherentTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared(); + } + + shared_ptr feature; +}; + +TEST_F(TargetGivenSourceCoherentTest, TestGetName) { + EXPECT_EQ("EgivenFCoherent", feature->GetName()); +} + +TEST_F(TargetGivenSourceCoherentTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.3, 2, 20); + EXPECT_EQ(1.0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1.9, 0, 1); + EXPECT_EQ(99.0, feature->Score(context)); +} + +} // namespace -- cgit v1.2.3