diff options
80 files changed, 3007 insertions, 700 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am index ded06239..c82fc1ae 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,8 +1,17 @@ bin_PROGRAMS = compile run_extractor noinst_PROGRAMS = \ + alignment_test \ binary_search_merger_test \ data_array_test \ + feature_count_source_target_test \ + feature_is_source_singleton_test \ + feature_is_source_target_singleton_test \ + feature_max_lex_source_given_target_test \ + feature_max_lex_target_given_source_test \ + feature_sample_source_count_test \ + feature_target_given_source_coherent_test \ + grammar_extractor_test \ intersector_test \ linear_merger_test \ matching_comparator_test \ @@ -10,27 +19,66 @@ noinst_PROGRAMS = \ matchings_finder_test \ phrase_test \ precomputation_test \ + rule_extractor_helper_test \ + rule_extractor_test \ + rule_factory_test \ sampler_test \ + scorer_test \ suffix_array_test \ + target_phrase_extractor_test \ + translation_table_test \ veb_test -TESTS = sampler_test -#TESTS = binary_search_merger_test \ -# data_array_test \ -# intersector_test \ -# linear_merger_test \ -# matching_comparator_test \ -# matching_test \ -# matchings_finder_test \ -# phrase_test \ -# precomputation_test \ -# suffix_array_test \ -# veb_test +TESTS = alignment_test \ + binary_search_merger_test \ + data_array_test \ + feature_count_source_target_test \ + feature_is_source_singleton_test \ + feature_is_source_target_singleton_test \ + feature_max_lex_source_given_target_test \ + feature_max_lex_target_given_source_test \ + feature_sample_source_count_test \ + feature_target_given_source_coherent_test \ + grammar_extractor_test \ + intersector_test \ + linear_merger_test \ + matching_comparator_test \ + matching_test \ + matchings_finder_test \ + phrase_test \ + precomputation_test \ + rule_extractor_helper_test \ + rule_extractor_test \ + rule_factory_test \ + sampler_test \ + scorer_test \ + suffix_array_test \ + target_phrase_extractor_test \ + translation_table_test \ + veb_test +alignment_test_SOURCES = alignment_test.cc +alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a binary_search_merger_test_SOURCES = binary_search_merger_test.cc binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a data_array_test_SOURCES = data_array_test.cc data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_count_source_target_test_SOURCES = features/count_source_target_test.cc +feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc +feature_is_source_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_target_singleton_test_SOURCES = features/is_source_target_singleton_test.cc +feature_is_source_target_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_max_lex_source_given_target_test_SOURCES = features/max_lex_source_given_target_test.cc +feature_max_lex_source_given_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_max_lex_target_given_source_test_SOURCES = features/max_lex_target_given_source_test.cc +feature_max_lex_target_given_source_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_sample_source_count_test_SOURCES = features/sample_source_count_test.cc +feature_sample_source_count_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_target_given_source_coherent_test_SOURCES = features/target_given_source_coherent_test.cc +feature_target_given_source_coherent_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +grammar_extractor_test_SOURCES = grammar_extractor_test.cc +grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a intersector_test_SOURCES = intersector_test.cc intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a linear_merger_test_SOURCES = linear_merger_test.cc @@ -45,10 +93,22 @@ phrase_test_SOURCES = phrase_test.cc phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a precomputation_test_SOURCES = precomputation_test.cc precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a -suffix_array_test_SOURCES = suffix_array_test.cc -suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_helper_test_SOURCES = rule_extractor_helper_test.cc +rule_extractor_helper_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_test_SOURCES = rule_extractor_test.cc +rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_factory_test_SOURCES = rule_factory_test.cc +rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a sampler_test_SOURCES = sampler_test.cc sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +scorer_test_SOURCES = scorer_test.cc +scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_test_SOURCES = suffix_array_test.cc +suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc +target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +translation_table_test_SOURCES = translation_table_test.cc +translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a veb_test_SOURCES = veb_test.cc veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -93,10 +153,12 @@ libextractor_a_SOURCES = \ precomputation.cc \ rule.cc \ rule_extractor.cc \ + rule_extractor_helper.cc \ rule_factory.cc \ sampler.cc \ scorer.cc \ suffix_array.cc \ + target_phrase_extractor.cc \ translation_table.cc \ veb.cc \ veb_bitset.cc \ diff --git a/extractor/alignment.cc b/extractor/alignment.cc index 2fa0abac..ff39d484 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -31,7 +31,11 @@ Alignment::Alignment(const string& filename) { alignments.shrink_to_fit(); } -const vector<pair<int, int> >& Alignment::GetLinks(int sentence_index) const { +Alignment::Alignment() {} + +Alignment::~Alignment() {} + +vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const { return alignments[sentence_index]; } diff --git a/extractor/alignment.h b/extractor/alignment.h index 290d6015..f7e79585 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -13,10 +13,15 @@ class Alignment { public: Alignment(const string& filename); - const vector<pair<int, int> >& GetLinks(int sentence_index) const; + virtual vector<pair<int, int> > GetLinks(int sentence_index) const; void WriteBinary(const fs::path& filepath); + virtual ~Alignment(); + + protected: + Alignment(); + private: vector<vector<pair<int, int> > > alignments; }; diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc new file mode 100644 index 00000000..1bc51a56 --- /dev/null +++ b/extractor/alignment_test.cc @@ -0,0 +1,31 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "alignment.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class AlignmentTest : public Test { + protected: + virtual void SetUp() { + alignment = make_shared<Alignment>("sample_alignment.txt"); + } + + shared_ptr<Alignment> alignment; +}; + +TEST_F(AlignmentTest, TestGetLinks) { + vector<pair<int, int> > expected_links = { + make_pair(0, 0), make_pair(1, 1), make_pair(2, 2) + }; + EXPECT_EQ(expected_links, alignment->GetLinks(0)); + expected_links = {make_pair(1, 0), make_pair(2, 1)}; + EXPECT_EQ(expected_links, alignment->GetLinks(1)); +} + +} // namespace diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc index 43d2f734..c1b86a77 100644 --- a/extractor/binary_search_merger.cc +++ b/extractor/binary_search_merger.cc @@ -25,8 +25,10 @@ BinarySearchMerger::~BinarySearchMerger() {} void BinarySearchMerger::Merge( vector<int>& locations, const Phrase& phrase, const Phrase& suffix, - vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, - vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& suffix_start, + const vector<int>::iterator& suffix_end, int prefix_subpatterns, int suffix_subpatterns) const { if (IsIntersectionVoid(prefix_start, prefix_end, suffix_start, suffix_end, prefix_subpatterns, suffix_subpatterns, suffix)) { diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h index ffa47c8e..c887e012 100644 --- a/extractor/binary_search_merger.h +++ b/extractor/binary_search_merger.h @@ -24,8 +24,10 @@ class BinarySearchMerger { virtual void Merge( vector<int>& locations, const Phrase& phrase, const Phrase& suffix, - vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, - vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& suffix_start, + const vector<int>::iterator& suffix_end, int prefix_subpatterns, int suffix_subpatterns) const; static double BAEZA_YATES_FACTOR; diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 383b08a7..1097caf3 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -10,9 +10,9 @@ namespace fs = boost::filesystem; using namespace std; -int DataArray::END_OF_FILE = 0; +int DataArray::NULL_WORD = 0; int DataArray::END_OF_LINE = 1; -string DataArray::END_OF_FILE_STR = "__END_OF_FILE__"; +string DataArray::NULL_WORD_STR = "__NULL__"; string DataArray::END_OF_LINE_STR = "__END_OF_LINE__"; DataArray::DataArray() { @@ -47,9 +47,9 @@ DataArray::DataArray(const string& filename, const Side& side) { } void DataArray::InitializeDataArray() { - word2id[END_OF_FILE_STR] = END_OF_FILE; - id2word.push_back(END_OF_FILE_STR); - word2id[END_OF_LINE_STR] = END_OF_FILE; + word2id[NULL_WORD_STR] = NULL_WORD; + id2word.push_back(NULL_WORD_STR); + word2id[END_OF_LINE_STR] = END_OF_LINE; id2word.push_back(END_OF_LINE_STR); } @@ -87,6 +87,10 @@ int DataArray::AtIndex(int index) const { return data[index]; } +string DataArray::GetWordAtIndex(int index) const { + return id2word[data[index]]; +} + int DataArray::GetSize() const { return data.size(); } @@ -103,6 +107,11 @@ int DataArray::GetSentenceStart(int position) const { return sentence_start[position]; } +int DataArray::GetSentenceLength(int sentence_id) const { + // Ignore end of line markers. + return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1; +} + int DataArray::GetSentenceId(int position) const { return sentence_id[position]; } diff --git a/extractor/data_array.h b/extractor/data_array.h index 19fbff88..7c120b3c 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -2,14 +2,13 @@ #define _DATA_ARRAY_H_ #include <string> -#include <tr1/unordered_map> +#include <unordered_map> #include <vector> #include <boost/filesystem.hpp> namespace fs = boost::filesystem; using namespace std; -using namespace tr1; enum Side { SOURCE, @@ -18,9 +17,9 @@ enum Side { class DataArray { public: - static int END_OF_FILE; + static int NULL_WORD; static int END_OF_LINE; - static string END_OF_FILE_STR; + static string NULL_WORD_STR; static string END_OF_LINE_STR; DataArray(const string& filename); @@ -33,6 +32,8 @@ class DataArray { virtual int AtIndex(int index) const; + virtual string GetWordAtIndex(int index) const; + virtual int GetSize() const; virtual int GetVocabularySize() const; @@ -43,9 +44,12 @@ class DataArray { virtual string GetWord(int word_id) const; - int GetNumSentences() const; + virtual int GetNumSentences() const; + + virtual int GetSentenceStart(int position) const; - int GetSentenceStart(int position) const; + //TODO(pauldb): Add unit tests. + virtual int GetSentenceLength(int sentence_id) const; virtual int GetSentenceId(int position) const; diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc new file mode 100644 index 00000000..22633bb6 --- /dev/null +++ b/extractor/features/count_source_target_test.cc @@ -0,0 +1,32 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "count_source_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class CountSourceTargetTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<CountSourceTarget>(); + } + + shared_ptr<CountSourceTarget> feature; +}; + +TEST_F(CountSourceTargetTest, TestGetName) { + EXPECT_EQ("CountEF", feature->GetName()); +} + +TEST_F(CountSourceTargetTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 9, 13); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc index 7381c35a..876f5f8f 100644 --- a/extractor/features/feature.cc +++ b/extractor/features/feature.cc @@ -1,3 +1,5 @@ #include "feature.h" const double Feature::MAX_SCORE = 99.0; + +Feature::~Feature() {} diff --git a/extractor/features/feature.h b/extractor/features/feature.h index ad22d3e7..aca58401 100644 --- a/extractor/features/feature.h +++ b/extractor/features/feature.h @@ -10,14 +10,16 @@ using namespace std; struct FeatureContext { FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, - double sample_source_count, int pair_count) : + double source_phrase_count, int pair_count, int num_samples) : source_phrase(source_phrase), target_phrase(target_phrase), - sample_source_count(sample_source_count), pair_count(pair_count) {} + source_phrase_count(source_phrase_count), pair_count(pair_count), + num_samples(num_samples) {} Phrase source_phrase; Phrase target_phrase; - double sample_source_count; + double source_phrase_count; int pair_count; + int num_samples; }; class Feature { @@ -26,6 +28,8 @@ class Feature { virtual string GetName() const = 0; + virtual ~Feature(); + static const double MAX_SCORE; }; diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc index 754df3bf..98d4e5fe 100644 --- a/extractor/features/is_source_singleton.cc +++ b/extractor/features/is_source_singleton.cc @@ -3,7 +3,7 @@ #include <cmath> double IsSourceSingleton::Score(const FeatureContext& context) const { - return context.sample_source_count == 1; + return context.source_phrase_count == 1; } string IsSourceSingleton::GetName() const { diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc new file mode 100644 index 00000000..8c71e593 --- /dev/null +++ b/extractor/features/is_source_singleton_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<IsSourceSingleton>(); + } + + shared_ptr<IsSourceSingleton> feature; +}; + +TEST_F(IsSourceSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonF", feature->GetName()); +} + +TEST_F(IsSourceSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 31); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1, 3, 25); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc index ec816509..31d36532 100644 --- a/extractor/features/is_source_target_singleton.cc +++ b/extractor/features/is_source_target_singleton.cc @@ -7,5 +7,5 @@ double IsSourceTargetSingleton::Score(const FeatureContext& context) const { } string IsSourceTargetSingleton::GetName() const { - return "IsSingletonEF"; + return "IsSingletonFE"; } diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc new file mode 100644 index 00000000..a51f77c9 --- /dev/null +++ b/extractor/features/is_source_target_singleton_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_target_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceTargetSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<IsSourceTargetSingleton>(); + } + + shared_ptr<IsSourceTargetSingleton> feature; +}; + +TEST_F(IsSourceTargetSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonFE", feature->GetName()); +} + +TEST_F(IsSourceTargetSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 7); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 2.3, 1, 28); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc index c4792d49..21f5c76a 100644 --- a/extractor/features/max_lex_source_given_target.cc +++ b/extractor/features/max_lex_source_given_target.cc @@ -2,6 +2,7 @@ #include <cmath> +#include "../data_array.h" #include "../translation_table.h" MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( @@ -10,8 +11,8 @@ MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { vector<string> source_words = context.source_phrase.GetWords(); - // TODO(pauldb): Add NULL to target_words, after fixing translation table. vector<string> target_words = context.target_phrase.GetWords(); + target_words.push_back(DataArray::NULL_WORD_STR); double score = 0; for (string source_word: source_words) { @@ -26,5 +27,5 @@ double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { } string MaxLexSourceGivenTarget::GetName() const { - return "MaxLexFGivenE"; + return "MaxLexFgivenE"; } diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc new file mode 100644 index 00000000..5fd41f8b --- /dev/null +++ b/extractor/features/max_lex_source_given_target_test.cc @@ -0,0 +1,74 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_source_given_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexSourceGivenTargetTest : public Test { + protected: + virtual void SetUp() { + vector<string> source_words = {"f1", "f2", "f3"}; + vector<string> target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + table = make_shared<MockTranslationTable>(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < source_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], DataArray::NULL_WORD_STR)) + .WillRepeatedly(Return(value)); + } + + feature = make_shared<MaxLexSourceGivenTarget>(table); + } + + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockTranslationTable> table; + shared_ptr<MaxLexSourceGivenTarget> feature; +}; + +TEST_F(MaxLexSourceGivenTargetTest, TestGetName) { + EXPECT_EQ("MaxLexFgivenE", feature->GetName()); +} + +TEST_F(MaxLexSourceGivenTargetTest, TestScore) { + vector<int> source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector<int> target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11); + EXPECT_EQ(99 - log10(18), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc index d82182fe..f2bc2474 100644 --- a/extractor/features/max_lex_target_given_source.cc +++ b/extractor/features/max_lex_target_given_source.cc @@ -2,6 +2,7 @@ #include <cmath> +#include "../data_array.h" #include "../translation_table.h" MaxLexTargetGivenSource::MaxLexTargetGivenSource( @@ -9,8 +10,8 @@ MaxLexTargetGivenSource::MaxLexTargetGivenSource( table(table) {} double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { - // TODO(pauldb): Add NULL to source_words, after fixing translation table. vector<string> source_words = context.source_phrase.GetWords(); + source_words.push_back(DataArray::NULL_WORD_STR); vector<string> target_words = context.target_phrase.GetWords(); double score = 0; @@ -26,5 +27,5 @@ double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { } string MaxLexTargetGivenSource::GetName() const { - return "MaxLexEGivenF"; + return "MaxLexEgivenF"; } diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc new file mode 100644 index 00000000..c8701bf7 --- /dev/null +++ b/extractor/features/max_lex_target_given_source_test.cc @@ -0,0 +1,74 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_target_given_source.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexTargetGivenSourceTest : public Test { + protected: + virtual void SetUp() { + vector<string> source_words = {"f1", "f2", "f3"}; + vector<string> target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + table = make_shared<MockTranslationTable>(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < target_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + DataArray::NULL_WORD_STR, target_words[i])) + .WillRepeatedly(Return(value)); + } + + feature = make_shared<MaxLexTargetGivenSource>(table); + } + + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockTranslationTable> table; + shared_ptr<MaxLexTargetGivenSource> feature; +}; + +TEST_F(MaxLexTargetGivenSourceTest, TestGetName) { + EXPECT_EQ("MaxLexEgivenF", feature->GetName()); +} + +TEST_F(MaxLexTargetGivenSourceTest, TestScore) { + vector<int> source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector<int> target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19); + EXPECT_EQ(-log10(36), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc index c8124cfb..88b645b1 100644 --- a/extractor/features/sample_source_count.cc +++ b/extractor/features/sample_source_count.cc @@ -3,7 +3,7 @@ #include <cmath> double SampleSourceCount::Score(const FeatureContext& context) const { - return log10(1 + context.sample_source_count); + return log10(1 + context.num_samples); } string SampleSourceCount::GetName() const { diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc new file mode 100644 index 00000000..7d226104 --- /dev/null +++ b/extractor/features/sample_source_count_test.cc @@ -0,0 +1,36 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "sample_source_count.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class SampleSourceCountTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<SampleSourceCount>(); + } + + shared_ptr<SampleSourceCount> feature; +}; + +TEST_F(SampleSourceCountTest, TestGetName) { + EXPECT_EQ("SampleCountF", feature->GetName()); +} + +TEST_F(SampleSourceCountTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0, 3, 1); + EXPECT_EQ(log10(2), feature->Score(context)); + + context = FeatureContext(phrase, phrase, 3.2, 3, 9); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc index 748413c3..274b3364 100644 --- a/extractor/features/target_given_source_coherent.cc +++ b/extractor/features/target_given_source_coherent.cc @@ -3,10 +3,10 @@ #include <cmath> double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { - double prob = context.pair_count / context.sample_source_count; + double prob = (double) context.pair_count / context.num_samples; return prob > 0 ? -log10(prob) : MAX_SCORE; } string TargetGivenSourceCoherent::GetName() const { - return "EGivenFCoherent"; + return "EgivenFCoherent"; } diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc new file mode 100644 index 00000000..c54c06c2 --- /dev/null +++ b/extractor/features/target_given_source_coherent_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "target_given_source_coherent.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class TargetGivenSourceCoherentTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<TargetGivenSourceCoherent>(); + } + + shared_ptr<TargetGivenSourceCoherent> feature; +}; + +TEST_F(TargetGivenSourceCoherentTest, TestGetName) { + EXPECT_EQ("EgivenFCoherent", feature->GetName()); +} + +TEST_F(TargetGivenSourceCoherentTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.3, 2, 20); + EXPECT_EQ(1.0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1.9, 0, 1); + EXPECT_EQ(99.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/grammar.cc b/extractor/grammar.cc index 79a0541d..8124a804 100644 --- a/extractor/grammar.cc +++ b/extractor/grammar.cc @@ -1,17 +1,32 @@ #include "grammar.h" +#include <iomanip> + #include "rule.h" +using namespace std; + Grammar::Grammar(const vector<Rule>& rules, const vector<string>& feature_names) : rules(rules), feature_names(feature_names) {} +vector<Rule> Grammar::GetRules() const { + return rules; +} + +vector<string> Grammar::GetFeatureNames() const { + return feature_names; +} + ostream& operator<<(ostream& os, const Grammar& grammar) { - for (Rule rule: grammar.rules) { + vector<Rule> rules = grammar.GetRules(); + vector<string> feature_names = grammar.GetFeatureNames(); + os << setprecision(12); + for (Rule rule: rules) { os << "[X] ||| " << rule.source_phrase << " ||| " << rule.target_phrase << " |||"; for (size_t i = 0; i < rule.scores.size(); ++i) { - os << " " << grammar.feature_names[i] << "=" << rule.scores[i]; + os << " " << feature_names[i] << "=" << rule.scores[i]; } os << " |||"; for (auto link: rule.alignment) { diff --git a/extractor/grammar.h b/extractor/grammar.h index db15fa7e..889cc2f3 100644 --- a/extractor/grammar.h +++ b/extractor/grammar.h @@ -13,6 +13,10 @@ class Grammar { public: Grammar(const vector<Rule>& rules, const vector<string>& feature_names); + vector<Rule> GetRules() const; + + vector<string> GetFeatureNames() const; + friend ostream& operator<<(ostream& os, const Grammar& grammar); private: diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 15268165..2f008026 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -10,19 +10,6 @@ using namespace std; -vector<string> Tokenize(const string& sentence) { - vector<string> result; - result.push_back("<s>"); - - istringstream buffer(sentence); - copy(istream_iterator<string>(buffer), - istream_iterator<string>(), - back_inserter(result)); - - result.push_back("</s>"); - return result; -} - GrammarExtractor::GrammarExtractor( shared_ptr<SuffixArray> source_suffix_array, shared_ptr<DataArray> target_data_array, @@ -31,15 +18,35 @@ GrammarExtractor::GrammarExtractor( int max_nonterminals, int max_rule_symbols, int max_samples, bool use_baeza_yates, bool require_tight_phrases) : vocabulary(make_shared<Vocabulary>()), - rule_factory(source_suffix_array, target_data_array, alignment, - vocabulary, precomputation, scorer, min_gap_size, max_rule_span, - max_nonterminals, max_rule_symbols, max_samples, use_baeza_yates, - require_tight_phrases) {} + rule_factory(make_shared<HieroCachingRuleFactory>( + source_suffix_array, target_data_array, alignment, vocabulary, + precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, + max_rule_symbols, max_samples, use_baeza_yates, + require_tight_phrases)) {} + +GrammarExtractor::GrammarExtractor( + shared_ptr<Vocabulary> vocabulary, + shared_ptr<HieroCachingRuleFactory> rule_factory) : + vocabulary(vocabulary), + rule_factory(rule_factory) {} Grammar GrammarExtractor::GetGrammar(const string& sentence) { - vector<string> words = Tokenize(sentence); + vector<string> words = TokenizeSentence(sentence); vector<int> word_ids = AnnotateWords(words); - return rule_factory.GetGrammar(word_ids); + return rule_factory->GetGrammar(word_ids); +} + +vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { + vector<string> result; + result.push_back("<s>"); + + istringstream buffer(sentence); + copy(istream_iterator<string>(buffer), + istream_iterator<string>(), + back_inserter(result)); + + result.push_back("</s>"); + return result; } vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 243f33cf..5f87faa7 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -32,13 +32,19 @@ class GrammarExtractor { bool use_baeza_yates, bool require_tight_phrases); + // For testing only. + GrammarExtractor(shared_ptr<Vocabulary> vocabulary, + shared_ptr<HieroCachingRuleFactory> rule_factory); + Grammar GetGrammar(const string& sentence); private: + vector<string> TokenizeSentence(const string& sentence); + vector<int> AnnotateWords(const vector<string>& words); shared_ptr<Vocabulary> vocabulary; - HieroCachingRuleFactory rule_factory; + shared_ptr<HieroCachingRuleFactory> rule_factory; }; #endif diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc new file mode 100644 index 00000000..d4ed7d4f --- /dev/null +++ b/extractor/grammar_extractor_test.cc @@ -0,0 +1,49 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "grammar_extractor.h" +#include "mocks/mock_rule_factory.h" +#include "mocks/mock_vocabulary.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace { + +TEST(GrammarExtractorTest, TestAnnotatingWords) { + shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>")) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna")) + .WillRepeatedly(Return(1)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("has")) + .WillRepeatedly(Return(2)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("many")) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("apples")) + .WillRepeatedly(Return(4)); + EXPECT_CALL(*vocabulary, GetTerminalIndex(".")) + .WillRepeatedly(Return(5)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>")) + .WillRepeatedly(Return(6)); + + shared_ptr<MockHieroCachingRuleFactory> factory = + make_shared<MockHieroCachingRuleFactory>(); + vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6}; + vector<Rule> rules; + vector<string> feature_names; + Grammar grammar(rules, feature_names); + EXPECT_CALL(*factory, GetGrammar(word_ids)) + .WillOnce(Return(grammar)); + + GrammarExtractor extractor(vocabulary, factory); + string sentence = "Anna has many many apples ."; + extractor.GetGrammar(sentence); +} + +} // namespace diff --git a/extractor/intersector.cc b/extractor/intersector.cc index b53479af..cf42f630 100644 --- a/extractor/intersector.cc +++ b/extractor/intersector.cc @@ -1,5 +1,7 @@ #include "intersector.h" +#include <chrono> + #include "data_array.h" #include "matching_comparator.h" #include "phrase.h" @@ -9,6 +11,10 @@ #include "veb.h" #include "vocabulary.h" +using namespace std::chrono; + +typedef high_resolution_clock Clock; + Intersector::Intersector(shared_ptr<Vocabulary> vocabulary, shared_ptr<Precomputation> precomputation, shared_ptr<SuffixArray> suffix_array, @@ -38,12 +44,22 @@ Intersector::Intersector(shared_ptr<Vocabulary> vocabulary, ConvertIndexes(precomputation, suffix_array->GetData()); } +Intersector::Intersector() {} + +Intersector::~Intersector() {} + void Intersector::ConvertIndexes(shared_ptr<Precomputation> precomputation, shared_ptr<DataArray> data_array) { const Index& precomputed_index = precomputation->GetInvertedIndex(); for (pair<vector<int>, vector<int> > entry: precomputed_index) { vector<int> phrase = ConvertPhrase(entry.first, data_array); inverted_index[phrase] = entry.second; + + phrase.push_back(vocabulary->GetNonterminalIndex(1)); + inverted_index[phrase] = entry.second; + phrase.pop_back(); + phrase.insert(phrase.begin(), vocabulary->GetNonterminalIndex(1)); + inverted_index[phrase] = entry.second; } const Index& precomputed_collocations = precomputation->GetCollocations(); @@ -76,6 +92,9 @@ PhraseLocation Intersector::Intersect( const Phrase& prefix, PhraseLocation& prefix_location, const Phrase& suffix, PhraseLocation& suffix_location, const Phrase& phrase) { + if (linear_merge_time == 0) { + linear_merger->linear_merge_time = 0; + } vector<int> symbols = phrase.Get(); // We should never attempt to do an intersect query for a pattern starting or @@ -95,17 +114,23 @@ PhraseLocation Intersector::Intersect( shared_ptr<vector<int> > prefix_matchings = prefix_location.matchings; shared_ptr<vector<int> > suffix_matchings = suffix_location.matchings; int prefix_subpatterns = prefix_location.num_subpatterns; - int suffix_subpatterns = prefix_location.num_subpatterns; + int suffix_subpatterns = suffix_location.num_subpatterns; if (use_baeza_yates) { + double prev_linear_merge_time = linear_merger->linear_merge_time; + Clock::time_point start = Clock::now(); binary_search_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); + Clock::time_point stop = Clock::now(); + binary_merge_time += duration_cast<milliseconds>(stop - start).count() - + (linear_merger->linear_merge_time - prev_linear_merge_time); } else { linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); } + linear_merge_time = linear_merger->linear_merge_time; return PhraseLocation(locations, phrase.Arity() + 1); } @@ -116,6 +141,8 @@ void Intersector::ExtendPhraseLocation( return; } + Clock::time_point sort_start = Clock::now(); + phrase_location.num_subpatterns = 1; phrase_location.sa_low = phrase_location.sa_high = 0; @@ -140,4 +167,6 @@ void Intersector::ExtendPhraseLocation( } phrase_location.matchings = make_shared<vector<int> >(matchings); + Clock::time_point sort_stop = Clock::now(); + sort_time += duration_cast<milliseconds>(sort_stop - sort_start).count(); } diff --git a/extractor/intersector.h b/extractor/intersector.h index f023cc96..8b159f17 100644 --- a/extractor/intersector.h +++ b/extractor/intersector.h @@ -2,7 +2,7 @@ #define _INTERSECTOR_H_ #include <memory> -#include <tr1/unordered_map> +#include <unordered_map> #include <vector> #include <boost/functional/hash.hpp> @@ -11,7 +11,6 @@ #include "linear_merger.h" using namespace std; -using namespace tr1; typedef boost::hash<vector<int> > VectorHash; typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; @@ -42,11 +41,16 @@ class Intersector { shared_ptr<BinarySearchMerger> binary_search_merger, bool use_baeza_yates); - PhraseLocation Intersect( + virtual ~Intersector(); + + virtual PhraseLocation Intersect( const Phrase& prefix, PhraseLocation& prefix_location, const Phrase& suffix, PhraseLocation& suffix_location, const Phrase& phrase); + protected: + Intersector(); + private: void ConvertIndexes(shared_ptr<Precomputation> precomputation, shared_ptr<DataArray> data_array); @@ -64,6 +68,12 @@ class Intersector { Index inverted_index; Index collocations; bool use_baeza_yates; + + // TODO(pauldb): Don't forget to remove these. + public: + double sort_time; + double linear_merge_time; + double binary_merge_time; }; #endif diff --git a/extractor/intersector_test.cc b/extractor/intersector_test.cc index a3756902..ec318362 100644 --- a/extractor/intersector_test.cc +++ b/extractor/intersector_test.cc @@ -34,7 +34,7 @@ class IntersectorTest : public Test { .WillRepeatedly(Return(words[i])); } - vector<int> suffixes = {0, 1, 3, 5, 2, 4, 6}; + vector<int> suffixes = {6, 0, 5, 3, 1, 4, 2}; suffix_array = make_shared<MockSuffixArray>(); EXPECT_CALL(*suffix_array, GetData()) .WillRepeatedly(Return(data_array)); @@ -103,7 +103,7 @@ TEST_F(IntersectorTest, TestLinearMergeaXb) { Phrase suffix = phrase_builder->Build(suffix_symbols); vector<int> symbols = {3, -1, 4}; Phrase phrase = phrase_builder->Build(symbols); - PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6); + PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7); vector<int> ex_prefix_locs = {1, 3, 5}; PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); @@ -135,7 +135,7 @@ TEST_F(IntersectorTest, TestBinarySearchMergeaXb) { Phrase suffix = phrase_builder->Build(suffix_symbols); vector<int> symbols = {3, -1, 4}; Phrase phrase = phrase_builder->Build(symbols); - PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6); + PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7); vector<int> ex_prefix_locs = {1, 3, 5}; PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc index 666f8d87..7233f945 100644 --- a/extractor/linear_merger.cc +++ b/extractor/linear_merger.cc @@ -1,5 +1,6 @@ #include "linear_merger.h" +#include <chrono> #include <cmath> #include "data_array.h" @@ -9,6 +10,10 @@ #include "phrase_location.h" #include "vocabulary.h" +using namespace std::chrono; + +typedef high_resolution_clock Clock; + LinearMerger::LinearMerger(shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array, shared_ptr<MatchingComparator> comparator) : @@ -22,7 +27,9 @@ void LinearMerger::Merge( vector<int>& locations, const Phrase& phrase, const Phrase& suffix, vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, - int prefix_subpatterns, int suffix_subpatterns) const { + int prefix_subpatterns, int suffix_subpatterns) { + Clock::time_point start = Clock::now(); + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); @@ -62,4 +69,7 @@ void LinearMerger::Merge( prefix_start += prefix_subpatterns; } } + + Clock::time_point stop = Clock::now(); + linear_merge_time += duration_cast<milliseconds>(stop - start).count(); } diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h index 6a69b804..25692b15 100644 --- a/extractor/linear_merger.h +++ b/extractor/linear_merger.h @@ -24,7 +24,7 @@ class LinearMerger { vector<int>& locations, const Phrase& phrase, const Phrase& suffix, vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, - int prefix_subpatterns, int suffix_subpatterns) const; + int prefix_subpatterns, int suffix_subpatterns); protected: LinearMerger(); @@ -33,6 +33,10 @@ class LinearMerger { shared_ptr<Vocabulary> vocabulary; shared_ptr<DataArray> data_array; shared_ptr<MatchingComparator> comparator; + + // TODO(pauldb): Remove this eventually. + public: + double linear_merge_time; }; #endif diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc index ba4edab1..eaf493b2 100644 --- a/extractor/matchings_finder.cc +++ b/extractor/matchings_finder.cc @@ -6,6 +6,10 @@ MatchingsFinder::MatchingsFinder(shared_ptr<SuffixArray> suffix_array) : suffix_array(suffix_array) {} +MatchingsFinder::MatchingsFinder() {} + +MatchingsFinder::~MatchingsFinder() {} + PhraseLocation MatchingsFinder::Find(PhraseLocation& location, const string& word, int offset) { if (location.sa_low == -1 && location.sa_high == -1) { diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h index 0458a4d8..ed04d8b8 100644 --- a/extractor/matchings_finder.h +++ b/extractor/matchings_finder.h @@ -13,7 +13,13 @@ class MatchingsFinder { public: MatchingsFinder(shared_ptr<SuffixArray> suffix_array); - PhraseLocation Find(PhraseLocation& location, const string& word, int offset); + virtual ~MatchingsFinder(); + + virtual PhraseLocation Find(PhraseLocation& location, const string& word, + int offset); + + protected: + MatchingsFinder(); private: shared_ptr<SuffixArray> suffix_array; diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc index 851d4596..921ec582 100644 --- a/extractor/matchings_trie.cc +++ b/extractor/matchings_trie.cc @@ -1,11 +1,19 @@ #include "matchings_trie.h" void MatchingsTrie::Reset() { - // TODO(pauldb): This is probably memory leaking because of the suffix links. - // Check if it's true and free the memory properly. - root.reset(new TrieNode()); + ResetTree(root); + root = make_shared<TrieNode>(); } shared_ptr<TrieNode> MatchingsTrie::GetRoot() const { return root; } + +void MatchingsTrie::ResetTree(shared_ptr<TrieNode> root) { + if (root != NULL) { + for (auto child: root->children) { + ResetTree(child.second); + } + root.reset(); + } +} diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index f935d1a9..6e72b2db 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -2,13 +2,12 @@ #define _MATCHINGS_TRIE_ #include <memory> -#include <tr1/unordered_map> +#include <unordered_map> #include "phrase.h" #include "phrase_location.h" using namespace std; -using namespace tr1; struct TrieNode { TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(), @@ -40,6 +39,8 @@ class MatchingsTrie { shared_ptr<TrieNode> GetRoot() const; private: + void ResetTree(shared_ptr<TrieNode> root); + shared_ptr<TrieNode> root; }; diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h new file mode 100644 index 00000000..4a5077ad --- /dev/null +++ b/extractor/mocks/mock_alignment.h @@ -0,0 +1,10 @@ +#include <gmock/gmock.h> + +#include "../alignment.h" + +typedef vector<pair<int, int> > SentenceLinks; + +class MockAlignment : public Alignment { + public: + MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id)); +}; diff --git a/extractor/mocks/mock_binary_search_merger.h b/extractor/mocks/mock_binary_search_merger.h index e1375ee3..e23386f0 100644 --- a/extractor/mocks/mock_binary_search_merger.h +++ b/extractor/mocks/mock_binary_search_merger.h @@ -10,6 +10,6 @@ using namespace std; class MockBinarySearchMerger: public BinarySearchMerger { public: MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&, - vector<int>::iterator, vector<int>::iterator, vector<int>::iterator, - vector<int>::iterator, int, int)); + const vector<int>::iterator&, const vector<int>::iterator&, + const vector<int>::iterator&, const vector<int>::iterator&, int, int)); }; diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 54497cf5..004e8906 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -6,10 +6,14 @@ class MockDataArray : public DataArray { public: MOCK_CONST_METHOD0(GetData, const vector<int>&()); MOCK_CONST_METHOD1(AtIndex, int(int index)); + MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); MOCK_CONST_METHOD1(HasWord, bool(const string& word)); MOCK_CONST_METHOD1(GetWordId, int(const string& word)); MOCK_CONST_METHOD1(GetWord, string(int word_id)); + MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); + MOCK_CONST_METHOD0(GetNumSentences, int()); + MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id)); MOCK_CONST_METHOD1(GetSentenceId, int(int position)); }; diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h new file mode 100644 index 00000000..d2137629 --- /dev/null +++ b/extractor/mocks/mock_feature.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../features/feature.h" + +class MockFeature : public Feature { + public: + MOCK_CONST_METHOD1(Score, double(const FeatureContext& context)); + MOCK_CONST_METHOD0(GetName, string()); +}; diff --git a/extractor/mocks/mock_intersector.h b/extractor/mocks/mock_intersector.h new file mode 100644 index 00000000..372fa7ea --- /dev/null +++ b/extractor/mocks/mock_intersector.h @@ -0,0 +1,11 @@ +#include <gmock/gmock.h> + +#include "../intersector.h" +#include "../phrase.h" +#include "../phrase_location.h" + +class MockIntersector : public Intersector { + public: + MOCK_METHOD5(Intersect, PhraseLocation(const Phrase&, PhraseLocation&, + const Phrase&, PhraseLocation&, const Phrase&)); +}; diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h index 82243428..522c1f31 100644 --- a/extractor/mocks/mock_linear_merger.h +++ b/extractor/mocks/mock_linear_merger.h @@ -9,7 +9,7 @@ using namespace std; class MockLinearMerger: public LinearMerger { public: - MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&, + MOCK_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&, vector<int>::iterator, vector<int>::iterator, vector<int>::iterator, vector<int>::iterator, int, int)); }; diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h new file mode 100644 index 00000000..3e80d266 --- /dev/null +++ b/extractor/mocks/mock_matchings_finder.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../matchings_finder.h" +#include "../phrase_location.h" + +class MockMatchingsFinder : public MatchingsFinder { + public: + MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int)); +}; diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h new file mode 100644 index 00000000..f18e009a --- /dev/null +++ b/extractor/mocks/mock_rule_extractor.h @@ -0,0 +1,12 @@ +#include <gmock/gmock.h> + +#include "../phrase.h" +#include "../phrase_builder.h" +#include "../rule.h" +#include "../rule_extractor.h" + +class MockRuleExtractor : public RuleExtractor { + public: + MOCK_CONST_METHOD2(ExtractRules, vector<Rule>(const Phrase&, + const PhraseLocation&)); +}; diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h new file mode 100644 index 00000000..63ff1048 --- /dev/null +++ b/extractor/mocks/mock_rule_extractor_helper.h @@ -0,0 +1,78 @@ +#include <gmock/gmock.h> + +#include <vector> + +#include "../rule_extractor_helper.h" + +using namespace std; + +typedef unordered_map<int, int> Indexes; + +class MockRuleExtractorHelper : public RuleExtractorHelper { + public: + MOCK_CONST_METHOD5(GetLinksSpans, void(vector<int>&, vector<int>&, + vector<int>&, vector<int>&, int)); + MOCK_CONST_METHOD3(CheckAlignedTerminals, bool(const vector<int>&, + const vector<int>&, const vector<int>&)); + MOCK_CONST_METHOD3(CheckTightPhrases, bool(const vector<int>&, + const vector<int>&, const vector<int>&)); + MOCK_CONST_METHOD1(GetGapOrder, vector<int>(const vector<pair<int, int> >&)); + MOCK_CONST_METHOD3(GetSourceIndexes, Indexes(const vector<int>&, + const vector<int>&, int)); + + // We need to implement these methods, because Google Mock doesn't support + // methods with more than 10 arguments. + bool FindFixPoint( + int, int, const vector<int>&, const vector<int>&, int& target_phrase_low, + int& target_phrase_high, const vector<int>&, const vector<int>&, + int& source_back_low, int& source_back_high, int, int, int, int, bool, + bool, bool) const { + target_phrase_low = this->target_phrase_low; + target_phrase_high = this->target_phrase_high; + source_back_low = this->source_back_low; + source_back_high = this->source_back_high; + return find_fix_point; + } + + bool GetGaps(vector<pair<int, int> >& source_gaps, + vector<pair<int, int> >& target_gaps, + const vector<int>&, const vector<int>&, const vector<int>&, + const vector<int>&, const vector<int>&, const vector<int>&, + int, int, int, int, int& num_symbols, + bool& met_constraints) const { + source_gaps = this->source_gaps; + target_gaps = this->target_gaps; + num_symbols = this->num_symbols; + met_constraints = this->met_constraints; + return get_gaps; + } + + void SetUp( + int target_phrase_low, int target_phrase_high, int source_back_low, + int source_back_high, bool find_fix_point, + vector<pair<int, int> > source_gaps, vector<pair<int, int> > target_gaps, + int num_symbols, bool met_constraints, bool get_gaps) { + this->target_phrase_low = target_phrase_low; + this->target_phrase_high = target_phrase_high; + this->source_back_low = source_back_low; + this->source_back_high = source_back_high; + this->find_fix_point = find_fix_point; + this->source_gaps = source_gaps; + this->target_gaps = target_gaps; + this->num_symbols = num_symbols; + this->met_constraints = met_constraints; + this->get_gaps = get_gaps; + } + + private: + int target_phrase_low; + int target_phrase_high; + int source_back_low; + int source_back_high; + bool find_fix_point; + vector<pair<int, int> > source_gaps; + vector<pair<int, int> > target_gaps; + int num_symbols; + bool met_constraints; + bool get_gaps; +}; diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h new file mode 100644 index 00000000..2a96be93 --- /dev/null +++ b/extractor/mocks/mock_rule_factory.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../grammar.h" +#include "../rule_factory.h" + +class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { + public: + MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids)); +}; diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h new file mode 100644 index 00000000..b2306109 --- /dev/null +++ b/extractor/mocks/mock_sampler.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../phrase_location.h" +#include "../sampler.h" + +class MockSampler : public Sampler { + public: + MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); +}; diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h new file mode 100644 index 00000000..48115ef4 --- /dev/null +++ b/extractor/mocks/mock_scorer.h @@ -0,0 +1,10 @@ +#include <gmock/gmock.h> + +#include "../scorer.h" +#include "../features/feature.h" + +class MockScorer : public Scorer { + public: + MOCK_CONST_METHOD1(Score, vector<double>(const FeatureContext& context)); + MOCK_CONST_METHOD0(GetFeatureNames, vector<string>()); +}; diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h new file mode 100644 index 00000000..6dc6bba6 --- /dev/null +++ b/extractor/mocks/mock_target_phrase_extractor.h @@ -0,0 +1,12 @@ +#include <gmock/gmock.h> + +#include "../target_phrase_extractor.h" + +typedef pair<Phrase, PhraseAlignment> PhraseExtract; + +class MockTargetPhraseExtractor : public TargetPhraseExtractor { + public: + MOCK_CONST_METHOD6(ExtractPhrases, vector<PhraseExtract>( + const vector<pair<int, int> > &, const vector<int>&, int, int, + const unordered_map<int, int>&, int)); +}; diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h new file mode 100644 index 00000000..a35c9327 --- /dev/null +++ b/extractor/mocks/mock_translation_table.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../translation_table.h" + +class MockTranslationTable : public TranslationTable { + public: + MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&)); + MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&)); +}; diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc index c4e0c2ed..4325390c 100644 --- a/extractor/phrase_builder.cc +++ b/extractor/phrase_builder.cc @@ -9,10 +9,9 @@ PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) : Phrase PhraseBuilder::Build(const vector<int>& symbols) { Phrase phrase; phrase.symbols = symbols; - phrase.words.resize(symbols.size()); for (size_t i = 0; i < symbols.size(); ++i) { if (vocabulary->IsTerminal(symbols[i])) { - phrase.words[i] = vocabulary->GetTerminalValue(symbols[i]); + phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i])); } else { phrase.var_pos.push_back(i); } @@ -30,7 +29,7 @@ Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) { } for (size_t i = start_x; i < symbols.size(); ++i) { - if (vocabulary->IsTerminal(symbols[i])) { + if (!vocabulary->IsTerminal(symbols[i])) { ++num_nonterminals; symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals); } diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 984407c5..62f1e714 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -10,7 +10,11 @@ PhraseLocation::PhraseLocation(const vector<int>& matchings, num_subpatterns(num_subpatterns) {} bool PhraseLocation::IsEmpty() { - return sa_low >= sa_high || (num_subpatterns > 0 && matchings->size() == 0); + if (num_subpatterns > 0) { + return matchings->size() == 0; + } else { + return sa_low >= sa_high; + } } bool operator==(const PhraseLocation& a, const PhraseLocation& b) { diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 9a167976..8a76beb1 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -7,7 +7,6 @@ #include "suffix_array.h" using namespace std; -using namespace tr1; int Precomputation::NON_TERMINAL = -1; @@ -79,13 +78,16 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns( } vector<vector<int> > frequent_patterns; - for (size_t i = 0; i < num_frequent_patterns && !heap.empty(); ++i) { + while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { int start = heap.top().second.first; int len = heap.top().second.second; heap.pop(); vector<int> pattern(data.begin() + start, data.begin() + start + len); - frequent_patterns.push_back(pattern); + if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == + pattern.end()) { + frequent_patterns.push_back(pattern); + } } return frequent_patterns; } diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 428505d8..28426bfa 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -2,8 +2,8 @@ #define _PRECOMPUTATION_H_ #include <memory> -#include <tr1/unordered_map> -#include <tr1/unordered_set> +#include <unordered_map> +#include <unordered_set> #include <tuple> #include <vector> @@ -12,7 +12,6 @@ namespace fs = boost::filesystem; using namespace std; -using namespace tr1; class SuffixArray; diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc index 9460020f..92343241 100644 --- a/extractor/rule_extractor.cc +++ b/extractor/rule_extractor.cc @@ -1,7 +1,6 @@ #include "rule_extractor.h" #include <map> -#include <tr1/unordered_set> #include "alignment.h" #include "data_array.h" @@ -9,11 +8,11 @@ #include "phrase_builder.h" #include "phrase_location.h" #include "rule.h" +#include "rule_extractor_helper.h" #include "scorer.h" -#include "vocabulary.h" +#include "target_phrase_extractor.h" using namespace std; -using namespace tr1; RuleExtractor::RuleExtractor( shared_ptr<DataArray> source_data_array, @@ -29,20 +28,50 @@ RuleExtractor::RuleExtractor( bool require_aligned_terminal, bool require_aligned_chunks, bool require_tight_phrases) : - source_data_array(source_data_array), target_data_array(target_data_array), - alignment(alignment), + source_data_array(source_data_array), phrase_builder(phrase_builder), scorer(scorer), - vocabulary(vocabulary), max_rule_span(max_rule_span), min_gap_size(min_gap_size), max_nonterminals(max_nonterminals), max_rule_symbols(max_rule_symbols), - require_aligned_terminal(require_aligned_terminal), - require_aligned_chunks(require_aligned_chunks), + require_tight_phrases(require_tight_phrases) { + helper = make_shared<RuleExtractorHelper>( + source_data_array, target_data_array, alignment, max_rule_span, + max_rule_symbols, require_aligned_terminal, require_aligned_chunks, + require_tight_phrases); + target_phrase_extractor = make_shared<TargetPhraseExtractor>( + target_data_array, alignment, phrase_builder, helper, vocabulary, + max_rule_span, require_tight_phrases); +} + +RuleExtractor::RuleExtractor( + shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases) : + source_data_array(source_data_array), + phrase_builder(phrase_builder), + scorer(scorer), + target_phrase_extractor(target_phrase_extractor), + helper(helper), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), require_tight_phrases(require_tight_phrases) {} +RuleExtractor::RuleExtractor() {} + +RuleExtractor::~RuleExtractor() {} + vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, const PhraseLocation& location) const { int num_subpatterns = location.num_subpatterns; @@ -60,6 +89,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, } } + int num_samples = matchings.size() / num_subpatterns; vector<Rule> rules; for (auto source_phrase_entry: alignments_counter) { Phrase source_phrase = source_phrase_entry.first; @@ -77,7 +107,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, } FeatureContext context(source_phrase, target_phrase, - source_phrase_counter[source_phrase], num_locations); + source_phrase_counter[source_phrase], num_locations, num_samples); vector<double> scores = scorer->Score(context); rules.push_back(Rule(source_phrase, target_phrase, scores, most_frequent_alignment)); @@ -93,7 +123,8 @@ vector<Extract> RuleExtractor::ExtractAlignments( int source_sent_start = source_data_array->GetSentenceStart(sentence_id); vector<int> source_low, source_high, target_low, target_high; - GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id); + helper->GetLinksSpans(source_low, source_high, target_low, target_high, + sentence_id); int num_subpatterns = matching.size(); vector<int> chunklen(num_subpatterns); @@ -101,39 +132,44 @@ vector<Extract> RuleExtractor::ExtractAlignments( chunklen[i] = phrase.GetChunkLen(i); } - if (!CheckAlignedTerminals(matching, chunklen, source_low) || - !CheckTightPhrases(matching, chunklen, source_low)) { + if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) || + !helper->CheckTightPhrases(matching, chunklen, source_low)) { return extracts; } int source_back_low = -1, source_back_high = -1; int source_phrase_low = matching[0] - source_sent_start; - int source_phrase_high = matching.back() + chunklen.back() - source_sent_start; + int source_phrase_high = matching.back() + chunklen.back() - + source_sent_start; int target_phrase_low = -1, target_phrase_high = -1; - if (!FindFixPoint(source_phrase_low, source_phrase_high, source_low, - source_high, target_phrase_low, target_phrase_high, - target_low, target_high, source_back_low, source_back_high, - sentence_id, min_gap_size, 0, - max_nonterminals - matching.size() + 1, 1, 1, false)) { + if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high, + target_low, target_high, source_back_low, + source_back_high, sentence_id, min_gap_size, 0, + max_nonterminals - matching.size() + 1, true, true, + false)) { return extracts; } bool met_constraints = true; int num_symbols = phrase.GetNumSymbols(); vector<pair<int, int> > source_gaps, target_gaps; - if (!CheckGaps(source_gaps, target_gaps, matching, chunklen, source_low, - source_high, target_low, target_high, source_phrase_low, - source_phrase_high, source_back_low, source_back_high, - num_symbols, met_constraints)) { + if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low, + source_high, target_low, target_high, source_phrase_low, + source_phrase_high, source_back_low, source_back_high, + num_symbols, met_constraints)) { return extracts; } - bool start_x = source_back_low != source_phrase_low; - bool end_x = source_back_high != source_phrase_high; - Phrase source_phrase = phrase_builder->Extend(phrase, start_x, end_x); + bool starts_with_x = source_back_low != source_phrase_low; + bool ends_with_x = source_back_high != source_phrase_high; + Phrase source_phrase = phrase_builder->Extend( + phrase, starts_with_x, ends_with_x); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, starts_with_x); if (met_constraints) { - AddExtracts(extracts, source_phrase, target_gaps, target_low, - target_phrase_low, target_phrase_high, sentence_id); + AddExtracts(extracts, source_phrase, source_indexes, target_gaps, + target_low, target_phrase_low, target_phrase_high, sentence_id); } if (source_gaps.size() >= max_nonterminals || @@ -145,317 +181,24 @@ vector<Extract> RuleExtractor::ExtractAlignments( for (int i = 0; i < 2; ++i) { for (int j = 1 - i; j < 2; ++j) { - AddNonterminalExtremities(extracts, source_phrase, source_phrase_low, - source_phrase_high, source_back_low, source_back_high, source_low, - source_high, target_low, target_high, target_gaps, sentence_id, i, j); + AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, + source_back_low, source_back_high, source_low, source_high, + target_low, target_high, target_gaps, sentence_id, starts_with_x, + ends_with_x, i, j); } } return extracts; } -void RuleExtractor::GetLinksSpans( - vector<int>& source_low, vector<int>& source_high, - vector<int>& target_low, vector<int>& target_high, int sentence_id) const { - // Ignore end of line markers. - int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - - source_data_array->GetSentenceStart(sentence_id) - 1; - int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - - target_data_array->GetSentenceStart(sentence_id) - 1; - source_low = vector<int>(source_sent_len, -1); - source_high = vector<int>(source_sent_len, -1); - - // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we - // can speed it up. - target_low = vector<int>(target_sent_len, -1); - target_high = vector<int>(target_sent_len, -1); - const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id); - for (auto link: links) { - if (source_low[link.first] == -1 || source_low[link.first] > link.second) { - source_low[link.first] = link.second; - } - source_high[link.first] = max(source_high[link.first], link.second + 1); - - if (target_low[link.second] == -1 || target_low[link.second] > link.first) { - target_low[link.second] = link.first; - } - target_high[link.second] = max(target_high[link.second], link.first + 1); - } -} - -bool RuleExtractor::CheckAlignedTerminals(const vector<int>& matching, - const vector<int>& chunklen, - const vector<int>& source_low) const { - if (!require_aligned_terminal) { - return true; - } - - int sentence_id = source_data_array->GetSentenceId(matching[0]); - int source_sent_start = source_data_array->GetSentenceStart(sentence_id); - - int num_aligned_chunks = 0; - for (size_t i = 0; i < chunklen.size(); ++i) { - for (size_t j = 0; j < chunklen[i]; ++j) { - int sent_index = matching[i] - source_sent_start + j; - if (source_low[sent_index] != -1) { - ++num_aligned_chunks; - break; - } - } - } - - if (num_aligned_chunks == 0) { - return false; - } - - return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); -} - -bool RuleExtractor::CheckTightPhrases(const vector<int>& matching, - const vector<int>& chunklen, - const vector<int>& source_low) const { - if (!require_tight_phrases) { - return true; - } - - int sentence_id = source_data_array->GetSentenceId(matching[0]); - int source_sent_start = source_data_array->GetSentenceStart(sentence_id); - for (size_t i = 0; i + 1 < chunklen.size(); ++i) { - int gap_start = matching[i] + chunklen[i] - source_sent_start; - int gap_end = matching[i + 1] - 1 - source_sent_start; - if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { - return false; - } - } - - return true; -} - -bool RuleExtractor::FindFixPoint( - int source_phrase_low, int source_phrase_high, - const vector<int>& source_low, const vector<int>& source_high, - int& target_phrase_low, int& target_phrase_high, - const vector<int>& target_low, const vector<int>& target_high, - int& source_back_low, int& source_back_high, int sentence_id, - int min_source_gap_size, int min_target_gap_size, - int max_new_x, int max_low_x, int max_high_x, - bool allow_arbitrary_expansion) const { - int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - - source_data_array->GetSentenceStart(sentence_id) - 1; - int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - - target_data_array->GetSentenceStart(sentence_id) - 1; - - int prev_target_low = target_phrase_low; - int prev_target_high = target_phrase_high; - FindProjection(source_phrase_low, source_phrase_high, source_low, - source_high, target_phrase_low, target_phrase_high); - - if (target_phrase_low == -1) { - // TODO(pauldb): Low priority corner case inherited from Adam's code: - // If w is unaligned, but we don't require aligned terminals, returning an - // error here prevents the extraction of the allowed rule - // X -> X_1 w X_2 / X_1 X_2 - return false; - } - - if (prev_target_low != -1 && target_phrase_low != prev_target_low) { - if (prev_target_low - target_phrase_low < min_target_gap_size) { - target_phrase_low = prev_target_low - min_target_gap_size; - if (target_phrase_low < 0) { - return false; - } - } - } - - if (prev_target_high != -1 && target_phrase_high != prev_target_high) { - if (target_phrase_high - prev_target_high < min_target_gap_size) { - target_phrase_high = prev_target_high + min_target_gap_size; - if (target_phrase_high > target_sent_len) { - return false; - } - } - } - - if (target_phrase_high - target_phrase_low > max_rule_span) { - return false; - } - - source_back_low = source_back_high = -1; - FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, - source_back_low, source_back_high); - int new_x = 0, new_low_x = 0, new_high_x = 0; - - while (true) { - source_back_low = min(source_back_low, source_phrase_low); - source_back_high = max(source_back_high, source_phrase_high); - - if (source_back_low == source_phrase_low && - source_back_high == source_phrase_high) { - return true; - } - - if (new_low_x >= max_low_x && source_back_low < source_phrase_low) { - // Extension on the left side not allowed. - return false; - } - if (new_high_x >= max_high_x && source_back_high > source_phrase_high) { - // Extension on the right side not allowed. - return false; - } - - // Extend left side. - if (source_back_low < source_phrase_low) { - if (new_x >= max_new_x) { - return false; - } - ++new_x; ++new_low_x; - if (source_phrase_low - source_back_low < min_source_gap_size) { - source_back_low = source_phrase_low - min_source_gap_size; - if (source_back_low < 0) { - return false; - } - } - } - - // Extend right side. - if (source_back_high > source_phrase_high) { - if (new_x >= max_new_x) { - return false; - } - ++new_x; ++new_high_x; - if (source_back_high - source_phrase_high < min_source_gap_size) { - source_back_high = source_phrase_high + min_source_gap_size; - if (source_back_high > source_sent_len) { - return false; - } - } - } - - if (source_back_high - source_back_low > max_rule_span) { - // Rule span too wide. - return false; - } - - prev_target_low = target_phrase_low; - prev_target_high = target_phrase_high; - FindProjection(source_back_low, source_phrase_low, source_low, source_high, - target_phrase_low, target_phrase_high); - FindProjection(source_phrase_high, source_back_high, source_low, - source_high, target_phrase_low, target_phrase_high); - if (prev_target_low == target_phrase_low && - prev_target_high == target_phrase_high) { - return true; - } - - if (!allow_arbitrary_expansion) { - // Arbitrary expansion not allowed. - return false; - } - if (target_phrase_high - target_phrase_low > max_rule_span) { - // Target side too wide. - return false; - } - - source_phrase_low = source_back_low; - source_phrase_high = source_back_high; - FindProjection(target_phrase_low, prev_target_low, target_low, target_high, - source_back_low, source_back_high); - FindProjection(prev_target_high, target_phrase_high, target_low, - target_high, source_back_low, source_back_high); - } - - return false; -} - -void RuleExtractor::FindProjection( - int source_phrase_low, int source_phrase_high, - const vector<int>& source_low, const vector<int>& source_high, - int& target_phrase_low, int& target_phrase_high) const { - for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { - if (source_low[i] != -1) { - if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { - target_phrase_low = source_low[i]; - } - target_phrase_high = max(target_phrase_high, source_high[i]); - } - } -} - -bool RuleExtractor::CheckGaps( - vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, - const vector<int>& matching, const vector<int>& chunklen, - const vector<int>& source_low, const vector<int>& source_high, - const vector<int>& target_low, const vector<int>& target_high, - int source_phrase_low, int source_phrase_high, int source_back_low, - int source_back_high, int& num_symbols, bool& met_constraints) const { - int sentence_id = source_data_array->GetSentenceId(matching[0]); - int source_sent_start = source_data_array->GetSentenceStart(sentence_id); - - if (source_back_low < source_phrase_low) { - source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); - if (num_symbols >= max_rule_symbols) { - // Source side contains too many symbols. - return false; - } - ++num_symbols; - if (require_tight_phrases && (source_low[source_back_low] == -1 || - source_low[source_phrase_low - 1] == -1)) { - // Inside edges of preceding gap are not tight. - return false; - } - } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { - // This is not a hard error. We can't extract this phrase, but we might - // still be able to extract a superphrase. - met_constraints = false; - } - - for (size_t i = 0; i + 1 < chunklen.size(); ++i) { - int gap_start = matching[i] + chunklen[i] - source_sent_start; - int gap_end = matching[i + 1] - source_sent_start; - source_gaps.push_back(make_pair(gap_start, gap_end)); - } - - if (source_phrase_high < source_back_high) { - source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); - if (num_symbols >= max_rule_symbols) { - // Source side contains too many symbols. - return false; - } - ++num_symbols; - if (require_tight_phrases && (source_low[source_phrase_high] == -1 || - source_low[source_back_high - 1] == -1)) { - // Inside edges of following gap are not tight. - return false; - } - } else if (require_tight_phrases && - source_low[source_phrase_high - 1] == -1) { - // This is not a hard error. We can't extract this phrase, but we might - // still be able to extract a superphrase. - met_constraints = false; - } - - target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); - for (size_t i = 0; i < source_gaps.size(); ++i) { - if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, - source_high, target_gaps[i].first, target_gaps[i].second, - target_low, target_high, source_gaps[i].first, - source_gaps[i].second, sentence_id, 0, 0, 0, 0, 0, - false)) { - // Gap fails integrity check. - return false; - } - } - - return true; -} - void RuleExtractor::AddExtracts( vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, int target_phrase_low, int target_phrase_high, int sentence_id) const { - vector<pair<Phrase, PhraseAlignment> > target_phrases = ExtractTargetPhrases( + auto target_phrases = target_phrase_extractor->ExtractPhrases( target_gaps, target_low, target_phrase_low, target_phrase_high, - sentence_id); + source_indexes, sentence_id); if (target_phrases.size() > 0) { double pairs_count = 1.0 / target_phrases.size(); @@ -466,147 +209,29 @@ void RuleExtractor::AddExtracts( } } -vector<pair<Phrase, PhraseAlignment> > RuleExtractor::ExtractTargetPhrases( - const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, - int target_phrase_low, int target_phrase_high, int sentence_id) const { - int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - - target_data_array->GetSentenceStart(sentence_id) - 1; - - vector<int> target_gap_order(target_gaps.size()); - for (size_t i = 0; i < target_gap_order.size(); ++i) { - for (size_t j = 0; j < i; ++j) { - if (target_gaps[target_gap_order[j]] < target_gaps[i]) { - ++target_gap_order[i]; - } else { - ++target_gap_order[j]; - } - } - } - - int target_x_low = target_phrase_low, target_x_high = target_phrase_high; - if (!require_tight_phrases) { - while (target_x_low > 0 && - target_phrase_high - target_x_low < max_rule_span && - target_low[target_x_low - 1] == -1) { - --target_x_low; - } - while (target_x_high + 1 < target_sent_len && - target_x_high - target_phrase_low < max_rule_span && - target_low[target_x_high + 1] == -1) { - ++target_x_high; - } - } - - vector<pair<int, int> > gaps(target_gaps.size()); - for (size_t i = 0; i < gaps.size(); ++i) { - gaps[i] = target_gaps[target_gap_order[i]]; - if (!require_tight_phrases) { - while (gaps[i].first > target_x_low && - target_low[gaps[i].first] == -1) { - --gaps[i].first; - } - while (gaps[i].second < target_x_high && - target_low[gaps[i].second] == -1) { - ++gaps[i].second; - } - } - } - - vector<pair<int, int> > ranges(2 * gaps.size() + 2); - ranges.front() = make_pair(target_x_low, target_phrase_low); - ranges.back() = make_pair(target_phrase_high, target_x_high); - for (size_t i = 0; i < gaps.size(); ++i) { - ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[i].first); - ranges[i * 2 + 2] = make_pair(target_gaps[i].second, gaps[i].second); - } - - vector<pair<Phrase, PhraseAlignment> > target_phrases; - vector<int> subpatterns(ranges.size()); - GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, - target_phrase_low, target_phrase_high, sentence_id); - return target_phrases; -} - -void RuleExtractor::GeneratePhrases( - vector<pair<Phrase, PhraseAlignment> >& target_phrases, - const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns, - const vector<int>& target_gap_order, int target_phrase_low, - int target_phrase_high, int sentence_id) const { - if (index >= ranges.size()) { - if (subpatterns.back() - subpatterns.front() > max_rule_span) { +void RuleExtractor::AddNonterminalExtremities( + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, + const vector<int>& source_high, const vector<int>& target_low, + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int starts_with_x, int ends_with_x, int extend_left, + int extend_right) const { + int source_x_low = source_back_low, source_x_high = source_back_high; + + if (require_tight_phrases) { + if (source_low[source_back_low - extend_left] == -1 || + source_low[source_back_high + extend_right - 1] == -1) { return; } - - vector<int> symbols; - unordered_set<int> target_indexes; - int offset = 1; - if (subpatterns.front() != target_phrase_low) { - offset = 2; - symbols.push_back(vocabulary->GetNonterminalIndex(1)); - } - - int target_sent_start = target_data_array->GetSentenceStart(sentence_id); - for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { - for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { - symbols.push_back(target_data_array->AtIndex(target_sent_start + j)); - target_indexes.insert(j); - } - if (i < target_gap_order.size()) { - symbols.push_back(vocabulary->GetNonterminalIndex( - target_gap_order[i] + offset)); - } - } - - if (subpatterns.back() != target_phrase_high) { - symbols.push_back(target_gap_order.size() + offset); - } - - const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id); - vector<pair<int, int> > alignment; - for (pair<int, int> link: links) { - if (target_indexes.count(link.second)) { - alignment.push_back(link); - } - } - - target_phrases.push_back(make_pair(phrase_builder->Build(symbols), - alignment)); - return; - } - - subpatterns[index] = ranges[index].first; - if (index > 0) { - subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); } - while (subpatterns[index] <= ranges[index].second) { - subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); - while (subpatterns[index + 1] <= ranges[index + 1].second) { - GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, - target_gap_order, target_phrase_low, target_phrase_high, - sentence_id); - ++subpatterns[index + 1]; - } - ++subpatterns[index]; - } -} -void RuleExtractor::AddNonterminalExtremities( - vector<Extract>& extracts, const Phrase& source_phrase, - int source_phrase_low, int source_phrase_high, int source_back_low, - int source_back_high, const vector<int>& source_low, - const vector<int>& source_high, const vector<int>& target_low, - const vector<int>& target_high, const vector<pair<int, int> >& target_gaps, - int sentence_id, int extend_left, int extend_right) const { - int source_x_low = source_phrase_low, source_x_high = source_phrase_high; if (extend_left) { - if (source_back_low != source_phrase_low || - source_phrase_low < min_gap_size || - (require_tight_phrases && (source_low[source_phrase_low - 1] == -1 || - source_low[source_back_high - 1] == -1))) { + if (starts_with_x || source_back_low < min_gap_size) { return; } - source_x_low = source_phrase_low - min_gap_size; + source_x_low = source_back_low - min_gap_size; if (require_tight_phrases) { while (source_x_low >= 0 && source_low[source_x_low] == -1) { --source_x_low; @@ -618,15 +243,11 @@ void RuleExtractor::AddNonterminalExtremities( } if (extend_right) { - int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - - source_data_array->GetSentenceStart(sentence_id) - 1; - if (source_back_high != source_phrase_high || - source_phrase_high + min_gap_size > source_sent_len || - (require_tight_phrases && (source_low[source_phrase_low] == -1 || - source_low[source_phrase_high] == -1))) { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { return; } - source_x_high = source_phrase_high + min_gap_size; + source_x_high = source_back_high + min_gap_size; if (require_tight_phrases) { while (source_x_high <= source_sent_len && source_low[source_x_high - 1] == -1) { @@ -639,41 +260,56 @@ void RuleExtractor::AddNonterminalExtremities( } } + int new_nonterminals = extend_left + extend_right; if (source_x_high - source_x_low > max_rule_span || - target_gaps.size() + extend_left + extend_right > max_nonterminals) { + target_gaps.size() + new_nonterminals > max_nonterminals || + source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) { return; } int target_x_low = -1, target_x_high = -1; - if (!FindFixPoint(source_x_low, source_x_high, source_low, source_high, - target_x_low, target_x_high, target_low, target_high, - source_x_low, source_x_high, sentence_id, 1, 1, - extend_left + extend_right, extend_left, extend_right, - true)) { + if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, + source_high, target_x_low, target_x_high, + target_low, target_high, source_x_low, + source_x_high, sentence_id, 1, 1, + new_nonterminals, extend_left, extend_right, + true)) { return; } - int source_gap_low = -1, source_gap_high = -1, target_gap_low = -1, - target_gap_high = -1; - if (extend_left && - ((require_tight_phrases && source_low[source_x_low] == -1) || - !FindFixPoint(source_x_low, source_phrase_low, source_low, source_high, - target_gap_low, target_gap_high, target_low, target_high, - source_gap_low, source_gap_high, sentence_id, - 0, 0, 0, 0, 0, false))) { - return; + if (extend_left) { + int source_gap_low = -1, source_gap_high = -1; + int target_gap_low = -1, target_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_low] == -1) || + !helper->FindFixPoint(source_x_low, source_back_low, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.insert(target_gaps.begin(), + make_pair(target_gap_low, target_gap_high)); } - if (extend_right && - ((require_tight_phrases && source_low[source_x_high - 1] == -1) || - !FindFixPoint(source_phrase_high, source_x_high, source_low, source_high, - target_gap_low, target_gap_high, target_low, target_high, - source_gap_low, source_gap_high, sentence_id, - 0, 0, 0, 0, 0, false))) { - return; + + if (extend_right) { + int target_gap_low = -1, target_gap_high = -1; + int source_gap_low = -1, source_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_high - 1] == -1) || + !helper->FindFixPoint(source_back_high, source_x_high, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); } Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, extend_right); - AddExtracts(extracts, new_source_phrase, target_gaps, target_low, - target_x_low, target_x_high, sentence_id); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, extend_left || starts_with_x); + AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps, + target_low, target_x_low, target_x_high, sentence_id); } diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h index f668de24..a087dc6d 100644 --- a/extractor/rule_extractor.h +++ b/extractor/rule_extractor.h @@ -2,6 +2,7 @@ #define _RULE_EXTRACTOR_H_ #include <memory> +#include <unordered_map> #include <vector> #include "phrase.h" @@ -13,8 +14,9 @@ class DataArray; class PhraseBuilder; class PhraseLocation; class Rule; +class RuleExtractorHelper; class Scorer; -class Vocabulary; +class TargetPhraseExtractor; typedef vector<pair<int, int> > PhraseAlignment; @@ -46,84 +48,56 @@ class RuleExtractor { bool require_aligned_chunks, bool require_tight_phrases); - vector<Rule> ExtractRules(const Phrase& phrase, - const PhraseLocation& location) const; + // For testing only. + RuleExtractor(shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases); + + virtual ~RuleExtractor(); + + virtual vector<Rule> ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const; + + protected: + RuleExtractor(); private: vector<Extract> ExtractAlignments(const Phrase& phrase, const vector<int>& matching) const; - void GetLinksSpans(vector<int>& source_low, vector<int>& source_high, - vector<int>& target_low, vector<int>& target_high, - int sentence_id) const; - - bool CheckAlignedTerminals(const vector<int>& matching, - const vector<int>& chunklen, - const vector<int>& source_low) const; - - bool CheckTightPhrases(const vector<int>& matching, - const vector<int>& chunklen, - const vector<int>& source_low) const; - - bool FindFixPoint( - int source_phrase_start, int source_phrase_end, - const vector<int>& source_low, const vector<int>& source_high, - int& target_phrase_start, int& target_phrase_end, - const vector<int>& target_low, const vector<int>& target_high, - int& source_back_low, int& source_back_high, int sentence_id, - int min_source_gap_size, int min_target_gap_size, - int max_new_x, int max_low_x, int max_high_x, - bool allow_arbitrary_expansion) const; - - void FindProjection( - int source_phrase_start, int source_phrase_end, - const vector<int>& source_low, const vector<int>& source_high, - int& target_phrase_low, int& target_phrase_end) const; - - bool CheckGaps( - vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, - const vector<int>& matching, const vector<int>& chunklen, - const vector<int>& source_low, const vector<int>& source_high, - const vector<int>& target_low, const vector<int>& target_high, - int source_phrase_low, int source_phrase_high, int source_back_low, - int source_back_high, int& num_symbols, bool& met_constraints) const; - void AddExtracts( vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, int target_phrase_low, int target_phrase_high, int sentence_id) const; - vector<pair<Phrase, PhraseAlignment> > ExtractTargetPhrases( - const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, - int target_phrase_low, int target_phrase_high, int sentence_id) const; - - void GeneratePhrases( - vector<pair<Phrase, PhraseAlignment> >& target_phrases, - const vector<pair<int, int> >& ranges, int index, - vector<int>& subpatterns, const vector<int>& target_gap_order, - int target_phrase_low, int target_phrase_high, int sentence_id) const; - void AddNonterminalExtremities( - vector<Extract>& extracts, const Phrase& source_phrase, - int source_phrase_low, int source_phrase_high, int source_back_low, - int source_back_high, const vector<int>& source_low, + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, const vector<int>& source_high, const vector<int>& target_low, - const vector<int>& target_high, - const vector<pair<int, int> >& target_gaps, int sentence_id, - int extend_left, int extend_right) const; + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int starts_with_x, int ends_with_x, int extend_left, + int extend_right) const; - shared_ptr<DataArray> source_data_array; + private: shared_ptr<DataArray> target_data_array; - shared_ptr<Alignment> alignment; + shared_ptr<DataArray> source_data_array; shared_ptr<PhraseBuilder> phrase_builder; shared_ptr<Scorer> scorer; - shared_ptr<Vocabulary> vocabulary; + shared_ptr<TargetPhraseExtractor> target_phrase_extractor; + shared_ptr<RuleExtractorHelper> helper; int max_rule_span; int min_gap_size; int max_nonterminals; int max_rule_symbols; - bool require_aligned_terminal; - bool require_aligned_chunks; bool require_tight_phrases; }; diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc new file mode 100644 index 00000000..ed6ae3a1 --- /dev/null +++ b/extractor/rule_extractor_helper.cc @@ -0,0 +1,356 @@ +#include "rule_extractor_helper.h" + +#include "data_array.h" +#include "alignment.h" + +RuleExtractorHelper::RuleExtractorHelper( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + int max_rule_span, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + source_data_array(source_data_array), + target_data_array(target_data_array), + alignment(alignment), + max_rule_span(max_rule_span), + max_rule_symbols(max_rule_symbols), + require_aligned_terminal(require_aligned_terminal), + require_aligned_chunks(require_aligned_chunks), + require_tight_phrases(require_tight_phrases) {} + +RuleExtractorHelper::RuleExtractorHelper() {} + +RuleExtractorHelper::~RuleExtractorHelper() {} + +void RuleExtractorHelper::GetLinksSpans( + vector<int>& source_low, vector<int>& source_high, + vector<int>& target_low, vector<int>& target_high, int sentence_id) const { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + source_low = vector<int>(source_sent_len, -1); + source_high = vector<int>(source_sent_len, -1); + + // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we + // can speed it up. + target_low = vector<int>(target_sent_len, -1); + target_high = vector<int>(target_sent_len, -1); + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + for (auto link: links) { + if (source_low[link.first] == -1 || source_low[link.first] > link.second) { + source_low[link.first] = link.second; + } + source_high[link.first] = max(source_high[link.first], link.second + 1); + + if (target_low[link.second] == -1 || target_low[link.second] > link.first) { + target_low[link.second] = link.first; + } + target_high[link.second] = max(target_high[link.second], link.first + 1); + } +} + +bool RuleExtractorHelper::CheckAlignedTerminals( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const { + if (!require_aligned_terminal) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + int num_aligned_chunks = 0; + for (size_t i = 0; i < chunklen.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + int sent_index = matching[i] - source_sent_start + j; + if (source_low[sent_index] != -1) { + ++num_aligned_chunks; + break; + } + } + } + + if (num_aligned_chunks == 0) { + return false; + } + + return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); +} + +bool RuleExtractorHelper::CheckTightPhrases( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const { + if (!require_tight_phrases) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - 1 - source_sent_start; + if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { + return false; + } + } + + return true; +} + +bool RuleExtractorHelper::FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector<int>& target_low, const vector<int>& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, bool allow_low_x, bool allow_high_x, + bool allow_arbitrary_expansion) const { + int prev_target_low = target_phrase_low; + int prev_target_high = target_phrase_high; + + FindProjection(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high); + + if (target_phrase_low == -1) { + // TODO(pauldb): Low priority corner case inherited from Adam's code: + // If w is unaligned, but we don't require aligned terminals, returning an + // error here prevents the extraction of the allowed rule + // X -> X_1 w X_2 / X_1 X_2 + return false; + } + + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + if (prev_target_low != -1 && target_phrase_low != prev_target_low) { + if (prev_target_low - target_phrase_low < min_target_gap_size) { + target_phrase_low = prev_target_low - min_target_gap_size; + if (target_phrase_low < 0) { + return false; + } + } + } + + if (prev_target_high != -1 && target_phrase_high != prev_target_high) { + if (target_phrase_high - prev_target_high < min_target_gap_size) { + target_phrase_high = prev_target_high + min_target_gap_size; + if (target_phrase_high > target_sent_len) { + return false; + } + } + } + + if (target_phrase_high - target_phrase_low > max_rule_span) { + return false; + } + + source_back_low = source_back_high = -1; + FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, + source_back_low, source_back_high); + int new_x = 0; + bool new_low_x = false, new_high_x = false; + while (true) { + source_back_low = min(source_back_low, source_phrase_low); + source_back_high = max(source_back_high, source_phrase_high); + + if (source_back_low == source_phrase_low && + source_back_high == source_phrase_high) { + return true; + } + + if (!allow_low_x && source_back_low < source_phrase_low) { + // Extension on the left side not allowed. + return false; + } + if (!allow_high_x && source_back_high > source_phrase_high) { + // Extension on the right side not allowed. + return false; + } + + // Extend left side. + if (source_back_low < source_phrase_low) { + if (new_low_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_low_x = true; + ++new_x; + } + if (source_phrase_low - source_back_low < min_source_gap_size) { + source_back_low = source_phrase_low - min_source_gap_size; + if (source_back_low < 0) { + return false; + } + } + } + + // Extend right side. + if (source_back_high > source_phrase_high) { + if (new_high_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_high_x = true; + ++new_x; + } + if (source_back_high - source_phrase_high < min_source_gap_size) { + source_back_high = source_phrase_high + min_source_gap_size; + if (source_back_high > source_sent_len) { + return false; + } + } + } + + if (source_back_high - source_back_low > max_rule_span) { + // Rule span too wide. + return false; + } + + prev_target_low = target_phrase_low; + prev_target_high = target_phrase_high; + FindProjection(source_back_low, source_phrase_low, source_low, source_high, + target_phrase_low, target_phrase_high); + FindProjection(source_phrase_high, source_back_high, source_low, + source_high, target_phrase_low, target_phrase_high); + if (prev_target_low == target_phrase_low && + prev_target_high == target_phrase_high) { + return true; + } + + if (!allow_arbitrary_expansion) { + // Arbitrary expansion not allowed. + return false; + } + if (target_phrase_high - target_phrase_low > max_rule_span) { + // Target side too wide. + return false; + } + + source_phrase_low = source_back_low; + source_phrase_high = source_back_high; + FindProjection(target_phrase_low, prev_target_low, target_low, target_high, + source_back_low, source_back_high); + FindProjection(prev_target_high, target_phrase_high, target_low, + target_high, source_back_low, source_back_high); + } + + return false; +} + +void RuleExtractorHelper::FindProjection( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high) const { + for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { + if (source_low[i] != -1) { + if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { + target_phrase_low = source_low[i]; + } + target_phrase_high = max(target_phrase_high, source_high[i]); + } + } +} + +bool RuleExtractorHelper::GetGaps( + vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, + const vector<int>& matching, const vector<int>& chunklen, + const vector<int>& source_low, const vector<int>& source_high, + const vector<int>& target_low, const vector<int>& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const { + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + if (source_back_low < source_phrase_low) { + source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_back_low] == -1 || + source_low[source_phrase_low - 1] == -1)) { + // Inside edges of preceding gap are not tight. + return false; + } + } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - source_sent_start; + source_gaps.push_back(make_pair(gap_start, gap_end)); + } + + if (source_phrase_high < source_back_high) { + source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_phrase_high] == -1 || + source_low[source_back_high - 1] == -1)) { + // Inside edges of following gap are not tight. + return false; + } + } else if (require_tight_phrases && + source_low[source_phrase_high - 1] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); + for (size_t i = 0; i < source_gaps.size(); ++i) { + if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, + source_high, target_gaps[i].first, target_gaps[i].second, + target_low, target_high, source_gaps[i].first, + source_gaps[i].second, sentence_id, 0, 0, 0, false, false, + false)) { + // Gap fails integrity check. + return false; + } + } + + return true; +} + +vector<int> RuleExtractorHelper::GetGapOrder( + const vector<pair<int, int> >& gaps) const { + vector<int> gap_order(gaps.size()); + for (size_t i = 0; i < gap_order.size(); ++i) { + for (size_t j = 0; j < i; ++j) { + if (gaps[gap_order[j]] < gaps[i]) { + ++gap_order[i]; + } else { + ++gap_order[j]; + } + } + } + return gap_order; +} + +unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes( + const vector<int>& matching, const vector<int>& chunklen, + int starts_with_x) const { + unordered_map<int, int> source_indexes; + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + int num_symbols = starts_with_x; + for (size_t i = 0; i < matching.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + source_indexes[matching[i] + j - source_sent_start] = num_symbols; + ++num_symbols; + } + ++num_symbols; + } + return source_indexes; +} diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h new file mode 100644 index 00000000..3478bfc8 --- /dev/null +++ b/extractor/rule_extractor_helper.h @@ -0,0 +1,82 @@ +#ifndef _RULE_EXTRACTOR_HELPER_H_ +#define _RULE_EXTRACTOR_HELPER_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +class Alignment; +class DataArray; + +class RuleExtractorHelper { + public: + RuleExtractorHelper(shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + int max_rule_span, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases); + + virtual ~RuleExtractorHelper(); + + virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high, + vector<int>& target_low, vector<int>& target_high, + int sentence_id) const; + + virtual bool CheckAlignedTerminals(const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const; + + virtual bool CheckTightPhrases(const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const; + + virtual bool FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector<int>& target_low, const vector<int>& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, bool allow_low_x, bool allow_high_x, + bool allow_arbitrary_expansion) const; + + virtual bool GetGaps( + vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, + const vector<int>& matching, const vector<int>& chunklen, + const vector<int>& source_low, const vector<int>& source_high, + const vector<int>& target_low, const vector<int>& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const; + + virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const; + + // TODO(pauldb): Add unit tests. + virtual unordered_map<int, int> GetSourceIndexes( + const vector<int>& matching, const vector<int>& chunklen, + int starts_with_x) const; + + protected: + RuleExtractorHelper(); + + private: + void FindProjection( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high) const; + + shared_ptr<DataArray> source_data_array; + shared_ptr<DataArray> target_data_array; + shared_ptr<Alignment> alignment; + int max_rule_span; + int max_rule_symbols; + bool require_aligned_terminal; + bool require_aligned_chunks; + bool require_tight_phrases; +}; + +#endif diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc new file mode 100644 index 00000000..29213312 --- /dev/null +++ b/extractor/rule_extractor_helper_test.cc @@ -0,0 +1,622 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "rule_extractor_helper.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleExtractorHelperTest : public Test { + protected: + virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(12)); + EXPECT_CALL(*source_data_array, GetSentenceId(_)) + .WillRepeatedly(Return(5)); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)) + .WillRepeatedly(Return(10)); + + target_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(12)); + + vector<pair<int, int> > links = { + make_pair(0, 0), make_pair(0, 1), make_pair(2, 2), make_pair(3, 1) + }; + alignment = make_shared<MockAlignment>(); + EXPECT_CALL(*alignment, GetLinks(_)).WillRepeatedly(Return(links)); + } + + shared_ptr<MockDataArray> source_data_array; + shared_ptr<MockDataArray> target_data_array; + shared_ptr<MockAlignment> alignment; + shared_ptr<RuleExtractorHelper> helper; +}; + +TEST_F(RuleExtractorHelperTest, TestGetLinksSpans) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(4)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low, source_high, target_low, target_high; + helper->GetLinksSpans(source_low, source_high, target_low, target_high, 0); + + vector<int> expected_source_low = {0, -1, 2, 1}; + EXPECT_EQ(expected_source_low, source_low); + vector<int> expected_source_high = {2, -1, 3, 2}; + EXPECT_EQ(expected_source_high, source_high); + vector<int> expected_target_low = {0, 0, 2}; + EXPECT_EQ(expected_target_low, target_low); + vector<int> expected_target_high = {1, 4, 3}; + EXPECT_EQ(expected_target_high, target_high); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedFalse) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, false, false, true); + EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + + vector<int> matching, chunklen, source_low; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedTerminal) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, false, true); + + vector<int> matching = {10, 12}; + vector<int> chunklen = {1, 3}; + vector<int> source_low = {-1, 1, -1, 3, -1}; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {-1, 1, -1, -1, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedChunks) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> matching = {10, 12}; + vector<int> chunklen = {1, 3}; + vector<int> source_low = {2, 1, -1, 3, -1}; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {-1, 1, -1, 3, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {2, 1, -1, -1, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrasesFalse) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, false); + EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + + vector<int> matching, chunklen, source_low; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrases) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> matching = {10, 14, 18}; + vector<int> chunklen = {2, 3, 1}; + // No missing links. + vector<int> source_low = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); + + // Missing link at the beginning or ending of a gap. + source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + source_low = {0, 1, 2, -1, 4, 5, 6, 7, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + source_low = {0, 1, 2, 3, 4, 5, 6, -1, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + + // Missing link inside the gap. + chunklen = {1, 3, 1}; + source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointBadEdgeCase) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> source_low = {0, -1, 2}; + vector<int> source_high = {1, -1, 3}; + vector<int> target_low = {0, -1, 2}; + vector<int> target_high = {1, -1, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = 1; + + // This should be in fact true. See comment about the inherited bug. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 0, 0, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSentenceOutOfBounds) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = 1, target_phrase_high = 2; + + // Extend out of sentence to left. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 2, + 0, false, false, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 2, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetTooWide) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 5, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {0, 0, 0, 0, 0, 0, 0}; + vector<int> source_high = {7, 7, 7, 7, 7, 7, 7}; + vector<int> target_low = {0, -1, -1, -1, -1, -1, 0}; + vector<int> target_high = {7, -1, -1, -1, -1, -1, 7}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Projection is too wide. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPoint) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {1, 1, 1, 3, 4, 5, 5}; + vector<int> source_high = {2, 2, 3, 4, 6, 6, 6}; + vector<int> target_low = {-1, 0, 2, 3, 4, 4, -1}; + vector<int> target_high = {-1, 3, 3, 4, 5, 7, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = 2, target_phrase_high = 5; + + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 1, 1, 1, + 2, true, true, false)); + EXPECT_EQ(1, target_phrase_low); + EXPECT_EQ(6, target_phrase_high); + EXPECT_EQ(0, source_back_low); + EXPECT_EQ(7, source_back_high); + + source_low = {0, -1, 1, 3, 4, -1, 6}; + source_high = {1, -1, 3, 4, 6, -1, 7}; + target_low = {0, 2, 2, 3, 4, 4, 6}; + target_high = {1, 3, 3, 4, 5, 5, 7}; + source_phrase_low = 2, source_phrase_high = 5; + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 1, 1, 1, + 2, true, true, false)); + EXPECT_EQ(1, target_phrase_low); + EXPECT_EQ(6, target_phrase_high); + EXPECT_EQ(2, source_back_low); + EXPECT_EQ(5, source_back_high); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointExtensionsNotAllowed) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Extension on the left side not allowed. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 1, false, true, false)); + // Extension on the left side is allowed, but we can't add anymore X. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, true, true, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + // Extension on the right side not allowed. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 1, true, false, false)); + // Extension on the right side is allowed, but we can't add anymore X. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointSourceSentenceOutOfBounds) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = 1, target_phrase_high = 2; + // Extend out of sentence to left. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 1, + 1, true, true, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + target_phrase_low = 1, target_phrase_high = 2; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 1, + 1, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSourceWide) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 5, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {2, -1, 2, 3, 4, -1, 4}; + vector<int> source_high = {3, -1, 3, 4, 5, -1, 5}; + vector<int> target_low = {-1, -1, 0, 3, 4, -1, -1}; + vector<int> target_high = {-1, -1, 3, 4, 7, -1, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Second projection (on source side) is too wide. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 2, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointArbitraryExpansion) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 20, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(11)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(11)); + + vector<int> source_low = {1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9}; + vector<int> source_high = {2, 3, 4, 5, 5, 6, 7, 8, 9, 10, 10}; + vector<int> target_low = {-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, -1}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, 8, 9, 10, 11, -1}; + int source_phrase_low = 4, source_phrase_high = 7; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 10, true, true, false)); + + source_phrase_low = 4, source_phrase_high = 7; + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 10, true, true, true)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapOrder) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<pair<int, int> > gaps = + {make_pair(0, 3), make_pair(5, 8), make_pair(11, 12), make_pair(15, 17)}; + vector<int> expected_gap_order = {0, 1, 2, 3}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + + gaps = {make_pair(15, 17), make_pair(8, 9), make_pair(5, 6), make_pair(0, 3)}; + expected_gap_order = {3, 2, 1, 0}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + + gaps = {make_pair(8, 9), make_pair(5, 6), make_pair(0, 3), make_pair(15, 17)}; + expected_gap_order = {2, 1, 0, 3}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExceedNumSymbols) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {1, 1, 2, 3, 4, 5, 6}; + vector<int> source_high = {2, 2, 3, 4, 5, 6, 7}; + vector<int> target_low = {-1, 0, 2, 3, 4, 5, 6}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, 7}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 0, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + + source_low = {0, 1, 2, 3, 4, 5, 5}; + source_high = {1, 2, 3, 4, 5, 6, 6}; + target_low = {0, 1, 2, 3, 4, 5, -1}; + target_high = {1, 2, 3, 4, 5, 7, -1}; + source_phrase_low = 1, source_phrase_high = 6; + source_back_low = 1, source_back_high = 7; + num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExtensionsNotTight) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 7, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 1, 2, 3, 4, 5, -1}; + vector<int> source_high = {-1, 2, 3, 4, 5, 6, -1}; + vector<int> target_low = {-1, 1, 2, 3, 4, 5, -1}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, -1}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 0, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + + source_phrase_low = 1, source_phrase_high = 6; + source_back_low = 1, source_back_high = 7; + num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsNotTightExtremities) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 7, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, -1, 2, 3, 4, 5, 6}; + vector<int> source_high = {-1, -1, 3, 4, 5, 6, 7}; + vector<int> target_low = {-1, -1, 2, 3, 4, 5, 6}; + vector<int> target_high = {-1, -1, 3, 4, 5, 6, 7}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + EXPECT_FALSE(met_constraints); + vector<pair<int, int> > expected_gaps = {make_pair(2, 3), make_pair(4, 5)}; + EXPECT_EQ(expected_gaps, source_gaps); + EXPECT_EQ(expected_gaps, target_gaps); + + source_low = {-1, 1, 2, 3, 4, -1, 6}; + source_high = {-1, 2, 3, 4, 5, -1, 7}; + target_low = {-1, 1, 2, 3, 4, -1, 6}; + target_high = {-1, 2, 3, 4, 5, -1, 7}; + met_constraints = true; + source_gaps.clear(); + target_gaps.clear(); + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + EXPECT_FALSE(met_constraints); + EXPECT_EQ(expected_gaps, source_gaps); + EXPECT_EQ(expected_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsWithExtensions) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 5, 2, 3, 4, 1, -1}; + vector<int> source_high = {-1, 6, 3, 4, 5, 2, -1}; + vector<int> target_low = {-1, 5, 2, 3, 4, 1, -1}; + vector<int> target_high = {-1, 6, 3, 4, 5, 2, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {12, 14}; + vector<int> chunklen = {1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 3; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + vector<pair<int, int> > expected_source_gaps = { + make_pair(1, 2), make_pair(3, 4), make_pair(5, 6) + }; + EXPECT_EQ(expected_source_gaps, source_gaps); + vector<pair<int, int> > expected_target_gaps = { + make_pair(5, 6), make_pair(3, 4), make_pair(1, 2) + }; + EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGaps) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 1, 4, 3, 2, 5, -1}; + vector<int> source_high = {-1, 2, 5, 4, 3, 6, -1}; + vector<int> target_low = {-1, 1, 4, 3, 2, 5, -1}; + vector<int> target_high = {-1, 2, 5, 4, 3, 6, -1}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + vector<pair<int, int> > expected_source_gaps = { + make_pair(2, 3), make_pair(4, 5) + }; + EXPECT_EQ(expected_source_gaps, source_gaps); + vector<pair<int, int> > expected_target_gaps = { + make_pair(4, 5), make_pair(2, 3) + }; + EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 3, 2, 3, 4, 3, -1}; + vector<int> source_high = {-1, 4, 3, 4, 5, 4, -1}; + vector<int> target_low = {-1, -1, 2, 1, 4, -1, -1}; + vector<int> target_high = {-1, -1, 3, 6, 5, -1, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low = 2, source_back_high = 5; + vector<int> matching = {12, 14}; + vector<int> chunklen = {1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 3; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +} // namespace diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc new file mode 100644 index 00000000..0be44d4d --- /dev/null +++ b/extractor/rule_extractor_test.cc @@ -0,0 +1,166 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_target_phrase_extractor.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_extractor.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleExtractorTest : public Test { + protected: + virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetSentenceId(_)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(10)); + + helper = make_shared<MockRuleExtractorHelper>(); + EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _)) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*helper, CheckTightPhrases(_, _, _)) + .WillRepeatedly(Return(true)); + unordered_map<int, int> source_indexes; + EXPECT_CALL(*helper, GetSourceIndexes(_, _, _)) + .WillRepeatedly(Return(source_indexes)); + + vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(87)) + .WillRepeatedly(Return("a")); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + vector<int> symbols = {87}; + Phrase target_phrase = phrase_builder->Build(symbols); + PhraseAlignment phrase_alignment = {make_pair(0, 0)}; + + target_phrase_extractor = make_shared<MockTargetPhraseExtractor>(); + vector<pair<Phrase, PhraseAlignment> > target_phrases = { + make_pair(target_phrase, phrase_alignment) + }; + EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) + .WillRepeatedly(Return(target_phrases)); + + scorer = make_shared<MockScorer>(); + vector<double> scores = {0.3, 7.2}; + EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); + + extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder, + scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); + } + + shared_ptr<MockDataArray> source_data_array; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockRuleExtractorHelper> helper; + shared_ptr<MockScorer> scorer; + shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor; + shared_ptr<RuleExtractor> extractor; +}; + +TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _)) + .WillRepeatedly(Return(false)); + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + EXPECT_CALL(*helper, CheckTightPhrases(_, _, _)) + .WillRepeatedly(Return(false)); + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + // Set FindFixPoint to return false. + vector<pair<int, int> > gaps; + helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + // Set CheckGaps to return false. + vector<pair<int, int> > gaps; + helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + vector<pair<int, int> > gaps(3); + // Set FindFixPoint to return true. The number of gaps equals the number of + // nonterminals, so we won't add any extremities. + helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(1, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + vector<int> links(10, -1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( + SetArgReferee<0>(links), + SetArgReferee<1>(links), + SetArgReferee<2>(links), + SetArgReferee<3>(links))); + + vector<pair<int, int> > gaps; + // Set FindFixPoint to return true. The number of gaps equals the number of + // nonterminals, so we won't add any extremities. + helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(4, rules.size()); +} + +} // namespace diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index c22f9b48..374a0db1 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -1,6 +1,6 @@ #include "rule_factory.h" -#include <cassert> +#include <chrono> #include <memory> #include <queue> #include <vector> @@ -18,7 +18,9 @@ #include "vocabulary.h" using namespace std; -using namespace tr1; +using namespace std::chrono; + +typedef high_resolution_clock Clock; struct State { State(int start, int end, const vector<int>& phrase, @@ -68,8 +70,44 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( sampler = make_shared<Sampler>(source_suffix_array, max_samples); } +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols) : + matchings_finder(finder), + intersector(intersector), + phrase_builder(phrase_builder), + rule_extractor(rule_extractor), + vocabulary(vocabulary), + sampler(sampler), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_chunks), + max_rule_symbols(max_rule_symbols) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {} Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { + intersector->binary_merge_time = 0; + intersector->linear_merge_time = 0; + intersector->sort_time = 0; + Clock::time_point start_time = Clock::now(); + double total_extract_time = 0; + double total_intersect_time = 0; + double total_lookup_time = 0; // Clear cache for every new sentence. trie.Reset(); shared_ptr<TrieNode> root = trie.GetRoot(); @@ -107,34 +145,42 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } if (RequiresLookup(node, word_id)) { - shared_ptr<TrieNode> next_suffix_link = - node->suffix_link->GetChild(word_id); + shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? + trie.GetRoot() : node->suffix_link->GetChild(word_id); if (state.starts_with_x) { // If the phrase starts with a non terminal, we simply use the matchings // from the suffix link. - next_node = shared_ptr<TrieNode>(new TrieNode( - next_suffix_link, next_phrase, next_suffix_link->matchings)); + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, next_suffix_link->matchings); } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { + Clock::time_point intersect_start_time = Clock::now(); phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); + Clock::time_point intersect_stop_time = Clock::now(); + total_intersect_time += duration_cast<milliseconds>( + intersect_stop_time - intersect_start_time).count(); } else { + Clock::time_point lookup_start_time = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); + Clock::time_point lookup_stop_time = Clock::now(); + total_lookup_time += duration_cast<milliseconds>( + lookup_stop_time - lookup_start_time).count(); } if (phrase_location.IsEmpty()) { continue; } - next_node = shared_ptr<TrieNode>(new TrieNode( - next_suffix_link, next_phrase, phrase_location)); + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, phrase_location); } node->AddChild(word_id, next_node); @@ -143,12 +189,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { AddTrailingNonterminal(phrase, next_phrase, next_node, state.starts_with_x); + Clock::time_point extract_start_time = Clock::now(); if (!state.starts_with_x) { PhraseLocation sample = sampler->Sample(next_node->matchings); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } + Clock::time_point extract_stop_time = Clock::now(); + total_extract_time += duration_cast<milliseconds>( + extract_stop_time - extract_start_time).count(); } else { next_node = node->GetChild(word_id); } @@ -160,6 +210,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } } + Clock::time_point stop_time = Clock::now(); + milliseconds ms = duration_cast<milliseconds>(stop_time - start_time); + cerr << "Total time for rule lookup, extraction, and scoring = " + << ms.count() / 1000.0 << endl; + cerr << "Extract time = " << total_extract_time / 1000.0 << endl; + cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl; + cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl; + cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl; + cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl; + // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl; return Grammar(rules, scorer->GetFeatureNames()); } @@ -192,12 +252,12 @@ void HieroCachingRuleFactory::AddTrailingNonterminal( Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( - prefix.Arity() + starts_with_x == 0); + prefix.Arity() + (starts_with_x == 0)); shared_ptr<TrieNode> var_suffix_link = prefix_node->suffix_link->GetChild(suffix_var_id); - prefix_node->AddChild(var_id, shared_ptr<TrieNode>(new TrieNode( - var_suffix_link, var_phrase, prefix_node->matchings))); + prefix_node->AddChild(var_id, make_shared<TrieNode>( + var_suffix_link, var_phrase, prefix_node->matchings)); } vector<State> HieroCachingRuleFactory::ExtendState( @@ -216,7 +276,7 @@ vector<State> HieroCachingRuleFactory::ExtendState( new_states.push_back(State(state.start, state.end + 1, symbols, state.subpatterns_start, node, state.starts_with_x)); - int num_subpatterns = phrase.Arity() + state.starts_with_x == 0; + int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); if (symbols.size() + 1 >= max_rule_symbols || phrase.Arity() >= max_nonterminals || num_subpatterns >= max_chunks) { diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a47b6d16..cf344667 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -40,7 +40,27 @@ class HieroCachingRuleFactory { bool use_beaza_yates, bool require_tight_phrases); - Grammar GetGrammar(const vector<int>& word_ids); + // For testing only. + HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols); + + virtual ~HieroCachingRuleFactory(); + + virtual Grammar GetGrammar(const vector<int>& word_ids); + + protected: + HieroCachingRuleFactory(); private: bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc new file mode 100644 index 00000000..d6fbab74 --- /dev/null +++ b/extractor/rule_factory_test.cc @@ -0,0 +1,98 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "mocks/mock_intersector.h" +#include "mocks/mock_matchings_finder.h" +#include "mocks/mock_rule_extractor.h" +#include "mocks/mock_sampler.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_vocabulary.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_factory.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleFactoryTest : public Test { + protected: + virtual void SetUp() { + finder = make_shared<MockMatchingsFinder>(); + intersector = make_shared<MockIntersector>(); + + vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a")); + EXPECT_CALL(*vocabulary, GetTerminalValue(3)).WillRepeatedly(Return("b")); + EXPECT_CALL(*vocabulary, GetTerminalValue(4)).WillRepeatedly(Return("c")); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + scorer = make_shared<MockScorer>(); + feature_names = {"f1"}; + EXPECT_CALL(*scorer, GetFeatureNames()) + .WillRepeatedly(Return(feature_names)); + + sampler = make_shared<MockSampler>(); + EXPECT_CALL(*sampler, Sample(_)) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + Phrase phrase; + vector<double> scores = {0.5}; + vector<pair<int, int> > phrase_alignment = {make_pair(0, 0)}; + vector<Rule> rules = {Rule(phrase, phrase, scores, phrase_alignment)}; + extractor = make_shared<MockRuleExtractor>(); + EXPECT_CALL(*extractor, ExtractRules(_, _)) + .WillRepeatedly(Return(rules)); + + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5); + } + + vector<string> feature_names; + shared_ptr<MockMatchingsFinder> finder; + shared_ptr<MockIntersector> intersector; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockScorer> scorer; + shared_ptr<MockSampler> sampler; + shared_ptr<MockRuleExtractor> extractor; + shared_ptr<HieroCachingRuleFactory> factory; +}; + +TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(6) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) + .Times(1) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + vector<int> word_ids = {2, 3, 4}; + Grammar grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(7, grammar.GetRules().size()); +} + +TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(12) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) + .Times(16) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + vector<int> word_ids = {2, 3, 4, 2, 3}; + Grammar grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(28, grammar.GetRules().size()); +} + +} // namespace diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 37a9cba0..ed30e6fe 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -114,8 +114,8 @@ int main(int argc, char** argv) { make_shared<TargetGivenSourceCoherent>(), make_shared<SampleSourceCount>(), make_shared<CountSourceTarget>(), - make_shared<MaxLexTargetGivenSource>(table), make_shared<MaxLexSourceGivenTarget>(table), + make_shared<MaxLexTargetGivenSource>(table), make_shared<IsSourceSingleton>(), make_shared<IsSourceTargetSingleton>() }; @@ -138,6 +138,10 @@ int main(int argc, char** argv) { int grammar_id = 0; fs::path grammar_path = vm["grammars"].as<string>(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + string sentence, delimiter = "|||"; while (getline(cin, sentence)) { string suffix = ""; @@ -148,7 +152,8 @@ int main(int argc, char** argv) { } Grammar grammar = extractor.GetGrammar(sentence); - fs::path grammar_file = grammar_path / to_string(grammar_id); + string file_name = "grammar." + to_string(grammar_id); + fs::path grammar_file = grammar_path / file_name; ofstream output(grammar_file.c_str()); output << grammar; diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt new file mode 100644 index 00000000..80b446a4 --- /dev/null +++ b/extractor/sample_alignment.txt @@ -0,0 +1,2 @@ +0-0 1-1 2-2 +1-0 2-1 diff --git a/extractor/sampler.cc b/extractor/sampler.cc index d8e0f49e..5067ca8a 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -6,6 +6,10 @@ Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) : suffix_array(suffix_array), max_samples(max_samples) {} +Sampler::Sampler() {} + +Sampler::~Sampler() {} + PhraseLocation Sampler::Sample(const PhraseLocation& location) const { vector<int> sample; int num_subpatterns; @@ -32,5 +36,6 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { } int Sampler::Round(double x) const { - return x + 0.5; + // TODO(pauldb): Remove EPS. + return x + 0.5 + 1e-8; } diff --git a/extractor/sampler.h b/extractor/sampler.h index 3b3e3a4d..9cf321fb 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -12,7 +12,12 @@ class Sampler { public: Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); - PhraseLocation Sample(const PhraseLocation& location) const; + virtual ~Sampler(); + + virtual PhraseLocation Sample(const PhraseLocation& location) const; + + protected: + Sampler(); private: int Round(double x) const; diff --git a/extractor/scorer.cc b/extractor/scorer.cc index c87e179d..f28b3181 100644 --- a/extractor/scorer.cc +++ b/extractor/scorer.cc @@ -5,6 +5,10 @@ Scorer::Scorer(const vector<shared_ptr<Feature> >& features) : features(features) {} +Scorer::Scorer() {} + +Scorer::~Scorer() {} + vector<double> Scorer::Score(const FeatureContext& context) const { vector<double> scores; for (auto feature: features) { diff --git a/extractor/scorer.h b/extractor/scorer.h index 5b328fb4..ba71a6ee 100644 --- a/extractor/scorer.h +++ b/extractor/scorer.h @@ -14,9 +14,14 @@ class Scorer { public: Scorer(const vector<shared_ptr<Feature> >& features); - vector<double> Score(const FeatureContext& context) const; + virtual ~Scorer(); - vector<string> GetFeatureNames() const; + virtual vector<double> Score(const FeatureContext& context) const; + + virtual vector<string> GetFeatureNames() const; + + protected: + Scorer(); private: vector<shared_ptr<Feature> > features; diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc new file mode 100644 index 00000000..56a85762 --- /dev/null +++ b/extractor/scorer_test.cc @@ -0,0 +1,47 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_feature.h" +#include "scorer.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class ScorerTest : public Test { + protected: + virtual void SetUp() { + feature1 = make_shared<MockFeature>(); + EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5)); + EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1")); + + feature2 = make_shared<MockFeature>(); + EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3)); + EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2")); + + vector<shared_ptr<Feature> > features = {feature1, feature2}; + scorer = make_shared<Scorer>(features); + } + + shared_ptr<MockFeature> feature1; + shared_ptr<MockFeature> feature2; + shared_ptr<Scorer> scorer; +}; + +TEST_F(ScorerTest, TestScore) { + vector<double> expected_scores = {0.5, -1.3}; + Phrase phrase; + FeatureContext context(phrase, phrase, 0.3, 2, 11); + EXPECT_EQ(expected_scores, scorer->Score(context)); +} + +TEST_F(ScorerTest, TestGetNames) { + vector<string> expected_names = {"f1", "f2"}; + EXPECT_EQ(expected_names, scorer->GetFeatureNames()); +} + +} // namespace diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index d13eacd5..9815996f 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -22,9 +22,9 @@ SuffixArray::~SuffixArray() {} void SuffixArray::BuildSuffixArray() { vector<int> groups = data_array->GetData(); groups.reserve(groups.size() + 1); - groups.push_back(data_array->GetVocabularySize()); + groups.push_back(DataArray::NULL_WORD); suffix_array.resize(groups.size()); - word_start.resize(data_array->GetVocabularySize() + 2); + word_start.resize(data_array->GetVocabularySize() + 1); InitialBucketSort(groups); @@ -112,6 +112,8 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step, } } + TernaryQuicksort(left, mid_left - 1, step, groups); + if (mid_left == mid_right) { groups[suffix_array[mid_left]] = mid_left; suffix_array[mid_left] = -1; @@ -121,7 +123,6 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step, } } - TernaryQuicksort(left, mid_left - 1, step, groups); TernaryQuicksort(mid_right + 1, right, step, groups); } @@ -201,7 +202,7 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id, int result = high; while (low < high) { int middle = low + (high - low) / 2; - if (suffix_array[middle] + offset < data_array->GetSize() && + if (suffix_array[middle] + offset >= data_array->GetSize() || data_array->AtIndex(suffix_array[middle] + offset) < word_id) { low = middle + 1; } else { diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index d891933c..60295567 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -14,10 +14,10 @@ namespace { class SuffixArrayTest : public Test { protected: virtual void SetUp() { - data = vector<int>{5, 3, 0, 1, 3, 4, 2, 3, 5, 5, 3, 0, 1}; + data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; data_array = make_shared<MockDataArray>(); EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); - EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(6)); + EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); suffix_array = make_shared<SuffixArray>(data_array); } @@ -33,14 +33,15 @@ TEST_F(SuffixArrayTest, TestData) { } TEST_F(SuffixArrayTest, TestBuildSuffixArray) { - vector<int> expected_suffix_array{2, 11, 3, 12, 6, 1, 10, 4, 7, 5, 0, 9, 8}; + vector<int> expected_suffix_array = + {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8}; for (size_t i = 0; i < expected_suffix_array.size(); ++i) { EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); } } TEST_F(SuffixArrayTest, TestBuildLCP) { - vector<int> expected_lcp{-1, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1, 0}; + vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1}; EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); } @@ -50,26 +51,26 @@ TEST_F(SuffixArrayTest, TestLookup) { } EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); - EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(5)); - EXPECT_EQ(PhraseLocation(10, 13), suffix_array->Lookup(0, 14, "word1", 0)); + EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); + EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0)); EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); - EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(3)); - EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 13, "word3", 1)); + EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1)); EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); - EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(0)); - EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word4", 2)); + EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2)); EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); - EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(1)); - EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word5", 3)); + EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3)); - EXPECT_EQ(PhraseLocation(10, 11), suffix_array->Lookup(10, 12, "word3", 4)); - EXPECT_EQ(PhraseLocation(10, 10), suffix_array->Lookup(10, 12, "word5", 1)); + EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4)); + EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); } } // namespace diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc new file mode 100644 index 00000000..ac583953 --- /dev/null +++ b/extractor/target_phrase_extractor.cc @@ -0,0 +1,144 @@ +#include "target_phrase_extractor.h" + +#include <unordered_set> + +#include "alignment.h" +#include "data_array.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "rule_extractor_helper.h" +#include "vocabulary.h" + +using namespace std; + +TargetPhraseExtractor::TargetPhraseExtractor( + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractorHelper> helper, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + bool require_tight_phrases) : + target_data_array(target_data_array), + alignment(alignment), + phrase_builder(phrase_builder), + helper(helper), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + require_tight_phrases(require_tight_phrases) {} + +TargetPhraseExtractor::TargetPhraseExtractor() {} + +TargetPhraseExtractor::~TargetPhraseExtractor() {} + +vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const { + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + + vector<int> target_gap_order = helper->GetGapOrder(target_gaps); + + int target_x_low = target_phrase_low, target_x_high = target_phrase_high; + if (!require_tight_phrases) { + while (target_x_low > 0 && + target_phrase_high - target_x_low < max_rule_span && + target_low[target_x_low - 1] == -1) { + --target_x_low; + } + while (target_x_high < target_sent_len && + target_x_high - target_phrase_low < max_rule_span && + target_low[target_x_high] == -1) { + ++target_x_high; + } + } + + vector<pair<int, int> > gaps(target_gaps.size()); + for (size_t i = 0; i < gaps.size(); ++i) { + gaps[i] = target_gaps[target_gap_order[i]]; + if (!require_tight_phrases) { + while (gaps[i].first > target_x_low && + target_low[gaps[i].first - 1] == -1) { + --gaps[i].first; + } + while (gaps[i].second < target_x_high && + target_low[gaps[i].second] == -1) { + ++gaps[i].second; + } + } + } + + vector<pair<int, int> > ranges(2 * gaps.size() + 2); + ranges.front() = make_pair(target_x_low, target_phrase_low); + ranges.back() = make_pair(target_phrase_high, target_x_high); + for (size_t i = 0; i < gaps.size(); ++i) { + int j = target_gap_order[i]; + ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first); + ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second); + } + + vector<pair<Phrase, PhraseAlignment> > target_phrases; + vector<int> subpatterns(ranges.size()); + GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, + target_phrase_low, target_phrase_high, source_indexes, + sentence_id); + return target_phrases; +} + +void TargetPhraseExtractor::GeneratePhrases( + vector<pair<Phrase, PhraseAlignment> >& target_phrases, + const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns, + const vector<int>& target_gap_order, int target_phrase_low, + int target_phrase_high, const unordered_map<int, int>& source_indexes, + int sentence_id) const { + if (index >= ranges.size()) { + if (subpatterns.back() - subpatterns.front() > max_rule_span) { + return; + } + + vector<int> symbols; + unordered_map<int, int> target_indexes; + + int target_sent_start = target_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { + for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { + target_indexes[j] = symbols.size(); + string target_word = target_data_array->GetWordAtIndex( + target_sent_start + j); + symbols.push_back(vocabulary->GetTerminalIndex(target_word)); + } + if (i < target_gap_order.size()) { + symbols.push_back(vocabulary->GetNonterminalIndex( + target_gap_order[i] + 1)); + } + } + + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + vector<pair<int, int> > alignment; + for (pair<int, int> link: links) { + if (target_indexes.count(link.second)) { + alignment.push_back(make_pair(source_indexes.find(link.first)->second, + target_indexes[link.second])); + } + } + + Phrase target_phrase = phrase_builder->Build(symbols); + target_phrases.push_back(make_pair(target_phrase, alignment)); + return; + } + + subpatterns[index] = ranges[index].first; + if (index > 0) { + subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); + } + while (subpatterns[index] <= ranges[index].second) { + subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); + while (subpatterns[index + 1] <= ranges[index + 1].second) { + GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, + target_gap_order, target_phrase_low, target_phrase_high, + source_indexes, sentence_id); + ++subpatterns[index + 1]; + } + ++subpatterns[index]; + } +} diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h new file mode 100644 index 00000000..134f24cc --- /dev/null +++ b/extractor/target_phrase_extractor.h @@ -0,0 +1,56 @@ +#ifndef _TARGET_PHRASE_EXTRACTOR_H_ +#define _TARGET_PHRASE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +class Alignment; +class DataArray; +class Phrase; +class PhraseBuilder; +class RuleExtractorHelper; +class Vocabulary; + +typedef vector<pair<int, int> > PhraseAlignment; + +class TargetPhraseExtractor { + public: + TargetPhraseExtractor(shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractorHelper> helper, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + bool require_tight_phrases); + + virtual ~TargetPhraseExtractor(); + + virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases( + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const; + + protected: + TargetPhraseExtractor(); + + private: + void GeneratePhrases( + vector<pair<Phrase, PhraseAlignment> >& target_phrases, + const vector<pair<int, int> >& ranges, int index, + vector<int>& subpatterns, const vector<int>& target_gap_order, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const; + + shared_ptr<DataArray> target_data_array; + shared_ptr<Alignment> alignment; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<RuleExtractorHelper> helper; + shared_ptr<Vocabulary> vocabulary; + int max_rule_span; + bool require_tight_phrases; +}; + +#endif diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc new file mode 100644 index 00000000..7394f4d9 --- /dev/null +++ b/extractor/target_phrase_extractor_test.cc @@ -0,0 +1,116 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "target_phrase_extractor.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class TargetPhraseExtractorTest : public Test { + protected: + virtual void SetUp() { + data_array = make_shared<MockDataArray>(); + alignment = make_shared<MockAlignment>(); + vocabulary = make_shared<MockVocabulary>(); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + helper = make_shared<MockRuleExtractorHelper>(); + } + + shared_ptr<MockDataArray> data_array; + shared_ptr<MockAlignment> alignment; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockRuleExtractorHelper> helper; + shared_ptr<TargetPhraseExtractor> extractor; +}; + +TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) { + EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5)); + EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3)); + + vector<string> target_words = {"a", "b", "c", "d", "e"}; + vector<int> target_symbols = {20, 21, 22, 23, 24}; + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordAtIndex(i + 3)) + .WillRepeatedly(Return(target_words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) + .WillRepeatedly(Return(target_symbols[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) + .WillRepeatedly(Return(target_words[i])); + } + + vector<pair<int, int> > links = { + make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1), + make_pair(4, 4) + }; + EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links)); + + vector<int> gap_order = {1, 0}; + EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + + extractor = make_shared<TargetPhraseExtractor>( + data_array, alignment, phrase_builder, helper, vocabulary, 10, true); + + vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)}; + vector<int> target_low = {0, 3, 2, 1, 4}; + unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}}; + + vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases( + target_gaps, target_low, 0, 5, source_indexes, 1); + EXPECT_EQ(1, results.size()); + vector<int> expected_symbols = {20, -2, 22, -1, 24}; + EXPECT_EQ(expected_symbols, results[0].first.Get()); + vector<string> expected_words = {"a", "c", "e"}; + EXPECT_EQ(expected_words, results[0].first.GetWords()); + vector<pair<int, int> > expected_alignment = { + make_pair(0, 0), make_pair(2, 2), make_pair(4, 4) + }; + EXPECT_EQ(expected_alignment, results[0].second); +} + +TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) { + vector<string> target_words = {"a", "b", "c", "d", "e", "f"}; + vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 26}; + EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6)); + EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0)); + + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordAtIndex(i)) + .WillRepeatedly(Return(target_words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) + .WillRepeatedly(Return(target_symbols[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) + .WillRepeatedly(Return(target_words[i])); + } + + vector<pair<int, int> > links = {make_pair(1, 1)}; + EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links)); + + vector<int> gap_order = {0}; + EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + + extractor = make_shared<TargetPhraseExtractor>( + data_array, alignment, phrase_builder, helper, vocabulary, 10, false); + + vector<pair<int, int> > target_gaps = {make_pair(2, 4)}; + vector<int> target_low = {-1, 1, -1, -1, -1, -1}; + unordered_map<int, int> source_indexes = {{1, 1}}; + + vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases( + target_gaps, target_low, 1, 5, source_indexes, 0); + EXPECT_EQ(10, results.size()); + // TODO(pauldb): Finish unit test once it's clear how these alignments should + // look like. +} + +} // namespace diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 10f1b9ed..a48c0657 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -9,7 +9,6 @@ #include "data_array.h" using namespace std; -using namespace tr1; TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, shared_ptr<DataArray> target_data_array, @@ -20,14 +19,15 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, unordered_map<int, int> source_links_count; unordered_map<int, int> target_links_count; - unordered_map<pair<int, int>, int, PairHash > links_count; + unordered_map<pair<int, int>, int, PairHash> links_count; for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { - const vector<pair<int, int> >& links = alignment->GetLinks(i); + vector<pair<int, int> > links = alignment->GetLinks(i); int source_start = source_data_array->GetSentenceStart(i); - int next_source_start = source_data_array->GetSentenceStart(i + 1); int target_start = target_data_array->GetSentenceStart(i); - int next_target_start = target_data_array->GetSentenceStart(i + 1); + // Ignore END_OF_LINE markers. + int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1; + int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1; vector<int> source_sentence(source_data.begin() + source_start, source_data.begin() + next_source_start); vector<int> target_sentence(target_data.begin() + target_start, @@ -38,15 +38,23 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, for (pair<int, int> link: links) { source_linked_words[link.first] = 1; target_linked_words[link.second] = 1; - int source_word = source_sentence[link.first]; - int target_word = target_sentence[link.second]; + IncreaseLinksCount(source_links_count, target_links_count, links_count, + source_sentence[link.first], target_sentence[link.second]); + } - ++source_links_count[source_word]; - ++target_links_count[target_word]; - ++links_count[make_pair(source_word, target_word)]; + for (size_t i = 0; i < source_sentence.size(); ++i) { + if (!source_linked_words[i]) { + IncreaseLinksCount(source_links_count, target_links_count, links_count, + source_sentence[i], DataArray::NULL_WORD); + } } - // TODO(pauldb): Something seems wrong here. No NULL word? + for (size_t i = 0; i < target_sentence.size(); ++i) { + if (!target_linked_words[i]) { + IncreaseLinksCount(source_links_count, target_links_count, links_count, + DataArray::NULL_WORD, target_sentence[i]); + } + } } for (pair<pair<int, int>, int> link_count: links_count) { @@ -58,6 +66,21 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, } } +TranslationTable::TranslationTable() {} + +TranslationTable::~TranslationTable() {} + +void TranslationTable::IncreaseLinksCount( + unordered_map<int, int>& source_links_count, + unordered_map<int, int>& target_links_count, + unordered_map<pair<int, int>, int, PairHash>& links_count, + int source_word_id, + int target_word_id) const { + ++source_links_count[source_word_id]; + ++target_links_count[target_word_id]; + ++links_count[make_pair(source_word_id, target_word_id)]; +} + double TranslationTable::GetTargetGivenSourceScore( const string& source_word, const string& target_word) { if (!source_data_array->HasWord(source_word) || @@ -73,7 +96,7 @@ double TranslationTable::GetTargetGivenSourceScore( double TranslationTable::GetSourceGivenTargetScore( const string& source_word, const string& target_word) { if (!source_data_array->HasWord(source_word) || - !target_data_array->HasWord(target_word) == 0) { + !target_data_array->HasWord(target_word)) { return -1; } diff --git a/extractor/translation_table.h b/extractor/translation_table.h index acf94af7..157ad3af 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -3,13 +3,12 @@ #include <memory> #include <string> -#include <tr1/unordered_map> +#include <unordered_map> #include <boost/filesystem.hpp> #include <boost/functional/hash.hpp> using namespace std; -using namespace tr1; namespace fs = boost::filesystem; class Alignment; @@ -24,15 +23,27 @@ class TranslationTable { shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment); - double GetTargetGivenSourceScore(const string& source_word, - const string& target_word); + virtual ~TranslationTable(); - double GetSourceGivenTargetScore(const string& source_word, - const string& target_word); + virtual double GetTargetGivenSourceScore(const string& source_word, + const string& target_word); + + virtual double GetSourceGivenTargetScore(const string& source_word, + const string& target_word); void WriteBinary(const fs::path& filepath) const; + protected: + TranslationTable(); + private: + void IncreaseLinksCount( + unordered_map<int, int>& source_links_count, + unordered_map<int, int>& target_links_count, + unordered_map<pair<int, int>, int, PairHash>& links_count, + int source_word_id, + int target_word_id) const; + shared_ptr<DataArray> source_data_array; shared_ptr<DataArray> target_data_array; unordered_map<pair<int, int>, pair<double, double>, PairHash> diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc new file mode 100644 index 00000000..c99f3f93 --- /dev/null +++ b/extractor/translation_table_test.cc @@ -0,0 +1,82 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "translation_table.h" + +using namespace std; +using namespace ::testing; + +namespace { + +TEST(TranslationTableTest, TestScores) { + vector<string> words = {"a", "b", "c"}; + + vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; + vector<int> source_sentence_start = {0, 6, 10, 14}; + shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetData()) + .WillRepeatedly(ReturnRef(source_data)); + EXPECT_CALL(*source_data_array, GetNumSentences()) + .WillRepeatedly(Return(3)); + for (size_t i = 0; i < source_sentence_start.size(); ++i) { + EXPECT_CALL(*source_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(source_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*source_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*source_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*source_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; + vector<int> target_sentence_start = {0, 7, 10, 13}; + shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*target_data_array, GetData()) + .WillRepeatedly(ReturnRef(target_data)); + for (size_t i = 0; i < target_sentence_start.size(); ++i) { + EXPECT_CALL(*target_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(target_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*target_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*target_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*target_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<pair<int, int> > links1 = { + make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), + make_pair(4, 4), make_pair(4, 5) + }; + vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)}; + vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)}; + shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); + EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); + EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); + EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); + + shared_ptr<TranslationTable> table = make_shared<TranslationTable>( + source_data_array, target_data_array, alignment); + + EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a")); + EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b")); + EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c")); + EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d")); + + EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a")); + EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b")); + EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c")); + EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); +} + +} // namespace diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index ed55e5e4..c6a8b3e8 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -2,11 +2,10 @@ #define _VOCABULARY_H_ #include <string> -#include <tr1/unordered_map> +#include <unordered_map> #include <vector> using namespace std; -using namespace tr1; class Vocabulary { public: |