From 4ab84a0be28fdb6c0c421fe5ba5e09cfa298f2d1 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 28 Jan 2013 11:56:31 +0000 Subject: Initial working commit. --- extractor/compile.cc | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 extractor/compile.cc (limited to 'extractor/compile.cc') diff --git a/extractor/compile.cc b/extractor/compile.cc new file mode 100644 index 00000000..c3ea3c8d --- /dev/null +++ b/extractor/compile.cc @@ -0,0 +1,98 @@ +#include +#include + +#include +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; + +int main(int argc, char** argv) { + po::options_description desc("Command line options"); + desc.add_options() + ("help,h", "Show available options") + ("source,f", po::value(), "Source language corpus") + ("target,e", po::value(), "Target language corpus") + ("bitext,b", po::value(), "Parallel text (source ||| target)") + ("alignment,a", po::value()->required(), "Bitext word alignment") + ("output,o", po::value()->required(), "Output path") + ("frequent", po::value()->default_value(100), + "Number of precomputed frequent patterns") + ("super_frequent", po::value()->default_value(10), + "Number of precomputed super frequent patterns") + ("max_rule_span,s", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols,l", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size,g", po::value()->default_value(1), "Minimum gap size") + ("max_phrase_len,p", po::value()->default_value(4), + "Maximum frequent phrase length") + ("min_frequency", po::value()->default_value(1000), + "Minimum number of occurences for a pharse to be considered frequent"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + + // Check for help argument before notify, so we don't need to pass in the + // required parameters. + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + po::notify(vm); + + if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { + cerr << "A paralel corpus is required. " + << "Use -f (source) with -e (target) or -b (bitext)." + << endl; + return 1; + } + + fs::path output_dir(vm["output"].as().c_str()); + if (!fs::exists(output_dir)) { + fs::create_directory(output_dir); + } + + shared_ptr source_data_array, target_data_array; + if (vm.count("bitext")) { + source_data_array = make_shared( + vm["bitext"].as(), SOURCE); + target_data_array = make_shared( + vm["bitext"].as(), TARGET); + } else { + source_data_array = make_shared(vm["source"].as()); + target_data_array = make_shared(vm["target"].as()); + } + shared_ptr source_suffix_array = + make_shared(source_data_array); + source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); + target_data_array->WriteBinary(output_dir / fs::path("e.bin")); + + Alignment alignment(vm["alignment"].as()); + alignment.WriteBinary(output_dir / fs::path("a.bin")); + + Precomputation precomputation( + source_suffix_array, + vm["frequent"].as(), + vm["super_frequent"].as(), + vm["max_rule_span"].as(), + vm["max_rule_symbols"].as(), + vm["min_gap_size"].as(), + vm["max_phrase_len"].as(), + vm["min_frequency"].as()); + precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); + + TranslationTable table(source_data_array, target_data_array, alignment); + table.WriteBinary(output_dir / fs::path("lex.bin")); + + return 0; +} -- cgit v1.2.3 From 252fb164c208ec8f3005f8a652eb3b48c0644e3d Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Fri, 1 Feb 2013 16:11:10 +0000 Subject: Second working commit. --- extractor/Makefile.am | 23 +- extractor/alignment.cc | 2 +- extractor/alignment.h | 2 +- extractor/binary_search_merger.cc | 4 + extractor/binary_search_merger.h | 7 +- extractor/binary_search_merger_test.cc | 4 +- extractor/compile.cc | 5 +- extractor/data_array.h | 7 +- extractor/features/count_source_target.cc | 11 + extractor/features/count_source_target.h | 13 + extractor/features/feature.cc | 3 + extractor/features/feature.h | 32 + extractor/features/is_source_singleton.cc | 11 + extractor/features/is_source_singleton.h | 13 + extractor/features/is_source_target_singleton.cc | 11 + extractor/features/is_source_target_singleton.h | 13 + extractor/features/max_lex_source_given_target.cc | 30 + extractor/features/max_lex_source_given_target.h | 24 + extractor/features/max_lex_target_given_source.cc | 30 + extractor/features/max_lex_target_given_source.h | 24 + extractor/features/sample_source_count.cc | 11 + extractor/features/sample_source_count.h | 13 + extractor/features/target_given_source_coherent.cc | 12 + extractor/features/target_given_source_coherent.h | 13 + extractor/grammar.cc | 24 + extractor/grammar.h | 23 + extractor/grammar_extractor.cc | 20 +- extractor/grammar_extractor.h | 15 +- extractor/intersector.cc | 50 +- extractor/intersector.h | 22 +- extractor/intersector_test.cc | 193 ++++++ extractor/linear_merger.cc | 2 + extractor/linear_merger.h | 3 + extractor/mocks/mock_binary_search_merger.h | 15 + extractor/mocks/mock_data_array.h | 1 + extractor/mocks/mock_linear_merger.h | 10 +- extractor/mocks/mock_precomputation.h | 9 + extractor/mocks/mock_suffix_array.h | 6 +- extractor/mocks/mock_vocabulary.h | 1 + extractor/phrase.cc | 29 + extractor/phrase.h | 12 + extractor/phrase_builder.cc | 24 + extractor/phrase_builder.h | 2 + extractor/phrase_location.cc | 10 +- extractor/phrase_location.h | 2 +- extractor/precomputation.cc | 18 +- extractor/precomputation.h | 15 +- extractor/precomputation_test.cc | 138 +++++ extractor/rule.cc | 10 + extractor/rule.h | 20 + extractor/rule_extractor.cc | 675 ++++++++++++++++++++- extractor/rule_extractor.h | 120 +++- extractor/rule_factory.cc | 56 +- extractor/rule_factory.h | 31 +- extractor/run_extractor.cc | 70 ++- extractor/sampler.cc | 36 ++ extractor/sampler.h | 24 + extractor/sampler_test.cc | 72 +++ extractor/scorer.cc | 21 +- extractor/scorer.h | 16 +- extractor/suffix_array.cc | 2 + extractor/suffix_array.h | 9 +- extractor/translation_table.cc | 10 +- extractor/translation_table.h | 18 +- extractor/vocabulary.h | 2 +- 65 files changed, 2005 insertions(+), 149 deletions(-) create mode 100644 extractor/features/count_source_target.cc create mode 100644 extractor/features/count_source_target.h create mode 100644 extractor/features/feature.cc create mode 100644 extractor/features/feature.h create mode 100644 extractor/features/is_source_singleton.cc create mode 100644 extractor/features/is_source_singleton.h create mode 100644 extractor/features/is_source_target_singleton.cc create mode 100644 extractor/features/is_source_target_singleton.h create mode 100644 extractor/features/max_lex_source_given_target.cc create mode 100644 extractor/features/max_lex_source_given_target.h create mode 100644 extractor/features/max_lex_target_given_source.cc create mode 100644 extractor/features/max_lex_target_given_source.h create mode 100644 extractor/features/sample_source_count.cc create mode 100644 extractor/features/sample_source_count.h create mode 100644 extractor/features/target_given_source_coherent.cc create mode 100644 extractor/features/target_given_source_coherent.h create mode 100644 extractor/grammar.cc create mode 100644 extractor/grammar.h create mode 100644 extractor/intersector_test.cc create mode 100644 extractor/mocks/mock_binary_search_merger.h create mode 100644 extractor/mocks/mock_precomputation.h create mode 100644 extractor/precomputation_test.cc create mode 100644 extractor/rule.cc create mode 100644 extractor/rule.h create mode 100644 extractor/sampler.cc create mode 100644 extractor/sampler.h create mode 100644 extractor/sampler_test.cc (limited to 'extractor/compile.cc') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 844c0ef3..ded06239 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -3,22 +3,27 @@ bin_PROGRAMS = compile run_extractor noinst_PROGRAMS = \ binary_search_merger_test \ data_array_test \ + intersector_test \ linear_merger_test \ matching_comparator_test \ matching_test \ matchings_finder_test \ phrase_test \ precomputation_test \ + sampler_test \ suffix_array_test \ veb_test -TESTS = precomputation_test +TESTS = sampler_test #TESTS = binary_search_merger_test \ # data_array_test \ +# intersector_test \ # linear_merger_test \ # matching_comparator_test \ # matching_test \ +# matchings_finder_test \ # phrase_test \ +# precomputation_test \ # suffix_array_test \ # veb_test @@ -26,6 +31,8 @@ binary_search_merger_test_SOURCES = binary_search_merger_test.cc binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a data_array_test_SOURCES = data_array_test.cc data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +intersector_test_SOURCES = intersector_test.cc +intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a linear_merger_test_SOURCES = linear_merger_test.cc linear_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a matching_comparator_test_SOURCES = matching_comparator_test.cc @@ -40,6 +47,8 @@ precomputation_test_SOURCES = precomputation_test.cc precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a suffix_array_test_SOURCES = suffix_array_test.cc suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +sampler_test_SOURCES = sampler_test.cc +sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a veb_test_SOURCES = veb_test.cc veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -62,6 +71,15 @@ libextractor_a_SOURCES = \ alignment.cc \ binary_search_merger.cc \ data_array.cc \ + features/count_source_target.cc \ + features/feature.cc \ + features/is_source_singleton.cc \ + features/is_source_target_singleton.cc \ + features/max_lex_source_given_target.cc \ + features/max_lex_target_given_source.cc \ + features/sample_source_count.cc \ + features/target_given_source_coherent.cc \ + grammar.cc \ grammar_extractor.cc \ matching.cc \ matching_comparator.cc \ @@ -73,8 +91,11 @@ libextractor_a_SOURCES = \ phrase_builder.cc \ phrase_location.cc \ precomputation.cc \ + rule.cc \ rule_extractor.cc \ rule_factory.cc \ + sampler.cc \ + scorer.cc \ suffix_array.cc \ translation_table.cc \ veb.cc \ diff --git a/extractor/alignment.cc b/extractor/alignment.cc index cad28a72..2fa0abac 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -31,7 +31,7 @@ Alignment::Alignment(const string& filename) { alignments.shrink_to_fit(); } -vector > Alignment::GetLinks(int sentence_index) const { +const vector >& Alignment::GetLinks(int sentence_index) const { return alignments[sentence_index]; } diff --git a/extractor/alignment.h b/extractor/alignment.h index e357e468..290d6015 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -13,7 +13,7 @@ class Alignment { public: Alignment(const string& filename); - vector > GetLinks(int sentence_index) const; + const vector >& GetLinks(int sentence_index) const; void WriteBinary(const fs::path& filepath); diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc index 7b018876..43d2f734 100644 --- a/extractor/binary_search_merger.cc +++ b/extractor/binary_search_merger.cc @@ -19,6 +19,10 @@ BinarySearchMerger::BinarySearchMerger( data_array(data_array), comparator(comparator), force_binary_search_merge(force_binary_search_merge) {} +BinarySearchMerger::BinarySearchMerger() {} + +BinarySearchMerger::~BinarySearchMerger() {} + void BinarySearchMerger::Merge( vector& locations, const Phrase& phrase, const Phrase& suffix, vector::iterator prefix_start, vector::iterator prefix_end, diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h index 0e229b3b..ffa47c8e 100644 --- a/extractor/binary_search_merger.h +++ b/extractor/binary_search_merger.h @@ -20,7 +20,9 @@ class BinarySearchMerger { shared_ptr comparator, bool force_binary_search_merge = false); - void Merge( + virtual ~BinarySearchMerger(); + + virtual void Merge( vector& locations, const Phrase& phrase, const Phrase& suffix, vector::iterator prefix_start, vector::iterator prefix_end, vector::iterator suffix_start, vector::iterator suffix_end, @@ -28,6 +30,9 @@ class BinarySearchMerger { static double BAEZA_YATES_FACTOR; + protected: + BinarySearchMerger(); + private: bool IsIntersectionVoid( vector::iterator prefix_start, vector::iterator prefix_end, diff --git a/extractor/binary_search_merger_test.cc b/extractor/binary_search_merger_test.cc index 20350b1e..b1baa62f 100644 --- a/extractor/binary_search_merger_test.cc +++ b/extractor/binary_search_merger_test.cc @@ -34,8 +34,8 @@ class BinarySearchMergerTest : public Test { // We are going to force the binary_search_merger to do all the work, so we // need to check that the linear_merger never gets called. - shared_ptr linear_merger = make_shared( - vocabulary, data_array, comparator); + shared_ptr linear_merger = + make_shared(); EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); binary_search_merger = make_shared( diff --git a/extractor/compile.cc b/extractor/compile.cc index c3ea3c8d..f5cd41f4 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -77,8 +77,9 @@ int main(int argc, char** argv) { source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); target_data_array->WriteBinary(output_dir / fs::path("e.bin")); - Alignment alignment(vm["alignment"].as()); - alignment.WriteBinary(output_dir / fs::path("a.bin")); + shared_ptr alignment = + make_shared(vm["alignment"].as()); + alignment->WriteBinary(output_dir / fs::path("a.bin")); Precomputation precomputation( source_suffix_array, diff --git a/extractor/data_array.h b/extractor/data_array.h index 6d3e99d5..19fbff88 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -23,8 +23,6 @@ class DataArray { static string END_OF_FILE_STR; static string END_OF_LINE_STR; - DataArray(); - DataArray(const string& filename); DataArray(const string& filename, const Side& side); @@ -43,7 +41,7 @@ class DataArray { virtual int GetWordId(const string& word) const; - string GetWord(int word_id) const; + virtual string GetWord(int word_id) const; int GetNumSentences() const; @@ -55,6 +53,9 @@ class DataArray { void WriteBinary(FILE* file) const; + protected: + DataArray(); + private: void InitializeDataArray(); void CreateDataArray(const vector& lines); diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc new file mode 100644 index 00000000..9441b451 --- /dev/null +++ b/extractor/features/count_source_target.cc @@ -0,0 +1,11 @@ +#include "count_source_target.h" + +#include + +double CountSourceTarget::Score(const FeatureContext& context) const { + return log10(1 + context.pair_count); +} + +string CountSourceTarget::GetName() const { + return "CountEF"; +} diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h new file mode 100644 index 00000000..a2481944 --- /dev/null +++ b/extractor/features/count_source_target.h @@ -0,0 +1,13 @@ +#ifndef _COUNT_SOURCE_TARGET_H_ +#define _COUNT_SOURCE_TARGET_H_ + +#include "feature.h" + +class CountSourceTarget : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc new file mode 100644 index 00000000..7381c35a --- /dev/null +++ b/extractor/features/feature.cc @@ -0,0 +1,3 @@ +#include "feature.h" + +const double Feature::MAX_SCORE = 99.0; diff --git a/extractor/features/feature.h b/extractor/features/feature.h new file mode 100644 index 00000000..ad22d3e7 --- /dev/null +++ b/extractor/features/feature.h @@ -0,0 +1,32 @@ +#ifndef _FEATURE_H_ +#define _FEATURE_H_ + +#include + +//TODO(pauldb): include headers nicely. +#include "../phrase.h" + +using namespace std; + +struct FeatureContext { + FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, + double sample_source_count, int pair_count) : + source_phrase(source_phrase), target_phrase(target_phrase), + sample_source_count(sample_source_count), pair_count(pair_count) {} + + Phrase source_phrase; + Phrase target_phrase; + double sample_source_count; + int pair_count; +}; + +class Feature { + public: + virtual double Score(const FeatureContext& context) const = 0; + + virtual string GetName() const = 0; + + static const double MAX_SCORE; +}; + +#endif diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc new file mode 100644 index 00000000..754df3bf --- /dev/null +++ b/extractor/features/is_source_singleton.cc @@ -0,0 +1,11 @@ +#include "is_source_singleton.h" + +#include + +double IsSourceSingleton::Score(const FeatureContext& context) const { + return context.sample_source_count == 1; +} + +string IsSourceSingleton::GetName() const { + return "IsSingletonF"; +} diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h new file mode 100644 index 00000000..7cc72828 --- /dev/null +++ b/extractor/features/is_source_singleton.h @@ -0,0 +1,13 @@ +#ifndef _IS_SOURCE_SINGLETON_H_ +#define _IS_SOURCE_SINGLETON_H_ + +#include "feature.h" + +class IsSourceSingleton : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc new file mode 100644 index 00000000..ec816509 --- /dev/null +++ b/extractor/features/is_source_target_singleton.cc @@ -0,0 +1,11 @@ +#include "is_source_target_singleton.h" + +#include + +double IsSourceTargetSingleton::Score(const FeatureContext& context) const { + return context.pair_count == 1; +} + +string IsSourceTargetSingleton::GetName() const { + return "IsSingletonEF"; +} diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h new file mode 100644 index 00000000..58913b74 --- /dev/null +++ b/extractor/features/is_source_target_singleton.h @@ -0,0 +1,13 @@ +#ifndef _IS_SOURCE_TARGET_SINGLETON_H_ +#define _IS_SOURCE_TARGET_SINGLETON_H_ + +#include "feature.h" + +class IsSourceTargetSingleton : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc new file mode 100644 index 00000000..c4792d49 --- /dev/null +++ b/extractor/features/max_lex_source_given_target.cc @@ -0,0 +1,30 @@ +#include "max_lex_source_given_target.h" + +#include + +#include "../translation_table.h" + +MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( + shared_ptr table) : + table(table) {} + +double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { + vector source_words = context.source_phrase.GetWords(); + // TODO(pauldb): Add NULL to target_words, after fixing translation table. + vector target_words = context.target_phrase.GetWords(); + + double score = 0; + for (string source_word: source_words) { + double max_score = 0; + for (string target_word: target_words) { + max_score = max(max_score, + table->GetSourceGivenTargetScore(source_word, target_word)); + } + score += max_score > 0 ? -log10(max_score) : MAX_SCORE; + } + return score; +} + +string MaxLexSourceGivenTarget::GetName() const { + return "MaxLexFGivenE"; +} diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h new file mode 100644 index 00000000..e87c1c8e --- /dev/null +++ b/extractor/features/max_lex_source_given_target.h @@ -0,0 +1,24 @@ +#ifndef _MAX_LEX_SOURCE_GIVEN_TARGET_H_ +#define _MAX_LEX_SOURCE_GIVEN_TARGET_H_ + +#include + +#include "feature.h" + +using namespace std; + +class TranslationTable; + +class MaxLexSourceGivenTarget : public Feature { + public: + MaxLexSourceGivenTarget(shared_ptr table); + + double Score(const FeatureContext& context) const; + + string GetName() const; + + private: + shared_ptr table; +}; + +#endif diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc new file mode 100644 index 00000000..d82182fe --- /dev/null +++ b/extractor/features/max_lex_target_given_source.cc @@ -0,0 +1,30 @@ +#include "max_lex_target_given_source.h" + +#include + +#include "../translation_table.h" + +MaxLexTargetGivenSource::MaxLexTargetGivenSource( + shared_ptr table) : + table(table) {} + +double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { + // TODO(pauldb): Add NULL to source_words, after fixing translation table. + vector source_words = context.source_phrase.GetWords(); + vector target_words = context.target_phrase.GetWords(); + + double score = 0; + for (string target_word: target_words) { + double max_score = 0; + for (string source_word: source_words) { + max_score = max(max_score, + table->GetTargetGivenSourceScore(source_word, target_word)); + } + score += max_score > 0 ? -log10(max_score) : MAX_SCORE; + } + return score; +} + +string MaxLexTargetGivenSource::GetName() const { + return "MaxLexEGivenF"; +} diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h new file mode 100644 index 00000000..9585ff04 --- /dev/null +++ b/extractor/features/max_lex_target_given_source.h @@ -0,0 +1,24 @@ +#ifndef _MAX_LEX_TARGET_GIVEN_SOURCE_H_ +#define _MAX_LEX_TARGET_GIVEN_SOURCE_H_ + +#include + +#include "feature.h" + +using namespace std; + +class TranslationTable; + +class MaxLexTargetGivenSource : public Feature { + public: + MaxLexTargetGivenSource(shared_ptr table); + + double Score(const FeatureContext& context) const; + + string GetName() const; + + private: + shared_ptr table; +}; + +#endif diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc new file mode 100644 index 00000000..c8124cfb --- /dev/null +++ b/extractor/features/sample_source_count.cc @@ -0,0 +1,11 @@ +#include "sample_source_count.h" + +#include + +double SampleSourceCount::Score(const FeatureContext& context) const { + return log10(1 + context.sample_source_count); +} + +string SampleSourceCount::GetName() const { + return "SampleCountF"; +} diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h new file mode 100644 index 00000000..62d236c8 --- /dev/null +++ b/extractor/features/sample_source_count.h @@ -0,0 +1,13 @@ +#ifndef _SAMPLE_SOURCE_COUNT_H_ +#define _SAMPLE_SOURCE_COUNT_H_ + +#include "feature.h" + +class SampleSourceCount : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc new file mode 100644 index 00000000..748413c3 --- /dev/null +++ b/extractor/features/target_given_source_coherent.cc @@ -0,0 +1,12 @@ +#include "target_given_source_coherent.h" + +#include + +double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { + double prob = context.pair_count / context.sample_source_count; + return prob > 0 ? -log10(prob) : MAX_SCORE; +} + +string TargetGivenSourceCoherent::GetName() const { + return "EGivenFCoherent"; +} diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h new file mode 100644 index 00000000..09c8edb1 --- /dev/null +++ b/extractor/features/target_given_source_coherent.h @@ -0,0 +1,13 @@ +#ifndef _TARGET_GIVEN_SOURCE_COHERENT_H_ +#define _TARGET_GIVEN_SOURCE_COHERENT_H_ + +#include "feature.h" + +class TargetGivenSourceCoherent : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/grammar.cc b/extractor/grammar.cc new file mode 100644 index 00000000..79a0541d --- /dev/null +++ b/extractor/grammar.cc @@ -0,0 +1,24 @@ +#include "grammar.h" + +#include "rule.h" + +Grammar::Grammar(const vector& rules, + const vector& feature_names) : + rules(rules), feature_names(feature_names) {} + +ostream& operator<<(ostream& os, const Grammar& grammar) { + for (Rule rule: grammar.rules) { + os << "[X] ||| " << rule.source_phrase << " ||| " + << rule.target_phrase << " |||"; + for (size_t i = 0; i < rule.scores.size(); ++i) { + os << " " << grammar.feature_names[i] << "=" << rule.scores[i]; + } + os << " |||"; + for (auto link: rule.alignment) { + os << " " << link.first << "-" << link.second; + } + os << endl; + } + + return os; +} diff --git a/extractor/grammar.h b/extractor/grammar.h new file mode 100644 index 00000000..db15fa7e --- /dev/null +++ b/extractor/grammar.h @@ -0,0 +1,23 @@ +#ifndef _GRAMMAR_H_ +#define _GRAMMAR_H_ + +#include +#include +#include + +using namespace std; + +class Rule; + +class Grammar { + public: + Grammar(const vector& rules, const vector& feature_names); + + friend ostream& operator<<(ostream& os, const Grammar& grammar); + + private: + vector rules; + vector feature_names; +}; + +#endif diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 3014c2e9..15268165 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -4,6 +4,10 @@ #include #include +#include "grammar.h" +#include "rule.h" +#include "vocabulary.h" + using namespace std; vector Tokenize(const string& sentence) { @@ -22,18 +26,20 @@ vector Tokenize(const string& sentence) { GrammarExtractor::GrammarExtractor( shared_ptr source_suffix_array, shared_ptr target_data_array, - const Alignment& alignment, const Precomputation& precomputation, - int min_gap_size, int max_rule_span, int max_nonterminals, - int max_rule_symbols, bool use_baeza_yates) : + shared_ptr alignment, shared_ptr precomputation, + shared_ptr scorer, int min_gap_size, int max_rule_span, + int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_baeza_yates, bool require_tight_phrases) : vocabulary(make_shared()), rule_factory(source_suffix_array, target_data_array, alignment, - vocabulary, precomputation, min_gap_size, max_rule_span, - max_nonterminals, max_rule_symbols, use_baeza_yates) {} + vocabulary, precomputation, scorer, min_gap_size, max_rule_span, + max_nonterminals, max_rule_symbols, max_samples, use_baeza_yates, + require_tight_phrases) {} -void GrammarExtractor::GetGrammar(const string& sentence) { +Grammar GrammarExtractor::GetGrammar(const string& sentence) { vector words = Tokenize(sentence); vector word_ids = AnnotateWords(words); - rule_factory.GetGrammar(word_ids); + return rule_factory.GetGrammar(word_ids); } vector GrammarExtractor::AnnotateWords(const vector& words) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 05e153fc..243f33cf 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -5,29 +5,34 @@ #include #include "rule_factory.h" -#include "vocabulary.h" using namespace std; class Alignment; class DataArray; +class Grammar; class Precomputation; +class Rule; class SuffixArray; +class Vocabulary; class GrammarExtractor { public: GrammarExtractor( shared_ptr source_suffix_array, shared_ptr target_data_array, - const Alignment& alignment, - const Precomputation& precomputation, + shared_ptr alignment, + shared_ptr precomputation, + shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, - bool use_baeza_yates); + int max_samples, + bool use_baeza_yates, + bool require_tight_phrases); - void GetGrammar(const string& sentence); + Grammar GetGrammar(const string& sentence); private: vector AnnotateWords(const vector& words); diff --git a/extractor/intersector.cc b/extractor/intersector.cc index 9d9b54c0..b53479af 100644 --- a/extractor/intersector.cc +++ b/extractor/intersector.cc @@ -10,35 +10,51 @@ #include "vocabulary.h" Intersector::Intersector(shared_ptr vocabulary, - const Precomputation& precomputation, + shared_ptr precomputation, shared_ptr suffix_array, shared_ptr comparator, bool use_baeza_yates) : vocabulary(vocabulary), suffix_array(suffix_array), use_baeza_yates(use_baeza_yates) { - linear_merger = make_shared( - vocabulary, suffix_array->GetData(), comparator); + shared_ptr data_array = suffix_array->GetData(); + linear_merger = make_shared(vocabulary, data_array, comparator); binary_search_merger = make_shared( - vocabulary, linear_merger, suffix_array->GetData(), comparator); + vocabulary, linear_merger, data_array, comparator); + ConvertIndexes(precomputation, data_array); +} - shared_ptr source_data_array = suffix_array->GetData(); +Intersector::Intersector(shared_ptr vocabulary, + shared_ptr precomputation, + shared_ptr suffix_array, + shared_ptr linear_merger, + shared_ptr binary_search_merger, + bool use_baeza_yates) : + vocabulary(vocabulary), + suffix_array(suffix_array), + linear_merger(linear_merger), + binary_search_merger(binary_search_merger), + use_baeza_yates(use_baeza_yates) { + ConvertIndexes(precomputation, suffix_array->GetData()); +} - const Index& precomputed_index = precomputation.GetInvertedIndex(); +void Intersector::ConvertIndexes(shared_ptr precomputation, + shared_ptr data_array) { + const Index& precomputed_index = precomputation->GetInvertedIndex(); for (pair, vector > entry: precomputed_index) { - vector phrase = Convert(entry.first, source_data_array); + vector phrase = ConvertPhrase(entry.first, data_array); inverted_index[phrase] = entry.second; } - const Index& precomputed_collocations = precomputation.GetCollocations(); + const Index& precomputed_collocations = precomputation->GetCollocations(); for (pair, vector > entry: precomputed_collocations) { - vector phrase = Convert(entry.first, source_data_array); + vector phrase = ConvertPhrase(entry.first, data_array); collocations[phrase] = entry.second; } } -vector Intersector::Convert( - const vector& old_phrase, shared_ptr source_data_array) { +vector Intersector::ConvertPhrase(const vector& old_phrase, + shared_ptr data_array) { vector new_phrase; new_phrase.reserve(old_phrase.size()); @@ -49,7 +65,7 @@ vector Intersector::Convert( new_phrase.push_back(vocabulary->GetNonterminalIndex(arity)); } else { new_phrase.push_back( - vocabulary->GetTerminalIndex(source_data_array->GetWord(word_id))); + vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); } } @@ -70,8 +86,7 @@ PhraseLocation Intersector::Intersect( && vocabulary->IsTerminal(symbols.back())); if (collocations.count(symbols)) { - return PhraseLocation(make_shared >(collocations[symbols]), - phrase.Arity()); + return PhraseLocation(collocations[symbols], phrase.Arity() + 1); } vector locations; @@ -91,19 +106,18 @@ PhraseLocation Intersector::Intersect( prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); } - return PhraseLocation(shared_ptr >(new vector(locations)), - phrase.Arity() + 1); + return PhraseLocation(locations, phrase.Arity() + 1); } void Intersector::ExtendPhraseLocation( const Phrase& phrase, PhraseLocation& phrase_location) { int low = phrase_location.sa_low, high = phrase_location.sa_high; - if (phrase.Arity() || phrase_location.num_subpatterns || - phrase_location.IsEmpty()) { + if (phrase_location.matchings != NULL) { return; } phrase_location.num_subpatterns = 1; + phrase_location.sa_low = phrase_location.sa_high = 0; vector symbols = phrase.Get(); if (inverted_index.count(symbols)) { diff --git a/extractor/intersector.h b/extractor/intersector.h index 874ffc1b..f023cc96 100644 --- a/extractor/intersector.h +++ b/extractor/intersector.h @@ -13,8 +13,8 @@ using namespace std; using namespace tr1; -typedef boost::hash > vector_hash; -typedef unordered_map, vector, vector_hash> Index; +typedef boost::hash > VectorHash; +typedef unordered_map, vector, VectorHash> Index; class DataArray; class MatchingComparator; @@ -28,19 +28,31 @@ class Intersector { public: Intersector( shared_ptr vocabulary, - const Precomputation& precomputaiton, + shared_ptr precomputation, shared_ptr source_suffix_array, shared_ptr comparator, bool use_baeza_yates); + // For testing. + Intersector( + shared_ptr vocabulary, + shared_ptr precomputation, + shared_ptr source_suffix_array, + shared_ptr linear_merger, + shared_ptr binary_search_merger, + bool use_baeza_yates); + PhraseLocation Intersect( const Phrase& prefix, PhraseLocation& prefix_location, const Phrase& suffix, PhraseLocation& suffix_location, const Phrase& phrase); private: - vector Convert(const vector& old_phrase, - shared_ptr source_data_array); + void ConvertIndexes(shared_ptr precomputation, + shared_ptr data_array); + + vector ConvertPhrase(const vector& old_phrase, + shared_ptr data_array); void ExtendPhraseLocation(const Phrase& phrase, PhraseLocation& phrase_location); diff --git a/extractor/intersector_test.cc b/extractor/intersector_test.cc new file mode 100644 index 00000000..a3756902 --- /dev/null +++ b/extractor/intersector_test.cc @@ -0,0 +1,193 @@ +#include + +#include +#include + +#include "intersector.h" +#include "mocks/mock_binary_search_merger.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_linear_merger.h" +#include "mocks/mock_precomputation.h" +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IntersectorTest : public Test { + protected: + virtual void SetUp() { + data = {2, 3, 4, 3, 4, 3}; + vector words = {"a", "b", "c", "b", "c", "b"}; + data_array = make_shared(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + + vocabulary = make_shared(); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, GetWord(data[i])) + .WillRepeatedly(Return(words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i])) + .WillRepeatedly(Return(data[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(data[i])) + .WillRepeatedly(Return(words[i])); + } + + vector suffixes = {0, 1, 3, 5, 2, 4, 6}; + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(data_array)); + EXPECT_CALL(*suffix_array, GetSize()) + .WillRepeatedly(Return(suffixes.size())); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)) + .WillRepeatedly(Return(suffixes[i])); + } + + vector key = {2, -1, 4}; + vector values = {0, 2}; + collocations[key] = values; + precomputation = make_shared(); + EXPECT_CALL(*precomputation, GetInvertedIndex()) + .WillRepeatedly(ReturnRef(inverted_index)); + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + + linear_merger = make_shared(); + binary_search_merger = make_shared(); + + phrase_builder = make_shared(vocabulary); + } + + Index inverted_index; + Index collocations; + vector data; + shared_ptr vocabulary; + shared_ptr data_array; + shared_ptr suffix_array; + shared_ptr precomputation; + shared_ptr linear_merger; + shared_ptr binary_search_merger; + shared_ptr phrase_builder; + shared_ptr intersector; +}; + +TEST_F(IntersectorTest, TestCachedCollocation) { + intersector = make_shared(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + vector prefix_symbols = {2, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector symbols = {2, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(0, 1), suffix_locs(2, 3); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + + vector expected_locs = {0, 2}; + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(PhraseLocation(0, 1), prefix_locs); + EXPECT_EQ(PhraseLocation(2, 3), suffix_locs); +} + +TEST_F(IntersectorTest, TestLinearMergeaXb) { + vector prefix_symbols = {3, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector symbols = {3, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6); + + vector ex_prefix_locs = {1, 3, 5}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); + vector ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 1); + + vector expected_locs = {1, 4}; + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(extended_prefix_locs, prefix_locs); + EXPECT_EQ(extended_suffix_locs, suffix_locs); +} + +TEST_F(IntersectorTest, TestBinarySearchMergeaXb) { + vector prefix_symbols = {3, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector symbols = {3, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6); + + vector ex_prefix_locs = {1, 3, 5}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); + vector ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 1); + + vector expected_locs = {1, 4}; + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, true); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(extended_prefix_locs, prefix_locs); + EXPECT_EQ(extended_suffix_locs, suffix_locs); +} + +TEST_F(IntersectorTest, TestMergeaXbXc) { + vector prefix_symbols = {2, -1, 4, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector suffix_symbols = {-1, 4, -1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector symbols = {2, -1, 4, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + + vector ex_prefix_locs = {0, 2, 0, 4}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 2); + vector ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 2); + vector expected_locs = {0, 2, 4}; + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + PhraseLocation result = intersector->Intersect( + prefix, extended_prefix_locs, suffix, extended_suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 3); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(ex_prefix_locs, *extended_prefix_locs.matchings); + EXPECT_EQ(ex_suffix_locs, *extended_suffix_locs.matchings); +} + +} // namespace diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc index 59e5f34c..666f8d87 100644 --- a/extractor/linear_merger.cc +++ b/extractor/linear_merger.cc @@ -14,6 +14,8 @@ LinearMerger::LinearMerger(shared_ptr vocabulary, shared_ptr comparator) : vocabulary(vocabulary), data_array(data_array), comparator(comparator) {} +LinearMerger::LinearMerger() {} + LinearMerger::~LinearMerger() {} void LinearMerger::Merge( diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h index 7bfb9246..6a69b804 100644 --- a/extractor/linear_merger.h +++ b/extractor/linear_merger.h @@ -26,6 +26,9 @@ class LinearMerger { vector::iterator suffix_start, vector::iterator suffix_end, int prefix_subpatterns, int suffix_subpatterns) const; + protected: + LinearMerger(); + private: shared_ptr vocabulary; shared_ptr data_array; diff --git a/extractor/mocks/mock_binary_search_merger.h b/extractor/mocks/mock_binary_search_merger.h new file mode 100644 index 00000000..e1375ee3 --- /dev/null +++ b/extractor/mocks/mock_binary_search_merger.h @@ -0,0 +1,15 @@ +#include + +#include + +#include "../binary_search_merger.h" +#include "../phrase.h" + +using namespace std; + +class MockBinarySearchMerger: public BinarySearchMerger { + public: + MOCK_CONST_METHOD9(Merge, void(vector&, const Phrase&, const Phrase&, + vector::iterator, vector::iterator, vector::iterator, + vector::iterator, int, int)); +}; diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index cda8f7a6..54497cf5 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -10,5 +10,6 @@ class MockDataArray : public DataArray { MOCK_CONST_METHOD0(GetVocabularySize, int()); MOCK_CONST_METHOD1(HasWord, bool(const string& word)); MOCK_CONST_METHOD1(GetWordId, int(const string& word)); + MOCK_CONST_METHOD1(GetWord, string(int word_id)); MOCK_CONST_METHOD1(GetSentenceId, int(int position)); }; diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h index 0defa88a..82243428 100644 --- a/extractor/mocks/mock_linear_merger.h +++ b/extractor/mocks/mock_linear_merger.h @@ -2,19 +2,13 @@ #include -#include "linear_merger.h" -#include "phrase.h" +#include "../linear_merger.h" +#include "../phrase.h" using namespace std; class MockLinearMerger: public LinearMerger { public: - MockLinearMerger(shared_ptr vocabulary, - shared_ptr data_array, - shared_ptr comparator) : - LinearMerger(vocabulary, data_array, comparator) {} - - MOCK_CONST_METHOD9(Merge, void(vector&, const Phrase&, const Phrase&, vector::iterator, vector::iterator, vector::iterator, vector::iterator, int, int)); diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h new file mode 100644 index 00000000..987bdb2f --- /dev/null +++ b/extractor/mocks/mock_precomputation.h @@ -0,0 +1,9 @@ +#include + +#include "../precomputation.h" + +class MockPrecomputation : public Precomputation { + public: + MOCK_CONST_METHOD0(GetInvertedIndex, const Index&()); + MOCK_CONST_METHOD0(GetCollocations, const Index&()); +}; diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h index 38d8bad6..11a3a443 100644 --- a/extractor/mocks/mock_suffix_array.h +++ b/extractor/mocks/mock_suffix_array.h @@ -1,5 +1,6 @@ #include +#include #include #include "../data_array.h" @@ -10,8 +11,9 @@ using namespace std; class MockSuffixArray : public SuffixArray { public: - MockSuffixArray() : SuffixArray(make_shared()) {} - MOCK_CONST_METHOD0(GetSize, int()); + MOCK_CONST_METHOD0(GetData, shared_ptr()); + MOCK_CONST_METHOD0(BuildLCPArray, vector()); + MOCK_CONST_METHOD1(GetSuffix, int(int)); MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int)); }; diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h index 06dea10f..e5c191f5 100644 --- a/extractor/mocks/mock_vocabulary.h +++ b/extractor/mocks/mock_vocabulary.h @@ -5,4 +5,5 @@ class MockVocabulary : public Vocabulary { public: MOCK_METHOD1(GetTerminalValue, string(int word_id)); + MOCK_METHOD1(GetTerminalIndex, int(const string& word)); }; diff --git a/extractor/phrase.cc b/extractor/phrase.cc index f9bd9908..6dc242db 100644 --- a/extractor/phrase.cc +++ b/extractor/phrase.cc @@ -23,3 +23,32 @@ vector Phrase::Get() const { int Phrase::GetSymbol(int position) const { return symbols[position]; } + +int Phrase::GetNumSymbols() const { + return symbols.size(); +} + +vector Phrase::GetWords() const { + return words; +} + +int Phrase::operator<(const Phrase& other) const { + return symbols < other.symbols; +} + +ostream& operator<<(ostream& os, const Phrase& phrase) { + int current_word = 0; + for (size_t i = 0; i < phrase.symbols.size(); ++i) { + if (phrase.symbols[i] < 0) { + os << "[X," << -phrase.symbols[i] << "]"; + } else { + os << phrase.words[current_word]; + ++current_word; + } + + if (i + 1 < phrase.symbols.size()) { + os << " "; + } + } + return os; +} diff --git a/extractor/phrase.h b/extractor/phrase.h index 5a5124d9..f40a8169 100644 --- a/extractor/phrase.h +++ b/extractor/phrase.h @@ -1,6 +1,7 @@ #ifndef _PHRASE_H_ #define _PHRASE_H_ +#include #include #include @@ -20,6 +21,17 @@ class Phrase { int GetSymbol(int position) const; + //TODO(pauldb): Unit test this method. + int GetNumSymbols() const; + + //TODO(pauldb): Add unit tests. + vector GetWords() const; + + //TODO(pauldb): Add unit tests. + int operator<(const Phrase& other) const; + + friend ostream& operator<<(ostream& os, const Phrase& phrase); + private: vector symbols; vector var_pos; diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc index 7f3447e5..c4e0c2ed 100644 --- a/extractor/phrase_builder.cc +++ b/extractor/phrase_builder.cc @@ -19,3 +19,27 @@ Phrase PhraseBuilder::Build(const vector& symbols) { } return phrase; } + +Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) { + vector symbols = phrase.Get(); + int num_nonterminals = 0; + if (start_x) { + num_nonterminals = 1; + symbols.insert(symbols.begin(), + vocabulary->GetNonterminalIndex(num_nonterminals)); + } + + for (size_t i = start_x; i < symbols.size(); ++i) { + if (vocabulary->IsTerminal(symbols[i])) { + ++num_nonterminals; + symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals); + } + } + + if (end_x) { + ++num_nonterminals; + symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals)); + } + + return Build(symbols); +} diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h index f01cb23b..a49af457 100644 --- a/extractor/phrase_builder.h +++ b/extractor/phrase_builder.h @@ -15,6 +15,8 @@ class PhraseBuilder { Phrase Build(const vector& symbols); + Phrase Extend(const Phrase& phrase, bool start_x, bool end_x); + private: shared_ptr vocabulary; }; diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index b5b68549..984407c5 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,16 +1,12 @@ #include "phrase_location.h" -#include - PhraseLocation::PhraseLocation(int sa_low, int sa_high) : - sa_low(sa_low), sa_high(sa_high), - matchings(shared_ptr >()), - num_subpatterns(0) {} + sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {} -PhraseLocation::PhraseLocation(shared_ptr > matchings, +PhraseLocation::PhraseLocation(const vector& matchings, int num_subpatterns) : sa_high(0), sa_low(0), - matchings(matchings), + matchings(make_shared >(matchings)), num_subpatterns(num_subpatterns) {} bool PhraseLocation::IsEmpty() { diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h index 96004b33..e04d8628 100644 --- a/extractor/phrase_location.h +++ b/extractor/phrase_location.h @@ -9,7 +9,7 @@ using namespace std; struct PhraseLocation { PhraseLocation(int sa_low = -1, int sa_high = -1); - PhraseLocation(shared_ptr > matchings, int num_subpatterns); + PhraseLocation(const vector& matchings, int num_subpatterns); bool IsEmpty(); diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 97a70554..9a167976 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -2,11 +2,6 @@ #include #include -#include -#include -#include - -#include #include "data_array.h" #include "suffix_array.h" @@ -26,9 +21,8 @@ Precomputation::Precomputation( suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, min_frequency); - unordered_set, boost::hash > > frequent_patterns_set; - unordered_set, boost::hash > > - super_frequent_patterns_set; + unordered_set, VectorHash> frequent_patterns_set; + unordered_set, VectorHash> super_frequent_patterns_set; for (size_t i = 0; i < frequent_patterns.size(); ++i) { frequent_patterns_set.insert(frequent_patterns[i]); if (i < num_super_frequent_patterns) { @@ -60,6 +54,10 @@ Precomputation::Precomputation( } } +Precomputation::Precomputation() {} + +Precomputation::~Precomputation() {} + vector > Precomputation::FindMostFrequentPatterns( shared_ptr suffix_array, const vector& data, int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { @@ -107,7 +105,7 @@ void Precomputation::AddCollocations( } if (start2 - start1 - size1 >= min_gap_size - && start2 + size2 - size1 <= max_rule_span + && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { vector pattern(data.begin() + start1, data.begin() + start1 + size1); @@ -126,7 +124,7 @@ void Precomputation::AddCollocations( } if (start3 - start2 - size2 >= min_gap_size - && start3 + size3 - size1 <= max_rule_span + && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { pattern.insert(pattern.end(), data.begin() + start3, diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 0d1b269f..428505d8 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -16,8 +16,8 @@ using namespace tr1; class SuffixArray; -typedef boost::hash > vector_hash; -typedef unordered_map, vector, vector_hash> Index; +typedef boost::hash > VectorHash; +typedef unordered_map, vector, VectorHash> Index; class Precomputation { public: @@ -27,20 +27,25 @@ class Precomputation { int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); + virtual ~Precomputation(); + void WriteBinary(const fs::path& filepath) const; - const Index& GetInvertedIndex() const; - const Index& GetCollocations() const; + virtual const Index& GetInvertedIndex() const; + virtual const Index& GetCollocations() const; static int NON_TERMINAL; + protected: + Precomputation(); + private: vector > FindMostFrequentPatterns( shared_ptr suffix_array, const vector& data, int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency); void AddCollocations( - const vector >& matchings, const vector& data, + const vector >& matchings, const vector& data, int max_rule_span, int min_gap_size, int max_rule_symbols); void AddStartPositions(vector& positions, int pos1, int pos2); void AddStartPositions(vector& positions, int pos1, int pos2, int pos3); diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc new file mode 100644 index 00000000..9edb29db --- /dev/null +++ b/extractor/precomputation_test.cc @@ -0,0 +1,138 @@ +#include + +#include +#include + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "precomputation.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class PrecomputationTest : public Test { + protected: + virtual void SetUp() { + data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; + data_array = make_shared(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + + vector suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; + vector lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, + GetSuffix(i)).WillRepeatedly(Return(suffixes[i])); + } + EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); + } + + vector data; + shared_ptr data_array; + shared_ptr suffix_array; +}; + +TEST_F(PrecomputationTest, TestInvertedIndex) { + Precomputation precomputation(suffix_array, 100, 3, 10, 5, 1, 4, 2); + Index inverted_index = precomputation.GetInvertedIndex(); + + EXPECT_EQ(8, inverted_index.size()); + vector key = {2}; + vector expected_value = {1, 5, 8, 11}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {3}; + expected_value = {2, 6, 9}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {4}; + expected_value = {0, 10}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {5}; + expected_value = {3, 7}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {4, 2}; + expected_value = {0, 10}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {2, 3}; + expected_value = {1, 5, 8}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {3, 5}; + expected_value = {2, 6}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {2, 3, 5}; + expected_value = {1, 5}; + EXPECT_EQ(expected_value, inverted_index[key]); + + key = {2, 4}; + EXPECT_EQ(0, inverted_index.count(key)); +} + +TEST_F(PrecomputationTest, TestCollocations) { + Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); + Index collocations = precomputation.GetCollocations(); + + EXPECT_EQ(-1, precomputation.NON_TERMINAL); + vector key = {2, 3, -1, 2}; + vector expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, 3, -1, 2, 3}; + expected_value = {1, 5, 1, 8, 5, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, 3, -1, 3}; + expected_value = {1, 6, 1, 9, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2}; + expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3}; + expected_value = {2, 6, 2, 9, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, 3}; + expected_value = {2, 5, 2, 8, 6, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2}; + expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2, 3}; + expected_value = {1, 5, 1, 8, 5, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3}; + expected_value = {1, 6, 1, 9, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + + key = {2, -1, 2, -1, 2}; + expected_value = {1, 5, 8, 5, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2, -1, 3}; + expected_value = {1, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3, -1, 2}; + expected_value = {1, 6, 8, 5, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3, -1, 3}; + expected_value = {1, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, -1, 2}; + expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, -1, 3}; + expected_value = {2, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3, -1, 2}; + expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3, -1, 3}; + expected_value = {2, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + + // Exceeds max_rule_symbols. + key = {2, -1, 2, -1, 2, 3}; + EXPECT_EQ(0, collocations.count(key)); + // Contains non frequent pattern. + key = {2, -1, 5}; + EXPECT_EQ(0, collocations.count(key)); +} + +} // namespace diff --git a/extractor/rule.cc b/extractor/rule.cc new file mode 100644 index 00000000..9c7ac9b5 --- /dev/null +++ b/extractor/rule.cc @@ -0,0 +1,10 @@ +#include "rule.h" + +Rule::Rule(const Phrase& source_phrase, + const Phrase& target_phrase, + const vector& scores, + const vector >& alignment) : + source_phrase(source_phrase), + target_phrase(target_phrase), + scores(scores), + alignment(alignment) {} diff --git a/extractor/rule.h b/extractor/rule.h new file mode 100644 index 00000000..64ff8794 --- /dev/null +++ b/extractor/rule.h @@ -0,0 +1,20 @@ +#ifndef _RULE_H_ +#define _RULE_H_ + +#include + +#include "phrase.h" + +using namespace std; + +struct Rule { + Rule(const Phrase& source_phrase, const Phrase& target_phrase, + const vector& scores, const vector >& alignment); + + Phrase source_phrase; + Phrase target_phrase; + vector scores; + vector > alignment; +}; + +#endif diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc index 48b39b63..9460020f 100644 --- a/extractor/rule_extractor.cc +++ b/extractor/rule_extractor.cc @@ -1,10 +1,679 @@ #include "rule_extractor.h" +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "features/feature.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule.h" +#include "scorer.h" +#include "vocabulary.h" + +using namespace std; +using namespace tr1; + RuleExtractor::RuleExtractor( - shared_ptr source_suffix_array, + shared_ptr source_data_array, shared_ptr target_data_array, - const Alignment& alingment) { + shared_ptr alignment, + shared_ptr phrase_builder, + shared_ptr scorer, + shared_ptr vocabulary, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + source_data_array(source_data_array), + target_data_array(target_data_array), + alignment(alignment), + phrase_builder(phrase_builder), + scorer(scorer), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), + require_aligned_terminal(require_aligned_terminal), + require_aligned_chunks(require_aligned_chunks), + require_tight_phrases(require_tight_phrases) {} + +vector RuleExtractor::ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const { + int num_subpatterns = location.num_subpatterns; + vector matchings = *location.matchings; + + map source_phrase_counter; + map > > alignments_counter; + for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { + vector matching(i, i + num_subpatterns); + vector extracts = ExtractAlignments(phrase, matching); + + for (Extract e: extracts) { + source_phrase_counter[e.source_phrase] += e.pairs_count; + alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1; + } + } + + vector rules; + for (auto source_phrase_entry: alignments_counter) { + Phrase source_phrase = source_phrase_entry.first; + for (auto target_phrase_entry: source_phrase_entry.second) { + Phrase target_phrase = target_phrase_entry.first; + + int max_locations = 0, num_locations = 0; + PhraseAlignment most_frequent_alignment; + for (auto alignment_entry: target_phrase_entry.second) { + num_locations += alignment_entry.second; + if (alignment_entry.second > max_locations) { + most_frequent_alignment = alignment_entry.first; + max_locations = alignment_entry.second; + } + } + + FeatureContext context(source_phrase, target_phrase, + source_phrase_counter[source_phrase], num_locations); + vector scores = scorer->Score(context); + rules.push_back(Rule(source_phrase, target_phrase, scores, + most_frequent_alignment)); + } + } + return rules; +} + +vector RuleExtractor::ExtractAlignments( + const Phrase& phrase, const vector& matching) const { + vector extracts; + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + vector source_low, source_high, target_low, target_high; + GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id); + + int num_subpatterns = matching.size(); + vector chunklen(num_subpatterns); + for (size_t i = 0; i < num_subpatterns; ++i) { + chunklen[i] = phrase.GetChunkLen(i); + } + + if (!CheckAlignedTerminals(matching, chunklen, source_low) || + !CheckTightPhrases(matching, chunklen, source_low)) { + return extracts; + } + + int source_back_low = -1, source_back_high = -1; + int source_phrase_low = matching[0] - source_sent_start; + int source_phrase_high = matching.back() + chunklen.back() - source_sent_start; + int target_phrase_low = -1, target_phrase_high = -1; + if (!FindFixPoint(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high, + target_low, target_high, source_back_low, source_back_high, + sentence_id, min_gap_size, 0, + max_nonterminals - matching.size() + 1, 1, 1, false)) { + return extracts; + } + + bool met_constraints = true; + int num_symbols = phrase.GetNumSymbols(); + vector > source_gaps, target_gaps; + if (!CheckGaps(source_gaps, target_gaps, matching, chunklen, source_low, + source_high, target_low, target_high, source_phrase_low, + source_phrase_high, source_back_low, source_back_high, + num_symbols, met_constraints)) { + return extracts; + } + + bool start_x = source_back_low != source_phrase_low; + bool end_x = source_back_high != source_phrase_high; + Phrase source_phrase = phrase_builder->Extend(phrase, start_x, end_x); + if (met_constraints) { + AddExtracts(extracts, source_phrase, target_gaps, target_low, + target_phrase_low, target_phrase_high, sentence_id); + } + + if (source_gaps.size() >= max_nonterminals || + source_phrase.GetNumSymbols() >= max_rule_symbols || + source_back_high - source_back_low + min_gap_size > max_rule_span) { + // Cannot add any more nonterminals. + return extracts; + } + + for (int i = 0; i < 2; ++i) { + for (int j = 1 - i; j < 2; ++j) { + AddNonterminalExtremities(extracts, source_phrase, source_phrase_low, + source_phrase_high, source_back_low, source_back_high, source_low, + source_high, target_low, target_high, target_gaps, sentence_id, i, j); + } + } + + return extracts; +} + +void RuleExtractor::GetLinksSpans( + vector& source_low, vector& source_high, + vector& target_low, vector& target_high, int sentence_id) const { + // Ignore end of line markers. + int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - + source_data_array->GetSentenceStart(sentence_id) - 1; + int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - + target_data_array->GetSentenceStart(sentence_id) - 1; + source_low = vector(source_sent_len, -1); + source_high = vector(source_sent_len, -1); + + // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we + // can speed it up. + target_low = vector(target_sent_len, -1); + target_high = vector(target_sent_len, -1); + const vector >& links = alignment->GetLinks(sentence_id); + for (auto link: links) { + if (source_low[link.first] == -1 || source_low[link.first] > link.second) { + source_low[link.first] = link.second; + } + source_high[link.first] = max(source_high[link.first], link.second + 1); + + if (target_low[link.second] == -1 || target_low[link.second] > link.first) { + target_low[link.second] = link.first; + } + target_high[link.second] = max(target_high[link.second], link.first + 1); + } +} + +bool RuleExtractor::CheckAlignedTerminals(const vector& matching, + const vector& chunklen, + const vector& source_low) const { + if (!require_aligned_terminal) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + int num_aligned_chunks = 0; + for (size_t i = 0; i < chunklen.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + int sent_index = matching[i] - source_sent_start + j; + if (source_low[sent_index] != -1) { + ++num_aligned_chunks; + break; + } + } + } + + if (num_aligned_chunks == 0) { + return false; + } + + return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); +} + +bool RuleExtractor::CheckTightPhrases(const vector& matching, + const vector& chunklen, + const vector& source_low) const { + if (!require_tight_phrases) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - 1 - source_sent_start; + if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { + return false; + } + } + + return true; +} + +bool RuleExtractor::FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector& source_low, const vector& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector& target_low, const vector& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, int max_low_x, int max_high_x, + bool allow_arbitrary_expansion) const { + int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - + source_data_array->GetSentenceStart(sentence_id) - 1; + int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - + target_data_array->GetSentenceStart(sentence_id) - 1; + + int prev_target_low = target_phrase_low; + int prev_target_high = target_phrase_high; + FindProjection(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high); + + if (target_phrase_low == -1) { + // TODO(pauldb): Low priority corner case inherited from Adam's code: + // If w is unaligned, but we don't require aligned terminals, returning an + // error here prevents the extraction of the allowed rule + // X -> X_1 w X_2 / X_1 X_2 + return false; + } + + if (prev_target_low != -1 && target_phrase_low != prev_target_low) { + if (prev_target_low - target_phrase_low < min_target_gap_size) { + target_phrase_low = prev_target_low - min_target_gap_size; + if (target_phrase_low < 0) { + return false; + } + } + } + + if (prev_target_high != -1 && target_phrase_high != prev_target_high) { + if (target_phrase_high - prev_target_high < min_target_gap_size) { + target_phrase_high = prev_target_high + min_target_gap_size; + if (target_phrase_high > target_sent_len) { + return false; + } + } + } + + if (target_phrase_high - target_phrase_low > max_rule_span) { + return false; + } + + source_back_low = source_back_high = -1; + FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, + source_back_low, source_back_high); + int new_x = 0, new_low_x = 0, new_high_x = 0; + + while (true) { + source_back_low = min(source_back_low, source_phrase_low); + source_back_high = max(source_back_high, source_phrase_high); + + if (source_back_low == source_phrase_low && + source_back_high == source_phrase_high) { + return true; + } + + if (new_low_x >= max_low_x && source_back_low < source_phrase_low) { + // Extension on the left side not allowed. + return false; + } + if (new_high_x >= max_high_x && source_back_high > source_phrase_high) { + // Extension on the right side not allowed. + return false; + } + + // Extend left side. + if (source_back_low < source_phrase_low) { + if (new_x >= max_new_x) { + return false; + } + ++new_x; ++new_low_x; + if (source_phrase_low - source_back_low < min_source_gap_size) { + source_back_low = source_phrase_low - min_source_gap_size; + if (source_back_low < 0) { + return false; + } + } + } + + // Extend right side. + if (source_back_high > source_phrase_high) { + if (new_x >= max_new_x) { + return false; + } + ++new_x; ++new_high_x; + if (source_back_high - source_phrase_high < min_source_gap_size) { + source_back_high = source_phrase_high + min_source_gap_size; + if (source_back_high > source_sent_len) { + return false; + } + } + } + + if (source_back_high - source_back_low > max_rule_span) { + // Rule span too wide. + return false; + } + + prev_target_low = target_phrase_low; + prev_target_high = target_phrase_high; + FindProjection(source_back_low, source_phrase_low, source_low, source_high, + target_phrase_low, target_phrase_high); + FindProjection(source_phrase_high, source_back_high, source_low, + source_high, target_phrase_low, target_phrase_high); + if (prev_target_low == target_phrase_low && + prev_target_high == target_phrase_high) { + return true; + } + + if (!allow_arbitrary_expansion) { + // Arbitrary expansion not allowed. + return false; + } + if (target_phrase_high - target_phrase_low > max_rule_span) { + // Target side too wide. + return false; + } + + source_phrase_low = source_back_low; + source_phrase_high = source_back_high; + FindProjection(target_phrase_low, prev_target_low, target_low, target_high, + source_back_low, source_back_high); + FindProjection(prev_target_high, target_phrase_high, target_low, + target_high, source_back_low, source_back_high); + } + + return false; +} + +void RuleExtractor::FindProjection( + int source_phrase_low, int source_phrase_high, + const vector& source_low, const vector& source_high, + int& target_phrase_low, int& target_phrase_high) const { + for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { + if (source_low[i] != -1) { + if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { + target_phrase_low = source_low[i]; + } + target_phrase_high = max(target_phrase_high, source_high[i]); + } + } } -void RuleExtractor::ExtractRules() { +bool RuleExtractor::CheckGaps( + vector >& source_gaps, vector >& target_gaps, + const vector& matching, const vector& chunklen, + const vector& source_low, const vector& source_high, + const vector& target_low, const vector& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const { + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + if (source_back_low < source_phrase_low) { + source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_back_low] == -1 || + source_low[source_phrase_low - 1] == -1)) { + // Inside edges of preceding gap are not tight. + return false; + } + } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - source_sent_start; + source_gaps.push_back(make_pair(gap_start, gap_end)); + } + + if (source_phrase_high < source_back_high) { + source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_phrase_high] == -1 || + source_low[source_back_high - 1] == -1)) { + // Inside edges of following gap are not tight. + return false; + } + } else if (require_tight_phrases && + source_low[source_phrase_high - 1] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); + for (size_t i = 0; i < source_gaps.size(); ++i) { + if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, + source_high, target_gaps[i].first, target_gaps[i].second, + target_low, target_high, source_gaps[i].first, + source_gaps[i].second, sentence_id, 0, 0, 0, 0, 0, + false)) { + // Gap fails integrity check. + return false; + } + } + + return true; +} + +void RuleExtractor::AddExtracts( + vector& extracts, const Phrase& source_phrase, + const vector >& target_gaps, const vector& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const { + vector > target_phrases = ExtractTargetPhrases( + target_gaps, target_low, target_phrase_low, target_phrase_high, + sentence_id); + + if (target_phrases.size() > 0) { + double pairs_count = 1.0 / target_phrases.size(); + for (auto target_phrase: target_phrases) { + extracts.push_back(Extract(source_phrase, target_phrase.first, + pairs_count, target_phrase.second)); + } + } +} + +vector > RuleExtractor::ExtractTargetPhrases( + const vector >& target_gaps, const vector& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const { + int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) - + target_data_array->GetSentenceStart(sentence_id) - 1; + + vector target_gap_order(target_gaps.size()); + for (size_t i = 0; i < target_gap_order.size(); ++i) { + for (size_t j = 0; j < i; ++j) { + if (target_gaps[target_gap_order[j]] < target_gaps[i]) { + ++target_gap_order[i]; + } else { + ++target_gap_order[j]; + } + } + } + + int target_x_low = target_phrase_low, target_x_high = target_phrase_high; + if (!require_tight_phrases) { + while (target_x_low > 0 && + target_phrase_high - target_x_low < max_rule_span && + target_low[target_x_low - 1] == -1) { + --target_x_low; + } + while (target_x_high + 1 < target_sent_len && + target_x_high - target_phrase_low < max_rule_span && + target_low[target_x_high + 1] == -1) { + ++target_x_high; + } + } + + vector > gaps(target_gaps.size()); + for (size_t i = 0; i < gaps.size(); ++i) { + gaps[i] = target_gaps[target_gap_order[i]]; + if (!require_tight_phrases) { + while (gaps[i].first > target_x_low && + target_low[gaps[i].first] == -1) { + --gaps[i].first; + } + while (gaps[i].second < target_x_high && + target_low[gaps[i].second] == -1) { + ++gaps[i].second; + } + } + } + + vector > ranges(2 * gaps.size() + 2); + ranges.front() = make_pair(target_x_low, target_phrase_low); + ranges.back() = make_pair(target_phrase_high, target_x_high); + for (size_t i = 0; i < gaps.size(); ++i) { + ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[i].first); + ranges[i * 2 + 2] = make_pair(target_gaps[i].second, gaps[i].second); + } + + vector > target_phrases; + vector subpatterns(ranges.size()); + GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, + target_phrase_low, target_phrase_high, sentence_id); + return target_phrases; +} + +void RuleExtractor::GeneratePhrases( + vector >& target_phrases, + const vector >& ranges, int index, vector& subpatterns, + const vector& target_gap_order, int target_phrase_low, + int target_phrase_high, int sentence_id) const { + if (index >= ranges.size()) { + if (subpatterns.back() - subpatterns.front() > max_rule_span) { + return; + } + + vector symbols; + unordered_set target_indexes; + int offset = 1; + if (subpatterns.front() != target_phrase_low) { + offset = 2; + symbols.push_back(vocabulary->GetNonterminalIndex(1)); + } + + int target_sent_start = target_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { + for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { + symbols.push_back(target_data_array->AtIndex(target_sent_start + j)); + target_indexes.insert(j); + } + if (i < target_gap_order.size()) { + symbols.push_back(vocabulary->GetNonterminalIndex( + target_gap_order[i] + offset)); + } + } + + if (subpatterns.back() != target_phrase_high) { + symbols.push_back(target_gap_order.size() + offset); + } + + const vector >& links = alignment->GetLinks(sentence_id); + vector > alignment; + for (pair link: links) { + if (target_indexes.count(link.second)) { + alignment.push_back(link); + } + } + + target_phrases.push_back(make_pair(phrase_builder->Build(symbols), + alignment)); + return; + } + + subpatterns[index] = ranges[index].first; + if (index > 0) { + subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); + } + while (subpatterns[index] <= ranges[index].second) { + subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); + while (subpatterns[index + 1] <= ranges[index + 1].second) { + GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, + target_gap_order, target_phrase_low, target_phrase_high, + sentence_id); + ++subpatterns[index + 1]; + } + ++subpatterns[index]; + } +} + +void RuleExtractor::AddNonterminalExtremities( + vector& extracts, const Phrase& source_phrase, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, const vector& source_low, + const vector& source_high, const vector& target_low, + const vector& target_high, const vector >& target_gaps, + int sentence_id, int extend_left, int extend_right) const { + int source_x_low = source_phrase_low, source_x_high = source_phrase_high; + if (extend_left) { + if (source_back_low != source_phrase_low || + source_phrase_low < min_gap_size || + (require_tight_phrases && (source_low[source_phrase_low - 1] == -1 || + source_low[source_back_high - 1] == -1))) { + return; + } + + source_x_low = source_phrase_low - min_gap_size; + if (require_tight_phrases) { + while (source_x_low >= 0 && source_low[source_x_low] == -1) { + --source_x_low; + } + } + if (source_x_low < 0) { + return; + } + } + + if (extend_right) { + int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) - + source_data_array->GetSentenceStart(sentence_id) - 1; + if (source_back_high != source_phrase_high || + source_phrase_high + min_gap_size > source_sent_len || + (require_tight_phrases && (source_low[source_phrase_low] == -1 || + source_low[source_phrase_high] == -1))) { + return; + } + source_x_high = source_phrase_high + min_gap_size; + if (require_tight_phrases) { + while (source_x_high <= source_sent_len && + source_low[source_x_high - 1] == -1) { + ++source_x_high; + } + } + + if (source_x_high > source_sent_len) { + return; + } + } + + if (source_x_high - source_x_low > max_rule_span || + target_gaps.size() + extend_left + extend_right > max_nonterminals) { + return; + } + + int target_x_low = -1, target_x_high = -1; + if (!FindFixPoint(source_x_low, source_x_high, source_low, source_high, + target_x_low, target_x_high, target_low, target_high, + source_x_low, source_x_high, sentence_id, 1, 1, + extend_left + extend_right, extend_left, extend_right, + true)) { + return; + } + + int source_gap_low = -1, source_gap_high = -1, target_gap_low = -1, + target_gap_high = -1; + if (extend_left && + ((require_tight_phrases && source_low[source_x_low] == -1) || + !FindFixPoint(source_x_low, source_phrase_low, source_low, source_high, + target_gap_low, target_gap_high, target_low, target_high, + source_gap_low, source_gap_high, sentence_id, + 0, 0, 0, 0, 0, false))) { + return; + } + if (extend_right && + ((require_tight_phrases && source_low[source_x_high - 1] == -1) || + !FindFixPoint(source_phrase_high, source_x_high, source_low, source_high, + target_gap_low, target_gap_high, target_low, target_high, + source_gap_low, source_gap_high, sentence_id, + 0, 0, 0, 0, 0, false))) { + return; + } + + Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, + extend_right); + AddExtracts(extracts, new_source_phrase, target_gaps, target_low, + target_x_low, target_x_high, sentence_id); } diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h index 13b5447a..f668de24 100644 --- a/extractor/rule_extractor.h +++ b/extractor/rule_extractor.h @@ -2,21 +2,129 @@ #define _RULE_EXTRACTOR_H_ #include +#include + +#include "phrase.h" using namespace std; class Alignment; class DataArray; -class SuffixArray; +class PhraseBuilder; +class PhraseLocation; +class Rule; +class Scorer; +class Vocabulary; + +typedef vector > PhraseAlignment; + +struct Extract { + Extract(const Phrase& source_phrase, const Phrase& target_phrase, + double pairs_count, const PhraseAlignment& alignment) : + source_phrase(source_phrase), target_phrase(target_phrase), + pairs_count(pairs_count), alignment(alignment) {} + + Phrase source_phrase; + Phrase target_phrase; + double pairs_count; + PhraseAlignment alignment; +}; class RuleExtractor { public: - RuleExtractor( - shared_ptr source_suffix_array, - shared_ptr target_data_array, - const Alignment& alingment); + RuleExtractor(shared_ptr source_data_array, + shared_ptr target_data_array, + shared_ptr alingment, + shared_ptr phrase_builder, + shared_ptr scorer, + shared_ptr vocabulary, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases); + + vector ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const; + + private: + vector ExtractAlignments(const Phrase& phrase, + const vector& matching) const; + + void GetLinksSpans(vector& source_low, vector& source_high, + vector& target_low, vector& target_high, + int sentence_id) const; + + bool CheckAlignedTerminals(const vector& matching, + const vector& chunklen, + const vector& source_low) const; + + bool CheckTightPhrases(const vector& matching, + const vector& chunklen, + const vector& source_low) const; + + bool FindFixPoint( + int source_phrase_start, int source_phrase_end, + const vector& source_low, const vector& source_high, + int& target_phrase_start, int& target_phrase_end, + const vector& target_low, const vector& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, int max_low_x, int max_high_x, + bool allow_arbitrary_expansion) const; + + void FindProjection( + int source_phrase_start, int source_phrase_end, + const vector& source_low, const vector& source_high, + int& target_phrase_low, int& target_phrase_end) const; + + bool CheckGaps( + vector >& source_gaps, vector >& target_gaps, + const vector& matching, const vector& chunklen, + const vector& source_low, const vector& source_high, + const vector& target_low, const vector& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const; + + void AddExtracts( + vector& extracts, const Phrase& source_phrase, + const vector >& target_gaps, const vector& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const; + + vector > ExtractTargetPhrases( + const vector >& target_gaps, const vector& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const; + + void GeneratePhrases( + vector >& target_phrases, + const vector >& ranges, int index, + vector& subpatterns, const vector& target_gap_order, + int target_phrase_low, int target_phrase_high, int sentence_id) const; + + void AddNonterminalExtremities( + vector& extracts, const Phrase& source_phrase, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, const vector& source_low, + const vector& source_high, const vector& target_low, + const vector& target_high, + const vector >& target_gaps, int sentence_id, + int extend_left, int extend_right) const; - void ExtractRules(); + shared_ptr source_data_array; + shared_ptr target_data_array; + shared_ptr alignment; + shared_ptr phrase_builder; + shared_ptr scorer; + shared_ptr vocabulary; + int max_rule_span; + int min_gap_size; + int max_nonterminals; + int max_rule_symbols; + bool require_aligned_terminal; + bool require_aligned_chunks; + bool require_tight_phrases; }; #endif diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 7a8356b8..c22f9b48 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -5,8 +5,15 @@ #include #include +#include "grammar.h" +#include "intersector.h" +#include "matchings_finder.h" #include "matching_comparator.h" #include "phrase.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" #include "suffix_array.h" #include "vocabulary.h" @@ -30,28 +37,39 @@ struct State { HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr source_suffix_array, shared_ptr target_data_array, - const Alignment& alignment, + shared_ptr alignment, const shared_ptr& vocabulary, - const Precomputation& precomputation, + shared_ptr precomputation, + shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, - bool use_baeza_yates) : - matchings_finder(source_suffix_array), - intersector(vocabulary, precomputation, source_suffix_array, - make_shared(min_gap_size, max_rule_span), - use_baeza_yates), - phrase_builder(vocabulary), - rule_extractor(source_suffix_array, target_data_array, alignment), + int max_samples, + bool use_baeza_yates, + bool require_tight_phrases) : vocabulary(vocabulary), + scorer(scorer), min_gap_size(min_gap_size), max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), - max_rule_symbols(max_rule_symbols) {} + max_rule_symbols(max_rule_symbols) { + matchings_finder = make_shared(source_suffix_array); + shared_ptr comparator = + make_shared(min_gap_size, max_rule_span); + intersector = make_shared(vocabulary, precomputation, + source_suffix_array, comparator, use_baeza_yates); + phrase_builder = make_shared(vocabulary); + rule_extractor = make_shared(source_suffix_array->GetData(), + target_data_array, alignment, phrase_builder, scorer, vocabulary, + max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, + false, require_tight_phrases); + sampler = make_shared(source_suffix_array, max_samples); +} + -void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { // Clear cache for every new sentence. trie.Reset(); shared_ptr root = trie.GetRoot(); @@ -69,6 +87,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { vector(1, i), x_root, true)); } + vector rules; while (!states.empty()) { State state = states.front(); states.pop(); @@ -77,7 +96,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { vector phrase = state.phrase; int word_id = word_ids[state.end]; phrase.push_back(word_id); - Phrase next_phrase = phrase_builder.Build(phrase); + Phrase next_phrase = phrase_builder->Build(phrase); shared_ptr next_node; if (CannotHaveMatchings(node, word_id)) { @@ -98,14 +117,14 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { - phrase_location = intersector.Intersect( + phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); } else { - phrase_location = matchings_finder.Find( + phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); @@ -125,7 +144,10 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { state.starts_with_x); if (!state.starts_with_x) { - rule_extractor.ExtractRules(); + PhraseLocation sample = sampler->Sample(next_node->matchings); + vector new_rules = + rule_extractor->ExtractRules(next_phrase, sample); + rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } } else { next_node = node->GetChild(word_id); @@ -137,6 +159,8 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { states.push(new_state); } } + + return Grammar(rules, scorer->GetFeatureNames()); } bool HieroCachingRuleFactory::CannotHaveMatchings( @@ -165,7 +189,7 @@ void HieroCachingRuleFactory::AddTrailingNonterminal( int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); symbols.push_back(var_id); - Phrase var_phrase = phrase_builder.Build(symbols); + Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( prefix.Arity() + starts_with_x == 0); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 8fe8bf30..a47b6d16 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -4,17 +4,21 @@ #include #include -#include "matchings_finder.h" -#include "intersector.h" #include "matchings_trie.h" #include "phrase_builder.h" -#include "rule_extractor.h" using namespace std; class Alignment; class DataArray; +class Grammar; +class MatchingsFinder; +class Intersector; class Precomputation; +class Rule; +class RuleExtractor; +class Sampler; +class Scorer; class State; class SuffixArray; class Vocabulary; @@ -24,16 +28,19 @@ class HieroCachingRuleFactory { HieroCachingRuleFactory( shared_ptr source_suffix_array, shared_ptr target_data_array, - const Alignment& alignment, + shared_ptr alignment, const shared_ptr& vocabulary, - const Precomputation& precomputation, + shared_ptr precomputation, + shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, - bool use_beaza_yates); + int max_samples, + bool use_beaza_yates, + bool require_tight_phrases); - void GetGrammar(const vector& word_ids); + Grammar GetGrammar(const vector& word_ids); private: bool CannotHaveMatchings(shared_ptr node, int word_id); @@ -51,12 +58,14 @@ class HieroCachingRuleFactory { const Phrase& phrase, const shared_ptr& node); - MatchingsFinder matchings_finder; - Intersector intersector; + shared_ptr matchings_finder; + shared_ptr intersector; MatchingsTrie trie; - PhraseBuilder phrase_builder; - RuleExtractor rule_extractor; + shared_ptr phrase_builder; + shared_ptr rule_extractor; shared_ptr vocabulary; + shared_ptr sampler; + shared_ptr scorer; int min_gap_size; int max_rule_span; int max_nonterminals; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 4f841864..37a9cba0 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -1,16 +1,31 @@ +#include #include #include +#include +#include #include #include #include "alignment.h" #include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" #include "grammar_extractor.h" #include "precomputation.h" +#include "rule.h" +#include "scorer.h" #include "suffix_array.h" #include "translation_table.h" +namespace fs = boost::filesystem; namespace po = boost::program_options; using namespace std; @@ -23,21 +38,26 @@ int main(int argc, char** argv) { ("target,e", po::value(), "Target language corpus") ("bitext,b", po::value(), "Parallel text (source ||| target)") ("alignment,a", po::value()->required(), "Bitext word alignment") + ("grammars,g", po::value()->required(), "Grammars output path") ("frequent", po::value()->default_value(100), "Number of precomputed frequent patterns") ("super_frequent", po::value()->default_value(10), "Number of precomputed super frequent patterns") - ("max_rule_span,s", po::value()->default_value(15), + ("max_rule_span", po::value()->default_value(15), "Maximum rule span") ("max_rule_symbols,l", po::value()->default_value(5), "Maximum number of symbols (terminals + nontermals) in a rule") - ("min_gap_size,g", po::value()->default_value(1), "Minimum gap size") - ("max_phrase_len,p", po::value()->default_value(4), + ("min_gap_size", po::value()->default_value(1), "Minimum gap size") + ("max_phrase_len", po::value()->default_value(4), "Maximum frequent phrase length") ("max_nonterminals", po::value()->default_value(2), "Maximum number of nonterminals in a rule") ("min_frequency", po::value()->default_value(1000), "Minimum number of occurences for a pharse to be considered frequent") + ("max_samples", po::value()->default_value(300), + "Maximum number of samples") + ("tight_phrases", po::value()->default_value(true), + "False if phrases may be loose (better, but slower)") ("baeza_yates", po::value()->default_value(true), "Use double binary search"); @@ -74,9 +94,10 @@ int main(int argc, char** argv) { make_shared(source_data_array); - Alignment alignment(vm["alignment"].as()); + shared_ptr alignment = + make_shared(vm["alignment"].as()); - Precomputation precomputation( + shared_ptr precomputation = make_shared( source_suffix_array, vm["frequent"].as(), vm["super_frequent"].as(), @@ -86,7 +107,19 @@ int main(int argc, char** argv) { vm["max_phrase_len"].as(), vm["min_frequency"].as()); - TranslationTable table(source_data_array, target_data_array, alignment); + shared_ptr table = make_shared( + source_data_array, target_data_array, alignment); + + vector > features = { + make_shared(), + make_shared(), + make_shared(), + make_shared(table), + make_shared(table), + make_shared(), + make_shared() + }; + shared_ptr scorer = make_shared(features); // TODO(pauldb): Add parallelization. GrammarExtractor extractor( @@ -94,15 +127,34 @@ int main(int argc, char** argv) { target_data_array, alignment, precomputation, + scorer, vm["min_gap_size"].as(), vm["max_rule_span"].as(), vm["max_nonterminals"].as(), vm["max_rule_symbols"].as(), - vm["baeza_yates"].as()); + vm["max_samples"].as(), + vm["baeza_yates"].as(), + vm["tight_phrases"].as()); - string sentence; + int grammar_id = 0; + fs::path grammar_path = vm["grammars"].as(); + string sentence, delimiter = "|||"; while (getline(cin, sentence)) { - extractor.GetGrammar(sentence); + string suffix = ""; + int position = sentence.find(delimiter); + if (position != sentence.npos) { + suffix = sentence.substr(position); + sentence = sentence.substr(0, position); + } + + Grammar grammar = extractor.GetGrammar(sentence); + fs::path grammar_file = grammar_path / to_string(grammar_id); + ofstream output(grammar_file.c_str()); + output << grammar; + + cout << " " << sentence << " " << suffix << endl; + ++grammar_id; } return 0; diff --git a/extractor/sampler.cc b/extractor/sampler.cc new file mode 100644 index 00000000..d8e0f49e --- /dev/null +++ b/extractor/sampler.cc @@ -0,0 +1,36 @@ +#include "sampler.h" + +#include "phrase_location.h" +#include "suffix_array.h" + +Sampler::Sampler(shared_ptr suffix_array, int max_samples) : + suffix_array(suffix_array), max_samples(max_samples) {} + +PhraseLocation Sampler::Sample(const PhraseLocation& location) const { + vector sample; + int num_subpatterns; + if (location.matchings == NULL) { + num_subpatterns = 1; + int low = location.sa_low, high = location.sa_high; + double step = max(1.0, (double) (high - low) / max_samples); + for (double i = low; i < high && sample.size() < max_samples; i += step) { + sample.push_back(suffix_array->GetSuffix(Round(i))); + } + } else { + num_subpatterns = location.num_subpatterns; + int num_matchings = location.matchings->size() / num_subpatterns; + double step = max(1.0, (double) num_matchings / max_samples); + for (double i = 0, num_samples = 0; + i < num_matchings && num_samples < max_samples; + i += step, ++num_samples) { + int start = Round(i) * num_subpatterns; + sample.insert(sample.end(), location.matchings->begin() + start, + location.matchings->begin() + start + num_subpatterns); + } + } + return PhraseLocation(sample, num_subpatterns); +} + +int Sampler::Round(double x) const { + return x + 0.5; +} diff --git a/extractor/sampler.h b/extractor/sampler.h new file mode 100644 index 00000000..3b3e3a4d --- /dev/null +++ b/extractor/sampler.h @@ -0,0 +1,24 @@ +#ifndef _SAMPLER_H_ +#define _SAMPLER_H_ + +#include + +using namespace std; + +class PhraseLocation; +class SuffixArray; + +class Sampler { + public: + Sampler(shared_ptr suffix_array, int max_samples); + + PhraseLocation Sample(const PhraseLocation& location) const; + + private: + int Round(double x) const; + + shared_ptr suffix_array; + int max_samples; +}; + +#endif diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc new file mode 100644 index 00000000..4f91965b --- /dev/null +++ b/extractor/sampler_test.cc @@ -0,0 +1,72 @@ +#include + +#include + +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" +#include "sampler.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class SamplerTest : public Test { + protected: + virtual void SetUp() { + suffix_array = make_shared(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); + } + } + + shared_ptr suffix_array; + shared_ptr sampler; +}; + +TEST_F(SamplerTest, TestSuffixArrayRange) { + PhraseLocation location(0, 10); + + sampler = make_shared(suffix_array, 1); + vector expected_locations = {0}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 2); + expected_locations = {0, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 3); + expected_locations = {0, 3, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 4); + expected_locations = {0, 3, 5, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 100); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +} + +TEST_F(SamplerTest, TestSubstringsSample) { + vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + PhraseLocation location(locations, 2); + + sampler = make_shared(suffix_array, 1); + vector expected_locations = {0, 1}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 2); + expected_locations = {0, 1, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 3); + expected_locations = {0, 1, 4, 5, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared(suffix_array, 7); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +} + +} // namespace diff --git a/extractor/scorer.cc b/extractor/scorer.cc index 22d5be1a..c87e179d 100644 --- a/extractor/scorer.cc +++ b/extractor/scorer.cc @@ -1,9 +1,22 @@ #include "scorer.h" -Scorer::Scorer(const vector& features) : features(features) {} +#include "features/feature.h" -Scorer::~Scorer() { - for (Feature* feature: features) { - delete feature; +Scorer::Scorer(const vector >& features) : + features(features) {} + +vector Scorer::Score(const FeatureContext& context) const { + vector scores; + for (auto feature: features) { + scores.push_back(feature->Score(context)); + } + return scores; +} + +vector Scorer::GetFeatureNames() const { + vector feature_names; + for (auto feature: features) { + feature_names.push_back(feature->GetName()); } + return feature_names; } diff --git a/extractor/scorer.h b/extractor/scorer.h index 57405a6c..5b328fb4 100644 --- a/extractor/scorer.h +++ b/extractor/scorer.h @@ -1,19 +1,25 @@ #ifndef _SCORER_H_ #define _SCORER_H_ +#include +#include #include -#include "features/feature.h" - using namespace std; +class Feature; +class FeatureContext; + class Scorer { public: - Scorer(const vector& features); - ~Scorer(); + Scorer(const vector >& features); + + vector Score(const FeatureContext& context) const; + + vector GetFeatureNames() const; private: - vector features; + vector > features; }; #endif diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index 76f00ace..d13eacd5 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -15,6 +15,8 @@ SuffixArray::SuffixArray(shared_ptr data_array) : BuildSuffixArray(); } +SuffixArray::SuffixArray() {} + SuffixArray::~SuffixArray() {} void SuffixArray::BuildSuffixArray() { diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index 7708f5a2..79a22694 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -21,17 +21,20 @@ class SuffixArray { virtual int GetSize() const; - shared_ptr GetData() const; + virtual shared_ptr GetData() const; - vector BuildLCPArray() const; + virtual vector BuildLCPArray() const; - int GetSuffix(int rank) const; + virtual int GetSuffix(int rank) const; virtual PhraseLocation Lookup(int low, int high, const string& word, int offset) const; void WriteBinary(const fs::path& filepath) const; + protected: + SuffixArray(); + private: void BuildSuffixArray(); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 5eb4ffdc..10f1b9ed 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -13,17 +13,17 @@ using namespace tr1; TranslationTable::TranslationTable(shared_ptr source_data_array, shared_ptr target_data_array, - const Alignment& alignment) : + shared_ptr alignment) : source_data_array(source_data_array), target_data_array(target_data_array) { const vector& source_data = source_data_array->GetData(); const vector& target_data = target_data_array->GetData(); unordered_map source_links_count; unordered_map target_links_count; - unordered_map, int, boost::hash > > links_count; + unordered_map, int, PairHash > links_count; for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { - vector > links = alignment.GetLinks(i); + const vector >& links = alignment->GetLinks(i); int source_start = source_data_array->GetSentenceStart(i); int next_source_start = source_data_array->GetSentenceStart(i + 1); int target_start = target_data_array->GetSentenceStart(i); @@ -58,7 +58,7 @@ TranslationTable::TranslationTable(shared_ptr source_data_array, } } -double TranslationTable::GetEgivenFScore( +double TranslationTable::GetTargetGivenSourceScore( const string& source_word, const string& target_word) { if (!source_data_array->HasWord(source_word) || !target_data_array->HasWord(target_word)) { @@ -70,7 +70,7 @@ double TranslationTable::GetEgivenFScore( return translation_probabilities[make_pair(source_id, target_id)].first; } -double TranslationTable::GetFgivenEScore( +double TranslationTable::GetSourceGivenTargetScore( const string& source_word, const string& target_word) { if (!source_data_array->HasWord(source_word) || !target_data_array->HasWord(target_word) == 0) { diff --git a/extractor/translation_table.h b/extractor/translation_table.h index 6004eca0..acf94af7 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -15,24 +15,28 @@ namespace fs = boost::filesystem; class Alignment; class DataArray; +typedef boost::hash > PairHash; + class TranslationTable { public: TranslationTable( shared_ptr source_data_array, shared_ptr target_data_array, - const Alignment& alignment); + shared_ptr alignment); - double GetEgivenFScore(const string& source_word, const string& target_word); + double GetTargetGivenSourceScore(const string& source_word, + const string& target_word); - double GetFgivenEScore(const string& source_word, const string& target_word); + double GetSourceGivenTargetScore(const string& source_word, + const string& target_word); void WriteBinary(const fs::path& filepath) const; private: - shared_ptr source_data_array; - shared_ptr target_data_array; - unordered_map, pair, - boost::hash > > translation_probabilities; + shared_ptr source_data_array; + shared_ptr target_data_array; + unordered_map, pair, PairHash> + translation_probabilities; }; #endif diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index 05744269..ed55e5e4 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -12,7 +12,7 @@ class Vocabulary { public: virtual ~Vocabulary(); - int GetTerminalIndex(const string& word); + virtual int GetTerminalIndex(const string& word); int GetNonterminalIndex(int position); -- cgit v1.2.3 From e1ffc4886b98f7ddcd2ec30d740aa2de8282cd8e Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Wed, 6 Mar 2013 15:25:54 +0000 Subject: Namespace for extractor. --- extractor/alignment.cc | 4 ++++ extractor/alignment.h | 4 ++++ extractor/alignment_test.cc | 4 +++- extractor/compile.cc | 1 + extractor/data_array.cc | 4 ++++ extractor/data_array.h | 4 ++++ extractor/data_array_test.cc | 4 +++- extractor/fast_intersector.cc | 4 ++++ extractor/fast_intersector.h | 4 ++++ extractor/fast_intersector_test.cc | 8 ++++---- extractor/features/count_source_target.cc | 6 ++++++ extractor/features/count_source_target.h | 6 ++++++ extractor/features/count_source_target_test.cc | 6 +++++- extractor/features/feature.cc | 6 ++++++ extractor/features/feature.h | 6 ++++++ extractor/features/is_source_singleton.cc | 6 ++++++ extractor/features/is_source_singleton.h | 6 ++++++ extractor/features/is_source_singleton_test.cc | 6 +++++- extractor/features/is_source_target_singleton.cc | 6 ++++++ extractor/features/is_source_target_singleton.h | 6 ++++++ extractor/features/is_source_target_singleton_test.cc | 6 +++++- extractor/features/max_lex_source_given_target.cc | 6 ++++++ extractor/features/max_lex_source_given_target.h | 7 +++++++ extractor/features/max_lex_source_given_target_test.cc | 6 +++++- extractor/features/max_lex_target_given_source.cc | 6 ++++++ extractor/features/max_lex_target_given_source.h | 7 +++++++ extractor/features/max_lex_target_given_source_test.cc | 6 +++++- extractor/features/sample_source_count.cc | 6 ++++++ extractor/features/sample_source_count.h | 6 ++++++ extractor/features/sample_source_count_test.cc | 6 +++++- extractor/features/target_given_source_coherent.cc | 6 ++++++ extractor/features/target_given_source_coherent.h | 6 ++++++ extractor/features/target_given_source_coherent_test.cc | 6 +++++- extractor/grammar.cc | 4 ++++ extractor/grammar.h | 4 ++++ extractor/grammar_extractor.cc | 5 +++++ extractor/grammar_extractor.h | 9 +++++++-- extractor/grammar_extractor_test.cc | 4 +++- extractor/matchings_finder.cc | 4 ++++ extractor/matchings_finder.h | 4 ++++ extractor/matchings_finder_test.cc | 4 +++- extractor/matchings_trie.cc | 4 ++++ extractor/matchings_trie.h | 4 ++++ extractor/mocks/mock_alignment.h | 4 ++++ extractor/mocks/mock_data_array.h | 4 ++++ extractor/mocks/mock_fast_intersector.h | 4 ++++ extractor/mocks/mock_feature.h | 6 ++++++ extractor/mocks/mock_matchings_finder.h | 4 ++++ extractor/mocks/mock_precomputation.h | 4 ++++ extractor/mocks/mock_rule_extractor.h | 4 ++++ extractor/mocks/mock_rule_extractor_helper.h | 4 ++++ extractor/mocks/mock_rule_factory.h | 4 ++++ extractor/mocks/mock_sampler.h | 4 ++++ extractor/mocks/mock_scorer.h | 7 ++++++- extractor/mocks/mock_suffix_array.h | 4 ++++ extractor/mocks/mock_target_phrase_extractor.h | 4 ++++ extractor/mocks/mock_translation_table.h | 4 ++++ extractor/mocks/mock_vocabulary.h | 4 ++++ extractor/phrase.cc | 4 ++++ extractor/phrase.h | 4 ++++ extractor/phrase_builder.cc | 4 ++++ extractor/phrase_builder.h | 4 ++++ extractor/phrase_location.cc | 4 ++++ extractor/phrase_location.h | 4 ++++ extractor/phrase_test.cc | 4 +++- extractor/precomputation.cc | 4 ++++ extractor/precomputation.h | 6 +++++- extractor/precomputation_test.cc | 5 ++++- extractor/rule.cc | 4 ++++ extractor/rule.h | 4 ++++ extractor/rule_extractor.cc | 6 +++++- extractor/rule_extractor.h | 8 ++++++-- extractor/rule_extractor_helper.cc | 4 ++++ extractor/rule_extractor_helper.h | 4 ++++ extractor/rule_extractor_helper_test.cc | 4 +++- extractor/rule_extractor_test.cc | 4 +++- extractor/rule_factory.cc | 5 +++++ extractor/rule_factory.h | 8 ++++++-- extractor/rule_factory_test.cc | 4 +++- extractor/run_extractor.cc | 17 ++++++++++------- extractor/sampler.cc | 4 ++++ extractor/sampler.h | 4 ++++ extractor/sampler_test.cc | 4 +++- extractor/scorer.cc | 8 ++++++-- extractor/scorer.h | 16 +++++++++++----- extractor/scorer_test.cc | 16 +++++++++------- extractor/suffix_array.cc | 4 ++++ extractor/suffix_array.h | 4 ++++ extractor/suffix_array_test.cc | 4 +++- extractor/target_phrase_extractor.cc | 4 ++++ extractor/target_phrase_extractor.h | 8 ++++++-- extractor/target_phrase_extractor_test.cc | 4 +++- extractor/time_util.cc | 4 ++++ extractor/time_util.h | 4 ++++ extractor/translation_table.cc | 4 ++++ extractor/translation_table.h | 8 ++++++-- extractor/translation_table_test.cc | 4 +++- extractor/vocabulary.cc | 3 +++ extractor/vocabulary.h | 4 ++++ 99 files changed, 460 insertions(+), 58 deletions(-) (limited to 'extractor/compile.cc') diff --git a/extractor/alignment.cc b/extractor/alignment.cc index ff39d484..f9bbcf6a 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -13,6 +13,8 @@ namespace fs = boost::filesystem; using namespace std; +namespace extractor { + Alignment::Alignment(const string& filename) { ifstream infile(filename.c_str()); string line; @@ -49,3 +51,5 @@ void Alignment::WriteBinary(const fs::path& filepath) { fwrite(alignment.data(), sizeof(pair), size, file); } } + +} // namespace extractor diff --git a/extractor/alignment.h b/extractor/alignment.h index f7e79585..ef89dc0c 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -9,6 +9,8 @@ namespace fs = boost::filesystem; using namespace std; +namespace extractor { + class Alignment { public: Alignment(const string& filename); @@ -26,4 +28,6 @@ class Alignment { vector > > alignments; }; +} // namespace extractor + #endif diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc index 1bc51a56..a7defb66 100644 --- a/extractor/alignment_test.cc +++ b/extractor/alignment_test.cc @@ -8,6 +8,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class AlignmentTest : public Test { @@ -28,4 +29,5 @@ TEST_F(AlignmentTest, TestGetLinks) { EXPECT_EQ(expected_links, alignment->GetLinks(1)); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/compile.cc b/extractor/compile.cc index f5cd41f4..7062ef03 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -14,6 +14,7 @@ namespace fs = boost::filesystem; namespace po = boost::program_options; using namespace std; +using namespace extractor; int main(int argc, char** argv) { po::options_description desc("Command line options"); diff --git a/extractor/data_array.cc b/extractor/data_array.cc index cd430c69..481abb80 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -10,6 +10,8 @@ namespace fs = boost::filesystem; using namespace std; +namespace extractor { + int DataArray::NULL_WORD = 0; int DataArray::END_OF_LINE = 1; string DataArray::NULL_WORD_STR = "__NULL__"; @@ -154,3 +156,5 @@ int DataArray::GetWordId(const string& word) const { string DataArray::GetWord(int word_id) const { return id2word[word_id]; } + +} // namespace extractor diff --git a/extractor/data_array.h b/extractor/data_array.h index 96950789..42e12135 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -10,6 +10,8 @@ namespace fs = boost::filesystem; using namespace std; +namespace extractor { + enum Side { SOURCE, TARGET @@ -73,4 +75,6 @@ class DataArray { vector sentence_start; }; +} // namespace extractor + #endif diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index ba5ce09e..71175fda 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -11,6 +11,7 @@ using namespace std; using namespace ::testing; namespace fs = boost::filesystem; +namespace extractor { namespace { class DataArrayTest : public Test { @@ -93,4 +94,5 @@ TEST_F(DataArrayTest, TestSentenceData) { } } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index 8c7a7af8..cec3d30b 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -9,6 +9,8 @@ #include "suffix_array.h" #include "vocabulary.h" +namespace extractor { + FastIntersector::FastIntersector(shared_ptr suffix_array, shared_ptr precomputation, shared_ptr vocabulary, @@ -189,3 +191,5 @@ pair FastIntersector::GetSearchRange(bool has_marginal_x) const { return make_pair(1, 2); } } + +} // namespace extractor diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 785e428e..32c88a30 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -9,6 +9,8 @@ using namespace std; +namespace extractor { + typedef boost::hash > VectorHash; typedef unordered_map, vector, VectorHash> Index; @@ -62,4 +64,6 @@ class FastIntersector { Index collocations; }; +} // namespace extractor + #endif diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 0d6ef367..76c3aaea 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -4,8 +4,8 @@ #include "fast_intersector.h" #include "mocks/mock_data_array.h" -#include "mocks/mock_suffix_array.h" #include "mocks/mock_precomputation.h" +#include "mocks/mock_suffix_array.h" #include "mocks/mock_vocabulary.h" #include "phrase.h" #include "phrase_location.h" @@ -14,6 +14,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class FastIntersectorTest : public Test { @@ -112,7 +113,6 @@ TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) { EXPECT_EQ(PhraseLocation(expected_locs, 3), result); } -/* TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) { vector symbols = {1, -1, 3}; Phrase phrase = phrase_builder->Build(symbols); @@ -141,6 +141,6 @@ TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) { EXPECT_EQ(PhraseLocation(expected_locs, 2), result); EXPECT_EQ(PhraseLocation(10, 12), suffix_location); } -*/ -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc index 9441b451..db0385e0 100644 --- a/extractor/features/count_source_target.cc +++ b/extractor/features/count_source_target.cc @@ -2,6 +2,9 @@ #include +namespace extractor { +namespace features { + double CountSourceTarget::Score(const FeatureContext& context) const { return log10(1 + context.pair_count); } @@ -9,3 +12,6 @@ double CountSourceTarget::Score(const FeatureContext& context) const { string CountSourceTarget::GetName() const { return "CountEF"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h index a2481944..dec78883 100644 --- a/extractor/features/count_source_target.h +++ b/extractor/features/count_source_target.h @@ -3,6 +3,9 @@ #include "feature.h" +namespace extractor { +namespace features { + class CountSourceTarget : public Feature { public: double Score(const FeatureContext& context) const; @@ -10,4 +13,7 @@ class CountSourceTarget : public Feature { string GetName() const; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc index 22633bb6..1fd0c2aa 100644 --- a/extractor/features/count_source_target_test.cc +++ b/extractor/features/count_source_target_test.cc @@ -8,6 +8,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class CountSourceTargetTest : public Test { @@ -29,4 +31,6 @@ TEST_F(CountSourceTargetTest, TestScore) { EXPECT_EQ(1.0, feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc index 876f5f8f..939bcc59 100644 --- a/extractor/features/feature.cc +++ b/extractor/features/feature.cc @@ -1,5 +1,11 @@ #include "feature.h" +namespace extractor { +namespace features { + const double Feature::MAX_SCORE = 99.0; Feature::~Feature() {} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/feature.h b/extractor/features/feature.h index aca58401..de2827bc 100644 --- a/extractor/features/feature.h +++ b/extractor/features/feature.h @@ -8,6 +8,9 @@ using namespace std; +namespace extractor { +namespace features { + struct FeatureContext { FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, double source_phrase_count, int pair_count, int num_samples) : @@ -33,4 +36,7 @@ class Feature { static const double MAX_SCORE; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc index 98d4e5fe..ab54e51a 100644 --- a/extractor/features/is_source_singleton.cc +++ b/extractor/features/is_source_singleton.cc @@ -2,6 +2,9 @@ #include +namespace extractor { +namespace features { + double IsSourceSingleton::Score(const FeatureContext& context) const { return context.source_phrase_count == 1; } @@ -9,3 +12,6 @@ double IsSourceSingleton::Score(const FeatureContext& context) const { string IsSourceSingleton::GetName() const { return "IsSingletonF"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h index 7cc72828..30f76c6d 100644 --- a/extractor/features/is_source_singleton.h +++ b/extractor/features/is_source_singleton.h @@ -3,6 +3,9 @@ #include "feature.h" +namespace extractor { +namespace features { + class IsSourceSingleton : public Feature { public: double Score(const FeatureContext& context) const; @@ -10,4 +13,7 @@ class IsSourceSingleton : public Feature { string GetName() const; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc index 8c71e593..f4266671 100644 --- a/extractor/features/is_source_singleton_test.cc +++ b/extractor/features/is_source_singleton_test.cc @@ -8,6 +8,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class IsSourceSingletonTest : public Test { @@ -32,4 +34,6 @@ TEST_F(IsSourceSingletonTest, TestScore) { EXPECT_EQ(1, feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc index 31d36532..03b3c62c 100644 --- a/extractor/features/is_source_target_singleton.cc +++ b/extractor/features/is_source_target_singleton.cc @@ -2,6 +2,9 @@ #include +namespace extractor { +namespace features { + double IsSourceTargetSingleton::Score(const FeatureContext& context) const { return context.pair_count == 1; } @@ -9,3 +12,6 @@ double IsSourceTargetSingleton::Score(const FeatureContext& context) const { string IsSourceTargetSingleton::GetName() const { return "IsSingletonFE"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h index 58913b74..12fb6ee6 100644 --- a/extractor/features/is_source_target_singleton.h +++ b/extractor/features/is_source_target_singleton.h @@ -3,6 +3,9 @@ #include "feature.h" +namespace extractor { +namespace features { + class IsSourceTargetSingleton : public Feature { public: double Score(const FeatureContext& context) const; @@ -10,4 +13,7 @@ class IsSourceTargetSingleton : public Feature { string GetName() const; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc index a51f77c9..929635b0 100644 --- a/extractor/features/is_source_target_singleton_test.cc +++ b/extractor/features/is_source_target_singleton_test.cc @@ -8,6 +8,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class IsSourceTargetSingletonTest : public Test { @@ -32,4 +34,6 @@ TEST_F(IsSourceTargetSingletonTest, TestScore) { EXPECT_EQ(1, feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc index 21f5c76a..3ffe598c 100644 --- a/extractor/features/max_lex_source_given_target.cc +++ b/extractor/features/max_lex_source_given_target.cc @@ -5,6 +5,9 @@ #include "../data_array.h" #include "../translation_table.h" +namespace extractor { +namespace features { + MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( shared_ptr table) : table(table) {} @@ -29,3 +32,6 @@ double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { string MaxLexSourceGivenTarget::GetName() const { return "MaxLexFgivenE"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h index e87c1c8e..bfa7ef1b 100644 --- a/extractor/features/max_lex_source_given_target.h +++ b/extractor/features/max_lex_source_given_target.h @@ -7,8 +7,12 @@ using namespace std; +namespace extractor { + class TranslationTable; +namespace features { + class MaxLexSourceGivenTarget : public Feature { public: MaxLexSourceGivenTarget(shared_ptr table); @@ -21,4 +25,7 @@ class MaxLexSourceGivenTarget : public Feature { shared_ptr table; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc index 5fd41f8b..c1edb483 100644 --- a/extractor/features/max_lex_source_given_target_test.cc +++ b/extractor/features/max_lex_source_given_target_test.cc @@ -13,6 +13,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class MaxLexSourceGivenTargetTest : public Test { @@ -71,4 +73,6 @@ TEST_F(MaxLexSourceGivenTargetTest, TestScore) { EXPECT_EQ(99 - log10(18), feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc index f2bc2474..30140d80 100644 --- a/extractor/features/max_lex_target_given_source.cc +++ b/extractor/features/max_lex_target_given_source.cc @@ -5,6 +5,9 @@ #include "../data_array.h" #include "../translation_table.h" +namespace extractor { +namespace features { + MaxLexTargetGivenSource::MaxLexTargetGivenSource( shared_ptr table) : table(table) {} @@ -29,3 +32,6 @@ double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { string MaxLexTargetGivenSource::GetName() const { return "MaxLexEgivenF"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h index 9585ff04..66cf0914 100644 --- a/extractor/features/max_lex_target_given_source.h +++ b/extractor/features/max_lex_target_given_source.h @@ -7,8 +7,12 @@ using namespace std; +namespace extractor { + class TranslationTable; +namespace features { + class MaxLexTargetGivenSource : public Feature { public: MaxLexTargetGivenSource(shared_ptr table); @@ -21,4 +25,7 @@ class MaxLexTargetGivenSource : public Feature { shared_ptr table; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc index c8701bf7..9ceb13e5 100644 --- a/extractor/features/max_lex_target_given_source_test.cc +++ b/extractor/features/max_lex_target_given_source_test.cc @@ -13,6 +13,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class MaxLexTargetGivenSourceTest : public Test { @@ -71,4 +73,6 @@ TEST_F(MaxLexTargetGivenSourceTest, TestScore) { EXPECT_EQ(-log10(36), feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc index 88b645b1..b110fc51 100644 --- a/extractor/features/sample_source_count.cc +++ b/extractor/features/sample_source_count.cc @@ -2,6 +2,9 @@ #include +namespace extractor { +namespace features { + double SampleSourceCount::Score(const FeatureContext& context) const { return log10(1 + context.num_samples); } @@ -9,3 +12,6 @@ double SampleSourceCount::Score(const FeatureContext& context) const { string SampleSourceCount::GetName() const { return "SampleCountF"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h index 62d236c8..53c7f954 100644 --- a/extractor/features/sample_source_count.h +++ b/extractor/features/sample_source_count.h @@ -3,6 +3,9 @@ #include "feature.h" +namespace extractor { +namespace features { + class SampleSourceCount : public Feature { public: double Score(const FeatureContext& context) const; @@ -10,4 +13,7 @@ class SampleSourceCount : public Feature { string GetName() const; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc index 7d226104..63856b9d 100644 --- a/extractor/features/sample_source_count_test.cc +++ b/extractor/features/sample_source_count_test.cc @@ -9,6 +9,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class SampleSourceCountTest : public Test { @@ -33,4 +35,6 @@ TEST_F(SampleSourceCountTest, TestScore) { EXPECT_EQ(1.0, feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc index 274b3364..c4551d88 100644 --- a/extractor/features/target_given_source_coherent.cc +++ b/extractor/features/target_given_source_coherent.cc @@ -2,6 +2,9 @@ #include +namespace extractor { +namespace features { + double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { double prob = (double) context.pair_count / context.num_samples; return prob > 0 ? -log10(prob) : MAX_SCORE; @@ -10,3 +13,6 @@ double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { string TargetGivenSourceCoherent::GetName() const { return "EgivenFCoherent"; } + +} // namespace features +} // namespace extractor diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h index 09c8edb1..80d9f617 100644 --- a/extractor/features/target_given_source_coherent.h +++ b/extractor/features/target_given_source_coherent.h @@ -3,6 +3,9 @@ #include "feature.h" +namespace extractor { +namespace features { + class TargetGivenSourceCoherent : public Feature { public: double Score(const FeatureContext& context) const; @@ -10,4 +13,7 @@ class TargetGivenSourceCoherent : public Feature { string GetName() const; }; +} // namespace features +} // namespace extractor + #endif diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc index c54c06c2..454105e1 100644 --- a/extractor/features/target_given_source_coherent_test.cc +++ b/extractor/features/target_given_source_coherent_test.cc @@ -8,6 +8,8 @@ using namespace std; using namespace ::testing; +namespace extractor { +namespace features { namespace { class TargetGivenSourceCoherentTest : public Test { @@ -32,4 +34,6 @@ TEST_F(TargetGivenSourceCoherentTest, TestScore) { EXPECT_EQ(99.0, feature->Score(context)); } -} // namespace +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/grammar.cc b/extractor/grammar.cc index 8124a804..8e5bcd45 100644 --- a/extractor/grammar.cc +++ b/extractor/grammar.cc @@ -6,6 +6,8 @@ using namespace std; +namespace extractor { + Grammar::Grammar(const vector& rules, const vector& feature_names) : rules(rules), feature_names(feature_names) {} @@ -37,3 +39,5 @@ ostream& operator<<(ostream& os, const Grammar& grammar) { return os; } + +} // namespace extractor diff --git a/extractor/grammar.h b/extractor/grammar.h index 889cc2f3..a424d65a 100644 --- a/extractor/grammar.h +++ b/extractor/grammar.h @@ -7,6 +7,8 @@ using namespace std; +namespace extractor { + class Rule; class Grammar { @@ -24,4 +26,6 @@ class Grammar { vector feature_names; }; +} // namespace extractor + #endif diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index b8f6f0c7..8050ce7b 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -6,10 +6,13 @@ #include "grammar.h" #include "rule.h" +#include "rule_factory.h" #include "vocabulary.h" using namespace std; +namespace extractor { + GrammarExtractor::GrammarExtractor( shared_ptr source_suffix_array, shared_ptr target_data_array, @@ -55,3 +58,5 @@ vector GrammarExtractor::AnnotateWords(const vector& words) { } return result; } + +} // namespace extractor diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index f50a8d14..6b1dcf98 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -1,18 +1,21 @@ #ifndef _GRAMMAR_EXTRACTOR_H_ #define _GRAMMAR_EXTRACTOR_H_ +#include #include #include -#include "rule_factory.h" - using namespace std; +namespace extractor { + class Alignment; class DataArray; class Grammar; +class HieroCachingRuleFactory; class Precomputation; class Rule; +class Scorer; class SuffixArray; class Vocabulary; @@ -46,4 +49,6 @@ class GrammarExtractor { shared_ptr rule_factory; }; +} // namespace extractor + #endif diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index d4ed7d4f..823bb8b4 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -13,6 +13,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { TEST(GrammarExtractorTest, TestAnnotatingWords) { @@ -46,4 +47,5 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { extractor.GetGrammar(sentence); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc index eaf493b2..ceed6891 100644 --- a/extractor/matchings_finder.cc +++ b/extractor/matchings_finder.cc @@ -3,6 +3,8 @@ #include "suffix_array.h" #include "phrase_location.h" +namespace extractor { + MatchingsFinder::MatchingsFinder(shared_ptr suffix_array) : suffix_array(suffix_array) {} @@ -19,3 +21,5 @@ PhraseLocation MatchingsFinder::Find(PhraseLocation& location, return suffix_array->Lookup(location.sa_low, location.sa_high, word, offset); } + +} // namespace extractor diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h index ed04d8b8..fbb504ef 100644 --- a/extractor/matchings_finder.h +++ b/extractor/matchings_finder.h @@ -6,6 +6,8 @@ using namespace std; +namespace extractor { + class PhraseLocation; class SuffixArray; @@ -25,4 +27,6 @@ class MatchingsFinder { shared_ptr suffix_array; }; +} // namespace extractor + #endif diff --git a/extractor/matchings_finder_test.cc b/extractor/matchings_finder_test.cc index 817f1635..d40e5191 100644 --- a/extractor/matchings_finder_test.cc +++ b/extractor/matchings_finder_test.cc @@ -9,6 +9,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class MatchingsFinderTest : public Test { @@ -39,4 +40,5 @@ TEST_F(MatchingsFinderTest, ResizeUnsetRange) { EXPECT_EQ(PhraseLocation(0, 10), phrase_location); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc index 8ea795db..c7b98765 100644 --- a/extractor/matchings_trie.cc +++ b/extractor/matchings_trie.cc @@ -1,5 +1,7 @@ #include "matchings_trie.h" +namespace extractor { + void MatchingsTrie::Reset() { ResetTree(root); root = make_shared(); @@ -20,3 +22,5 @@ void MatchingsTrie::ResetTree(shared_ptr root) { root.reset(); } } + +} // namespace extractor diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index 6e72b2db..a54671d2 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -9,6 +9,8 @@ using namespace std; +namespace extractor { + struct TrieNode { TrieNode(shared_ptr suffix_link = shared_ptr(), Phrase phrase = Phrase(), @@ -44,4 +46,6 @@ class MatchingsTrie { shared_ptr root; }; +} // namespace extractor + #endif diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h index 4a5077ad..3d745e9d 100644 --- a/extractor/mocks/mock_alignment.h +++ b/extractor/mocks/mock_alignment.h @@ -2,9 +2,13 @@ #include "../alignment.h" +namespace extractor { + typedef vector > SentenceLinks; class MockAlignment : public Alignment { public: MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 004e8906..cf9f3671 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -2,6 +2,8 @@ #include "../data_array.h" +namespace extractor { + class MockDataArray : public DataArray { public: MOCK_CONST_METHOD0(GetData, const vector&()); @@ -17,3 +19,5 @@ class MockDataArray : public DataArray { MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id)); MOCK_CONST_METHOD1(GetSentenceId, int(int position)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h index 201386f2..665add65 100644 --- a/extractor/mocks/mock_fast_intersector.h +++ b/extractor/mocks/mock_fast_intersector.h @@ -4,8 +4,12 @@ #include "../phrase.h" #include "../phrase_location.h" +namespace extractor { + class MockFastIntersector : public FastIntersector { public: MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&, const Phrase&)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h index d2137629..19ba4de9 100644 --- a/extractor/mocks/mock_feature.h +++ b/extractor/mocks/mock_feature.h @@ -2,8 +2,14 @@ #include "../features/feature.h" +namespace extractor { +namespace features { + class MockFeature : public Feature { public: MOCK_CONST_METHOD1(Score, double(const FeatureContext& context)); MOCK_CONST_METHOD0(GetName, string()); }; + +} // namespace features +} // namespace extractor diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h index 3e80d266..ffbb06c7 100644 --- a/extractor/mocks/mock_matchings_finder.h +++ b/extractor/mocks/mock_matchings_finder.h @@ -3,7 +3,11 @@ #include "../matchings_finder.h" #include "../phrase_location.h" +namespace extractor { + class MockMatchingsFinder : public MatchingsFinder { public: MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 9bc72235..64934b94 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -2,7 +2,11 @@ #include "../precomputation.h" +namespace extractor { + class MockPrecomputation : public Precomputation { public: MOCK_CONST_METHOD0(GetCollocations, const Index&()); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h index f18e009a..28b644b0 100644 --- a/extractor/mocks/mock_rule_extractor.h +++ b/extractor/mocks/mock_rule_extractor.h @@ -5,8 +5,12 @@ #include "../rule.h" #include "../rule_extractor.h" +namespace extractor { + class MockRuleExtractor : public RuleExtractor { public: MOCK_CONST_METHOD2(ExtractRules, vector(const Phrase&, const PhraseLocation&)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h index 63ff1048..3b0ac0f5 100644 --- a/extractor/mocks/mock_rule_extractor_helper.h +++ b/extractor/mocks/mock_rule_extractor_helper.h @@ -6,6 +6,8 @@ using namespace std; +namespace extractor { + typedef unordered_map Indexes; class MockRuleExtractorHelper : public RuleExtractorHelper { @@ -76,3 +78,5 @@ class MockRuleExtractorHelper : public RuleExtractorHelper { bool met_constraints; bool get_gaps; }; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 2a96be93..11cb9ab5 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -3,7 +3,11 @@ #include "../grammar.h" #include "../rule_factory.h" +namespace extractor { + class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: MOCK_METHOD1(GetGrammar, Grammar(const vector& word_ids)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index b2306109..7022f7b3 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -3,7 +3,11 @@ #include "../phrase_location.h" #include "../sampler.h" +namespace extractor { + class MockSampler : public Sampler { public: MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h index 48115ef4..4d593ddf 100644 --- a/extractor/mocks/mock_scorer.h +++ b/extractor/mocks/mock_scorer.h @@ -3,8 +3,13 @@ #include "../scorer.h" #include "../features/feature.h" +namespace extractor { + class MockScorer : public Scorer { public: - MOCK_CONST_METHOD1(Score, vector(const FeatureContext& context)); + MOCK_CONST_METHOD1(Score, vector( + const features::FeatureContext& context)); MOCK_CONST_METHOD0(GetFeatureNames, vector()); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h index 11a3a443..6886232a 100644 --- a/extractor/mocks/mock_suffix_array.h +++ b/extractor/mocks/mock_suffix_array.h @@ -9,6 +9,8 @@ using namespace std; +namespace extractor { + class MockSuffixArray : public SuffixArray { public: MOCK_CONST_METHOD0(GetSize, int()); @@ -17,3 +19,5 @@ class MockSuffixArray : public SuffixArray { MOCK_CONST_METHOD1(GetSuffix, int(int)); MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h index 6dc6bba6..e5e9aeab 100644 --- a/extractor/mocks/mock_target_phrase_extractor.h +++ b/extractor/mocks/mock_target_phrase_extractor.h @@ -2,6 +2,8 @@ #include "../target_phrase_extractor.h" +namespace extractor { + typedef pair PhraseExtract; class MockTargetPhraseExtractor : public TargetPhraseExtractor { @@ -10,3 +12,5 @@ class MockTargetPhraseExtractor : public TargetPhraseExtractor { const vector > &, const vector&, int, int, const unordered_map&, int)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h index a35c9327..358c854f 100644 --- a/extractor/mocks/mock_translation_table.h +++ b/extractor/mocks/mock_translation_table.h @@ -2,8 +2,12 @@ #include "../translation_table.h" +namespace extractor { + class MockTranslationTable : public TranslationTable { public: MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&)); MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&)); }; + +} // namespace extractor diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h index e5c191f5..802c29b4 100644 --- a/extractor/mocks/mock_vocabulary.h +++ b/extractor/mocks/mock_vocabulary.h @@ -2,8 +2,12 @@ #include "../vocabulary.h" +namespace extractor { + class MockVocabulary : public Vocabulary { public: MOCK_METHOD1(GetTerminalValue, string(int word_id)); MOCK_METHOD1(GetTerminalIndex, int(const string& word)); }; + +} // namespace extractor diff --git a/extractor/phrase.cc b/extractor/phrase.cc index 6dc242db..244fab07 100644 --- a/extractor/phrase.cc +++ b/extractor/phrase.cc @@ -1,5 +1,7 @@ #include "phrase.h" +namespace extractor { + int Phrase::Arity() const { return var_pos.size(); } @@ -52,3 +54,5 @@ ostream& operator<<(ostream& os, const Phrase& phrase) { } return os; } + +} // namspace extractor diff --git a/extractor/phrase.h b/extractor/phrase.h index f40a8169..8c98a025 100644 --- a/extractor/phrase.h +++ b/extractor/phrase.h @@ -9,6 +9,8 @@ using namespace std; +namespace extractor { + class Phrase { public: friend Phrase PhraseBuilder::Build(const vector& phrase); @@ -38,4 +40,6 @@ class Phrase { vector words; }; +} // namespace extractor + #endif diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc index 4325390c..9faee4be 100644 --- a/extractor/phrase_builder.cc +++ b/extractor/phrase_builder.cc @@ -3,6 +3,8 @@ #include "phrase.h" #include "vocabulary.h" +namespace extractor { + PhraseBuilder::PhraseBuilder(shared_ptr vocabulary) : vocabulary(vocabulary) {} @@ -42,3 +44,5 @@ Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) { return Build(symbols); } + +} // namespace extractor diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h index a49af457..2956fd35 100644 --- a/extractor/phrase_builder.h +++ b/extractor/phrase_builder.h @@ -6,6 +6,8 @@ using namespace std; +namespace extractor { + class Phrase; class Vocabulary; @@ -21,4 +23,6 @@ class PhraseBuilder { shared_ptr vocabulary; }; +} // namespace extractor + #endif diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index b0bfed80..678ae270 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,5 +1,7 @@ #include "phrase_location.h" +namespace extractor { + PhraseLocation::PhraseLocation(int sa_low, int sa_high) : sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {} @@ -37,3 +39,5 @@ bool operator==(const PhraseLocation& a, const PhraseLocation& b) { return *a.matchings == *b.matchings; } + +} // namespace extractor diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h index a0eb36c8..e5f3cf08 100644 --- a/extractor/phrase_location.h +++ b/extractor/phrase_location.h @@ -6,6 +6,8 @@ using namespace std; +namespace extractor { + struct PhraseLocation { PhraseLocation(int sa_low = -1, int sa_high = -1); @@ -22,4 +24,6 @@ struct PhraseLocation { int num_subpatterns; }; +} // namespace extractor + #endif diff --git a/extractor/phrase_test.cc b/extractor/phrase_test.cc index 2b553b6f..c8176178 100644 --- a/extractor/phrase_test.cc +++ b/extractor/phrase_test.cc @@ -10,6 +10,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class PhraseTest : public Test { @@ -58,4 +59,5 @@ TEST_F(PhraseTest, TestGetSymbol) { } } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 189ac42c..8cc32ffd 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -8,6 +8,8 @@ using namespace std; +namespace extractor { + int Precomputation::NON_TERMINAL = -1; Precomputation::Precomputation( @@ -173,3 +175,5 @@ void Precomputation::WriteBinary(const fs::path& filepath) const { const Index& Precomputation::GetCollocations() const { return collocations; } + +} // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 3d44c2a6..dbd99c14 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -13,11 +13,13 @@ namespace fs = boost::filesystem; using namespace std; -class SuffixArray; +namespace extractor { typedef boost::hash > VectorHash; typedef unordered_map, vector, VectorHash> Index; +class SuffixArray; + class Precomputation { public: Precomputation( @@ -51,4 +53,6 @@ class Precomputation { Index collocations; }; +} // namespace extractor + #endif diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index 6b77b9c0..04e3850d 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -10,6 +10,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class PrecomputationTest : public Test { @@ -101,4 +102,6 @@ TEST_F(PrecomputationTest, TestCollocations) { EXPECT_EQ(0, collocations.count(key)); } -} // namespace +} // namespace +} // namespace extractor + diff --git a/extractor/rule.cc b/extractor/rule.cc index 9c7ac9b5..b6c7d783 100644 --- a/extractor/rule.cc +++ b/extractor/rule.cc @@ -1,5 +1,7 @@ #include "rule.h" +namespace extractor { + Rule::Rule(const Phrase& source_phrase, const Phrase& target_phrase, const vector& scores, @@ -8,3 +10,5 @@ Rule::Rule(const Phrase& source_phrase, target_phrase(target_phrase), scores(scores), alignment(alignment) {} + +} // namespace extractor diff --git a/extractor/rule.h b/extractor/rule.h index 64ff8794..b4d45fc1 100644 --- a/extractor/rule.h +++ b/extractor/rule.h @@ -7,6 +7,8 @@ using namespace std; +namespace extractor { + struct Rule { Rule(const Phrase& source_phrase, const Phrase& target_phrase, const vector& scores, const vector >& alignment); @@ -17,4 +19,6 @@ struct Rule { vector > alignment; }; +} // namespace extractor + #endif diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc index 92343241..b9286472 100644 --- a/extractor/rule_extractor.cc +++ b/extractor/rule_extractor.cc @@ -14,6 +14,8 @@ using namespace std; +namespace extractor { + RuleExtractor::RuleExtractor( shared_ptr source_data_array, shared_ptr target_data_array, @@ -106,7 +108,7 @@ vector RuleExtractor::ExtractRules(const Phrase& phrase, } } - FeatureContext context(source_phrase, target_phrase, + features::FeatureContext context(source_phrase, target_phrase, source_phrase_counter[source_phrase], num_locations, num_samples); vector scores = scorer->Score(context); rules.push_back(Rule(source_phrase, target_phrase, scores, @@ -313,3 +315,5 @@ void RuleExtractor::AddNonterminalExtremities( AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps, target_low, target_x_low, target_x_high, sentence_id); } + +} // namespace extractor diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h index a087dc6d..8b6daeea 100644 --- a/extractor/rule_extractor.h +++ b/extractor/rule_extractor.h @@ -9,6 +9,10 @@ using namespace std; +namespace extractor { + +typedef vector > PhraseAlignment; + class Alignment; class DataArray; class PhraseBuilder; @@ -18,8 +22,6 @@ class RuleExtractorHelper; class Scorer; class TargetPhraseExtractor; -typedef vector > PhraseAlignment; - struct Extract { Extract(const Phrase& source_phrase, const Phrase& target_phrase, double pairs_count, const PhraseAlignment& alignment) : @@ -101,4 +103,6 @@ class RuleExtractor { bool require_tight_phrases; }; +} // namespace extractor + #endif diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc index ed6ae3a1..553b56d4 100644 --- a/extractor/rule_extractor_helper.cc +++ b/extractor/rule_extractor_helper.cc @@ -3,6 +3,8 @@ #include "data_array.h" #include "alignment.h" +namespace extractor { + RuleExtractorHelper::RuleExtractorHelper( shared_ptr source_data_array, shared_ptr target_data_array, @@ -354,3 +356,5 @@ unordered_map RuleExtractorHelper::GetSourceIndexes( } return source_indexes; } + +} // namespace extractor diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h index 3478bfc8..95274df6 100644 --- a/extractor/rule_extractor_helper.h +++ b/extractor/rule_extractor_helper.h @@ -7,6 +7,8 @@ using namespace std; +namespace extractor { + class Alignment; class DataArray; @@ -79,4 +81,6 @@ class RuleExtractorHelper { bool require_tight_phrases; }; +} // namespace extractor + #endif diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc index 29213312..ec0635b1 100644 --- a/extractor/rule_extractor_helper_test.cc +++ b/extractor/rule_extractor_helper_test.cc @@ -9,6 +9,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class RuleExtractorHelperTest : public Test { @@ -619,4 +620,5 @@ TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) { met_constraints)); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc index 0be44d4d..1b543fc9 100644 --- a/extractor/rule_extractor_test.cc +++ b/extractor/rule_extractor_test.cc @@ -17,6 +17,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class RuleExtractorTest : public Test { @@ -163,4 +164,5 @@ TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { EXPECT_EQ(4, rules.size()); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 4908dac0..51f85c30 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -9,6 +9,7 @@ #include "fast_intersector.h" #include "matchings_finder.h" #include "phrase.h" +#include "phrase_builder.h" #include "rule.h" #include "rule_extractor.h" #include "sampler.h" @@ -20,6 +21,8 @@ using namespace std; using namespace chrono; +namespace extractor { + typedef high_resolution_clock Clock; struct State { @@ -282,3 +285,5 @@ vector HieroCachingRuleFactory::ExtendState( return new_states; } + +} // namespace extractor diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 2d1c7f5a..0de04e40 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -5,15 +5,17 @@ #include #include "matchings_trie.h" -#include "phrase_builder.h" using namespace std; +namespace extractor { + class Alignment; class DataArray; +class FastIntersector; class Grammar; class MatchingsFinder; -class FastIntersector; +class PhraseBuilder; class Precomputation; class Rule; class RuleExtractor; @@ -92,4 +94,6 @@ class HieroCachingRuleFactory { int max_rule_symbols; }; +} // namespace extractor + #endif diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index fc709461..2129dfa0 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -18,6 +18,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class RuleFactoryTest : public Test { @@ -98,4 +99,5 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { EXPECT_EQ(28, grammar.GetRules().size()); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 5255737d..c701c8d0 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -30,6 +31,8 @@ namespace fs = boost::filesystem; namespace po = boost::program_options; using namespace std; +using namespace extractor; +using namespace features; int main(int argc, char** argv) { // TODO(pauldb): Also take arguments from config file. @@ -146,13 +149,13 @@ int main(int argc, char** argv) { Clock::time_point extraction_start_time = Clock::now(); vector > features = { - make_shared(), - make_shared(), - make_shared(), - make_shared(table), - make_shared(table), - make_shared(), - make_shared() +// make_shared(), +// make_shared(), +// make_shared(), +// make_shared(table), +// make_shared(table), +// make_shared(), +// make_shared() }; shared_ptr scorer = make_shared(features); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 5067ca8a..d128913f 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -3,6 +3,8 @@ #include "phrase_location.h" #include "suffix_array.h" +namespace extractor { + Sampler::Sampler(shared_ptr suffix_array, int max_samples) : suffix_array(suffix_array), max_samples(max_samples) {} @@ -39,3 +41,5 @@ int Sampler::Round(double x) const { // TODO(pauldb): Remove EPS. return x + 0.5 + 1e-8; } + +} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h index 9cf321fb..cda28b10 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -5,6 +5,8 @@ using namespace std; +namespace extractor { + class PhraseLocation; class SuffixArray; @@ -26,4 +28,6 @@ class Sampler { int max_samples; }; +} // namespace extractor + #endif diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 4f91965b..e9abebfa 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -9,6 +9,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class SamplerTest : public Test { @@ -69,4 +70,5 @@ TEST_F(SamplerTest, TestSubstringsSample) { EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/scorer.cc b/extractor/scorer.cc index f28b3181..d3ebf1c9 100644 --- a/extractor/scorer.cc +++ b/extractor/scorer.cc @@ -2,14 +2,16 @@ #include "features/feature.h" -Scorer::Scorer(const vector >& features) : +namespace extractor { + +Scorer::Scorer(const vector >& features) : features(features) {} Scorer::Scorer() {} Scorer::~Scorer() {} -vector Scorer::Score(const FeatureContext& context) const { +vector Scorer::Score(const features::FeatureContext& context) const { vector scores; for (auto feature: features) { scores.push_back(feature->Score(context)); @@ -24,3 +26,5 @@ vector Scorer::GetFeatureNames() const { } return feature_names; } + +} // namespace extractor diff --git a/extractor/scorer.h b/extractor/scorer.h index ba71a6ee..c31db0ca 100644 --- a/extractor/scorer.h +++ b/extractor/scorer.h @@ -7,16 +7,20 @@ using namespace std; -class Feature; -class FeatureContext; +namespace extractor { + +namespace features { + class Feature; + class FeatureContext; +} // namespace features class Scorer { public: - Scorer(const vector >& features); + Scorer(const vector >& features); virtual ~Scorer(); - virtual vector Score(const FeatureContext& context) const; + virtual vector Score(const features::FeatureContext& context) const; virtual vector GetFeatureNames() const; @@ -24,7 +28,9 @@ class Scorer { Scorer(); private: - vector > features; + vector > features; }; +} // namespace extractor + #endif diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc index 56a85762..3a09c9cc 100644 --- a/extractor/scorer_test.cc +++ b/extractor/scorer_test.cc @@ -10,32 +10,33 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class ScorerTest : public Test { protected: virtual void SetUp() { - feature1 = make_shared(); + feature1 = make_shared(); EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5)); EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1")); - feature2 = make_shared(); + feature2 = make_shared(); EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3)); EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2")); - vector > features = {feature1, feature2}; + vector > features = {feature1, feature2}; scorer = make_shared(features); } - shared_ptr feature1; - shared_ptr feature2; + shared_ptr feature1; + shared_ptr feature2; shared_ptr scorer; }; TEST_F(ScorerTest, TestScore) { vector expected_scores = {0.5, -1.3}; Phrase phrase; - FeatureContext context(phrase, phrase, 0.3, 2, 11); + features::FeatureContext context(phrase, phrase, 0.3, 2, 11); EXPECT_EQ(expected_scores, scorer->Score(context)); } @@ -44,4 +45,5 @@ TEST_F(ScorerTest, TestGetNames) { EXPECT_EQ(expected_names, scorer->GetFeatureNames()); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index ab8a0913..9988b1a2 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -13,6 +13,8 @@ namespace fs = boost::filesystem; using namespace std; using namespace chrono; +namespace extractor { + SuffixArray::SuffixArray(shared_ptr data_array) : data_array(data_array) { BuildSuffixArray(); @@ -227,3 +229,5 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id, } return result; } + +} // namespace extractor diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index 79a22694..7a4f1110 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -10,6 +10,8 @@ namespace fs = boost::filesystem; using namespace std; +namespace extractor { + class DataArray; class PhraseLocation; @@ -51,4 +53,6 @@ class SuffixArray { vector word_start; }; +} // namespace extractor + #endif diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index 60295567..8431a16e 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -9,6 +9,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class SuffixArrayTest : public Test { @@ -73,4 +74,5 @@ TEST_F(SuffixArrayTest, TestLookup) { EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc index ac583953..9f8bc6e2 100644 --- a/extractor/target_phrase_extractor.cc +++ b/extractor/target_phrase_extractor.cc @@ -11,6 +11,8 @@ using namespace std; +namespace extractor { + TargetPhraseExtractor::TargetPhraseExtractor( shared_ptr target_data_array, shared_ptr alignment, @@ -142,3 +144,5 @@ void TargetPhraseExtractor::GeneratePhrases( ++subpatterns[index]; } } + +} // namespace extractor diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h index 134f24cc..a4b54145 100644 --- a/extractor/target_phrase_extractor.h +++ b/extractor/target_phrase_extractor.h @@ -7,6 +7,10 @@ using namespace std; +namespace extractor { + +typedef vector > PhraseAlignment; + class Alignment; class DataArray; class Phrase; @@ -14,8 +18,6 @@ class PhraseBuilder; class RuleExtractorHelper; class Vocabulary; -typedef vector > PhraseAlignment; - class TargetPhraseExtractor { public: TargetPhraseExtractor(shared_ptr target_data_array, @@ -53,4 +55,6 @@ class TargetPhraseExtractor { bool require_tight_phrases; }; +} // namespace extractor + #endif diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc index 7394f4d9..51c753ee 100644 --- a/extractor/target_phrase_extractor_test.cc +++ b/extractor/target_phrase_extractor_test.cc @@ -14,6 +14,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { class TargetPhraseExtractorTest : public Test { @@ -113,4 +114,5 @@ TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) { // look like. } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/time_util.cc b/extractor/time_util.cc index 88395f77..e46a0c3d 100644 --- a/extractor/time_util.cc +++ b/extractor/time_util.cc @@ -1,6 +1,10 @@ #include "time_util.h" +namespace extractor { + double GetDuration(const Clock::time_point& start_time, const Clock::time_point& stop_time) { return duration_cast(stop_time - start_time).count() / 1000.0; } + +} // namespace extractor diff --git a/extractor/time_util.h b/extractor/time_util.h index 6f7eda70..45f79199 100644 --- a/extractor/time_util.h +++ b/extractor/time_util.h @@ -6,9 +6,13 @@ using namespace std; using namespace chrono; +namespace extractor { + typedef high_resolution_clock Clock; double GetDuration(const Clock::time_point& start_time, const Clock::time_point& stop_time); +} // namespace extractor + #endif diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index a48c0657..1852a357 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -10,6 +10,8 @@ using namespace std; +namespace extractor { + TranslationTable::TranslationTable(shared_ptr source_data_array, shared_ptr target_data_array, shared_ptr alignment) : @@ -115,3 +117,5 @@ void TranslationTable::WriteBinary(const fs::path& filepath) const { fwrite(&entry.second, sizeof(entry.second), 1, file); } } + +} // namespace extractor diff --git a/extractor/translation_table.h b/extractor/translation_table.h index 157ad3af..a7be26f5 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -11,11 +11,13 @@ using namespace std; namespace fs = boost::filesystem; -class Alignment; -class DataArray; +namespace extractor { typedef boost::hash > PairHash; +class Alignment; +class DataArray; + class TranslationTable { public: TranslationTable( @@ -50,4 +52,6 @@ class TranslationTable { translation_probabilities; }; +} // namespace extractor + #endif diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index c99f3f93..051b5715 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -11,6 +11,7 @@ using namespace std; using namespace ::testing; +namespace extractor { namespace { TEST(TranslationTableTest, TestScores) { @@ -79,4 +80,5 @@ TEST(TranslationTableTest, TestScores) { EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); } -} // namespace +} // namespace +} // namespace extractor diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index b68d76a9..57f564d9 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -1,5 +1,7 @@ #include "vocabulary.h" +namespace extractor { + Vocabulary::~Vocabulary() {} int Vocabulary::GetTerminalIndex(const string& word) { @@ -29,3 +31,4 @@ int Vocabulary::Size() { return words.size(); } +} // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index ff3e7a63..dcc2a8fa 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -7,6 +7,8 @@ using namespace std; +namespace extractor { + class Vocabulary { public: virtual ~Vocabulary(); @@ -26,4 +28,6 @@ class Vocabulary { vector words; }; +} // namespace extractor + #endif -- cgit v1.2.3 From e6181c89ab8f29d8bd0fc6a3a8a359cb50c2304c Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Sun, 10 Mar 2013 01:01:01 +0000 Subject: Added comments. Hooray! --- extractor/alignment.cc | 2 -- extractor/alignment.h | 6 ++++ extractor/compile.cc | 2 +- extractor/data_array.h | 36 +++++++++++++++++++++-- extractor/fast_intersector.cc | 3 ++ extractor/fast_intersector.h | 27 +++++++++++++++++ extractor/features/count_source_target.h | 3 ++ extractor/features/feature.h | 6 ++++ extractor/features/is_source_singleton.h | 3 ++ extractor/features/is_source_target_singleton.h | 3 ++ extractor/features/max_lex_source_given_target.h | 3 ++ extractor/features/max_lex_target_given_source.h | 3 ++ extractor/features/sample_source_count.h | 4 +++ extractor/features/target_given_source_coherent.h | 4 +++ extractor/grammar.h | 3 ++ extractor/grammar_extractor.h | 8 +++++ extractor/matchings_finder.h | 5 ++++ extractor/matchings_trie.h | 12 ++++++++ extractor/phrase.h | 10 +++++++ extractor/phrase_builder.h | 5 ++++ extractor/phrase_location.h | 12 ++++++++ extractor/precomputation.cc | 11 +++++++ extractor/precomputation.h | 21 +++++++++++++ extractor/rule.h | 3 ++ extractor/rule_extractor.cc | 21 +++++++++++++ extractor/rule_extractor.h | 16 ++++++++++ extractor/rule_extractor_helper.cc | 11 +++++++ extractor/rule_extractor_helper.h | 13 ++++++++ extractor/rule_factory.cc | 13 ++++++++ extractor/rule_factory.h | 20 +++++++++++++ extractor/run_extractor.cc | 27 +++++++++++++---- extractor/sampler.cc | 2 ++ extractor/sampler.h | 5 ++++ extractor/scorer.h | 5 ++++ extractor/suffix_array.h | 17 +++++++++++ extractor/target_phrase_extractor.cc | 10 +++++++ extractor/target_phrase_extractor.h | 4 +++ extractor/time_util.h | 1 + extractor/translation_table.cc | 17 +++++++---- extractor/translation_table.h | 8 ++++- extractor/vocabulary.h | 17 +++++++++++ 41 files changed, 383 insertions(+), 19 deletions(-) (limited to 'extractor/compile.cc') diff --git a/extractor/alignment.cc b/extractor/alignment.cc index f9bbcf6a..1aea34b3 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -28,8 +28,6 @@ Alignment::Alignment(const string& filename) { } alignments.push_back(alignment); } - // Note: shrink_to_fit does nothing for vector > on g++ 4.6.3, - // but let's hope that the bug will be fixed in a newer version. alignments.shrink_to_fit(); } diff --git a/extractor/alignment.h b/extractor/alignment.h index ef89dc0c..e9292121 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -11,12 +11,18 @@ using namespace std; namespace extractor { +/** + * Data structure storing the word alignments for a parallel corpus. + */ class Alignment { public: + // Reads alignment from text file. Alignment(const string& filename); + // Returns the alignment for a given sentence. virtual vector > GetLinks(int sentence_index) const; + // Writes alignment to file in binary format. void WriteBinary(const fs::path& filepath); virtual ~Alignment(); diff --git a/extractor/compile.cc b/extractor/compile.cc index 7062ef03..a9ae2cef 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -37,7 +37,7 @@ int main(int argc, char** argv) { ("max_phrase_len,p", po::value()->default_value(4), "Maximum frequent phrase length") ("min_frequency", po::value()->default_value(1000), - "Minimum number of occurences for a pharse to be considered frequent"); + "Minimum number of occurrences for a pharse to be considered frequent"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/extractor/data_array.h b/extractor/data_array.h index a26bbecf..978a6931 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -17,9 +17,19 @@ enum Side { TARGET }; -// Note: This class has features for both the source and target data arrays. -// Maybe we can save some memory by having more specific implementations (e.g. -// sentence_id is only needed for the source data array). +/** + * Data structure storing information about a single side of a parallel corpus. + * + * Each word is mapped to a unique integer (word_id). The data structure holds + * the corpus in the numberized format, together with the hash table mapping + * words to word_ids. It also holds additional information such as the starting + * index for each sentence and, for each token, the index of the sentence it + * belongs to. + * + * Note: This class has features for both the source and target data arrays. + * Maybe we can save some memory by having more specific implementations (not + * likely to save a lot of memory tough). + */ class DataArray { public: static int NULL_WORD; @@ -27,45 +37,65 @@ class DataArray { static string NULL_WORD_STR; static string END_OF_LINE_STR; + // Reads data array from text file. DataArray(const string& filename); + // Reads data array from bitext file where the sentences are separated by |||. DataArray(const string& filename, const Side& side); virtual ~DataArray(); + // Returns a vector containing the word ids. virtual const vector& GetData() const; + // Returns the word id at the specified position. virtual int AtIndex(int index) const; + // Returns the original word at the specified position. virtual string GetWordAtIndex(int index) const; + // Returns the size of the data array. virtual int GetSize() const; + // Returns the number of distinct words in the data array. virtual int GetVocabularySize() const; + // Returns whether a word has ever been observed in the data array. virtual bool HasWord(const string& word) const; + // Returns the word id for a given word or -1 if it the word has never been + // observed. virtual int GetWordId(const string& word) const; + // Returns the word corresponding to a particular word id. virtual string GetWord(int word_id) const; + // Returns the number of sentences in the data. virtual int GetNumSentences() const; + // Returns the index where the sentence containing the given position starts. virtual int GetSentenceStart(int position) const; + // Returns the length of the sentence. virtual int GetSentenceLength(int sentence_id) const; + // Returns the number of the sentence containing the given position. virtual int GetSentenceId(int position) const; + // Writes data array to file in binary format. void WriteBinary(const fs::path& filepath) const; + // Writes data array to file in binary format. void WriteBinary(FILE* file) const; protected: DataArray(); private: + // Sets up specific constants. void InitializeDataArray(); + + // Constructs the data array. void CreateDataArray(const vector& lines); unordered_map word2id; diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index 1b8c32b1..2a7693b2 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -107,6 +107,7 @@ PhraseLocation FastIntersector::ExtendPrefixPhraseLocation( } else { pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2; } + // Searches for the last symbol in the phrase after each prefix occurrence. for (int j = range.first; j < range.second; ++j) { if (pattern_end >= sent_end || pattern_end - positions[i] >= max_rule_span) { @@ -149,6 +150,8 @@ PhraseLocation FastIntersector::ExtendSuffixPhraseLocation( int pattern_start = positions[i] - range.first; int pattern_end = positions[i + num_subpatterns - 1] + phrase.GetChunkLen(phrase.Arity()) - 1; + // Searches for the first symbol in the phrase before each suffix + // occurrence. for (int j = range.first; j < range.second; ++j) { if (pattern_start < sent_start || pattern_end - pattern_start >= max_rule_span) { diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 32c88a30..f950a2a9 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -20,6 +20,18 @@ class Precomputation; class SuffixArray; class Vocabulary; +/** + * Component for searching the training data for occurrences of source phrases + * containing nonterminals + * + * Given a source phrase containing a nonterminal, we first query the + * precomputed index containing frequent collocations. If the phrase is not + * frequent enough, we extend the matchings of either its prefix or its suffix, + * depending on which operation seems to require less computations. + * + * Note: This method for intersecting phrase locations is faster than both + * mergers (linear or Baeza Yates) described in Adam Lopez' dissertation. + */ class FastIntersector { public: FastIntersector(shared_ptr suffix_array, @@ -30,6 +42,8 @@ class FastIntersector { virtual ~FastIntersector(); + // Finds the locations of a phrase given the locations of its prefix and + // suffix. virtual PhraseLocation Intersect(PhraseLocation& prefix_location, PhraseLocation& suffix_location, const Phrase& phrase); @@ -38,23 +52,36 @@ class FastIntersector { FastIntersector(); private: + // Uses the vocabulary to convert the phrase from the numberized format + // specified by the source data array to the numberized format given by the + // vocabulary. vector ConvertPhrase(const vector& old_phrase); + // Estimates the number of computations needed if the prefix/suffix is + // extended. If the last/first symbol is separated from the rest of the phrase + // by a nonterminal, then for each occurrence of the prefix/suffix we need to + // check max_rule_span positions. Otherwise, we only need to check a single + // position for each occurrence. int EstimateNumOperations(const PhraseLocation& phrase_location, bool has_margin_x) const; + // Uses the occurrences of the prefix to find the occurrences of the phrase. PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location, const Phrase& phrase, bool prefix_ends_with_x, int next_symbol) const; + // Uses the occurrences of the suffix to find the occurrences of the phrase. PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location, const Phrase& phrase, bool suffix_starts_with_x, int prev_symbol) const; + // Extends the prefix/suffix location to a list of subpatterns positions if it + // represents a suffix array range. void ExtendPhraseLocation(PhraseLocation& location) const; + // Returns the range in which the search should be performed. pair GetSearchRange(bool has_marginal_x) const; shared_ptr suffix_array; diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h index dec78883..8747fa60 100644 --- a/extractor/features/count_source_target.h +++ b/extractor/features/count_source_target.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Feature for the number of times a word pair was found in the bitext. + */ class CountSourceTarget : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/feature.h b/extractor/features/feature.h index 6693ccbf..36ea504a 100644 --- a/extractor/features/feature.h +++ b/extractor/features/feature.h @@ -10,6 +10,9 @@ using namespace std; namespace extractor { namespace features { +/** + * Structure providing context for computing feature scores. + */ struct FeatureContext { FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, double source_phrase_count, int pair_count, int num_samples) : @@ -24,6 +27,9 @@ struct FeatureContext { int num_samples; }; +/** + * Base class for features. + */ class Feature { public: virtual double Score(const FeatureContext& context) const = 0; diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h index 30f76c6d..b8352d0e 100644 --- a/extractor/features/is_source_singleton.h +++ b/extractor/features/is_source_singleton.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Boolean feature checking if the source phrase occurs only once in the data. + */ class IsSourceSingleton : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h index 12fb6ee6..dacfebba 100644 --- a/extractor/features/is_source_target_singleton.h +++ b/extractor/features/is_source_target_singleton.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Boolean feature checking if the phrase pair occurs only once in the data. + */ class IsSourceTargetSingleton : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h index bfa7ef1b..461b0ebf 100644 --- a/extractor/features/max_lex_source_given_target.h +++ b/extractor/features/max_lex_source_given_target.h @@ -13,6 +13,9 @@ class TranslationTable; namespace features { +/** + * Feature computing max(p(f | e)) across all pairs of words in the phrase pair. + */ class MaxLexSourceGivenTarget : public Feature { public: MaxLexSourceGivenTarget(shared_ptr table); diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h index 66cf0914..c3c87327 100644 --- a/extractor/features/max_lex_target_given_source.h +++ b/extractor/features/max_lex_target_given_source.h @@ -13,6 +13,9 @@ class TranslationTable; namespace features { +/** + * Feature computing max(p(e | f)) across all pairs of words in the phrase pair. + */ class MaxLexTargetGivenSource : public Feature { public: MaxLexTargetGivenSource(shared_ptr table); diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h index 53c7f954..ee6e59a0 100644 --- a/extractor/features/sample_source_count.h +++ b/extractor/features/sample_source_count.h @@ -6,6 +6,10 @@ namespace extractor { namespace features { +/** + * Feature scoring the number of times the source phrase occurs in the sampled + * set. + */ class SampleSourceCount : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h index 80d9f617..e66d70a5 100644 --- a/extractor/features/target_given_source_coherent.h +++ b/extractor/features/target_given_source_coherent.h @@ -6,6 +6,10 @@ namespace extractor { namespace features { +/** + * Feature computing the ratio of the phrase pair count over all source phrase + * occurrences (sampled). + */ class TargetGivenSourceCoherent : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/grammar.h b/extractor/grammar.h index a424d65a..fed41b16 100644 --- a/extractor/grammar.h +++ b/extractor/grammar.h @@ -11,6 +11,9 @@ namespace extractor { class Rule; +/** + * Grammar class wrapping the set of rules to be extracted. + */ class Grammar { public: Grammar(const vector& rules, const vector& feature_names); diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 6b1dcf98..b36ceeb9 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -19,6 +19,10 @@ class Scorer; class SuffixArray; class Vocabulary; +/** + * Class wrapping all the logic for extracting the synchronous context free + * grammars. + */ class GrammarExtractor { public: GrammarExtractor( @@ -38,11 +42,15 @@ class GrammarExtractor { GrammarExtractor(shared_ptr vocabulary, shared_ptr rule_factory); + // Converts the sentence to a vector of word ids and uses the RuleFactory to + // extract the SCFG rules which may be used to decode the sentence. Grammar GetGrammar(const string& sentence); private: + // Splits the sentence in a vector of words. vector TokenizeSentence(const string& sentence); + // Maps the words to word ids. vector AnnotateWords(const vector& words); shared_ptr vocabulary; diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h index fbb504ef..451f4a4c 100644 --- a/extractor/matchings_finder.h +++ b/extractor/matchings_finder.h @@ -11,12 +11,17 @@ namespace extractor { class PhraseLocation; class SuffixArray; +/** + * Class wrapping the suffix array lookup for a contiguous phrase. + */ class MatchingsFinder { public: MatchingsFinder(shared_ptr suffix_array); virtual ~MatchingsFinder(); + // Uses the suffix array to search only for the last word of the phrase + // starting from the range in which the prefix of the phrase occurs. virtual PhraseLocation Find(PhraseLocation& location, const string& word, int offset); diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index f3dcc075..1fb29693 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -11,20 +11,27 @@ using namespace std; namespace extractor { +/** + * Trie node containing all the occurrences of the corresponding phrase in the + * source data. + */ struct TrieNode { TrieNode(shared_ptr suffix_link = shared_ptr(), Phrase phrase = Phrase(), PhraseLocation matchings = PhraseLocation()) : suffix_link(suffix_link), phrase(phrase), matchings(matchings) {} + // Adds a trie node as a child of the current node. void AddChild(int key, shared_ptr child_node) { children[key] = child_node; } + // Checks if a child exists for a given key. bool HasChild(int key) { return children.count(key); } + // Gets the child corresponding to the given key. shared_ptr GetChild(int key) { return children[key]; } @@ -35,15 +42,20 @@ struct TrieNode { unordered_map > children; }; +/** + * Trie containing all the phrases that can be obtained from a sentence. + */ class MatchingsTrie { public: MatchingsTrie(); virtual ~MatchingsTrie(); + // Returns the root of the trie. shared_ptr GetRoot() const; private: + // Recursively deletes a subtree of the trie. void DeleteTree(shared_ptr root); shared_ptr root; diff --git a/extractor/phrase.h b/extractor/phrase.h index 6521c438..a8e91e3c 100644 --- a/extractor/phrase.h +++ b/extractor/phrase.h @@ -11,20 +11,30 @@ using namespace std; namespace extractor { +/** + * Structure containing the data for a phrase. + */ class Phrase { public: friend Phrase PhraseBuilder::Build(const vector& phrase); + // Returns the number of nonterminals in the phrase. int Arity() const; + // Returns the number of terminals (length) for the given chunk. (A chunk is a + // contiguous sequence of terminals in the phrase). int GetChunkLen(int index) const; + // Returns the symbols (word ids) marking up the phrase. vector Get() const; + // Returns the symbol located at the given position in the phrase. int GetSymbol(int position) const; + // Returns the number of symbols in the phrase. int GetNumSymbols() const; + // Returns the words making up the phrase. (Nonterminals are stripped out.) vector GetWords() const; bool operator<(const Phrase& other) const; diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h index 2956fd35..de86dbae 100644 --- a/extractor/phrase_builder.h +++ b/extractor/phrase_builder.h @@ -11,12 +11,17 @@ namespace extractor { class Phrase; class Vocabulary; +/** + * Component for constructing phrases. + */ class PhraseBuilder { public: PhraseBuilder(shared_ptr vocabulary); + // Constructs a phrase starting from an array of symbols. Phrase Build(const vector& symbols); + // Extends a phrase with a leading and/or trailing nonterminal. Phrase Extend(const Phrase& phrase, bool start_x, bool end_x); private: diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h index e5f3cf08..91950e03 100644 --- a/extractor/phrase_location.h +++ b/extractor/phrase_location.h @@ -8,13 +8,25 @@ using namespace std; namespace extractor { +/** + * Structure containing information about the occurrences of a phrase in the + * source data. + * + * Every consecutive (disjoint) group of num_subpatterns entries in matchings + * vector encodes an occurrence of the phrase. The i-th entry of a group + * represents the start of the i-th subpattern of the phrase. If the phrase + * doesn't contain any nonterminals, then it may also be represented as the + * range in the suffix array which matches the phrase. + */ struct PhraseLocation { PhraseLocation(int sa_low = -1, int sa_high = -1); PhraseLocation(const vector& matchings, int num_subpatterns); + // Checks if a phrase has any occurrences in the source data. bool IsEmpty() const; + // Returns the number of occurrences of a phrase in the source data. int GetSize() const; friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 0fadc95c..b3906943 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -23,6 +23,8 @@ Precomputation::Precomputation( suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, min_frequency); + // Construct sets containing the frequent and superfrequent contiguous + // collocations. unordered_set, VectorHash> frequent_patterns_set; unordered_set, VectorHash> super_frequent_patterns_set; for (size_t i = 0; i < frequent_patterns.size(); ++i) { @@ -34,6 +36,8 @@ Precomputation::Precomputation( vector > matchings; for (size_t i = 0; i < data.size(); ++i) { + // If the sentence is over, add all the discontiguous frequent patterns to + // the index. if (data[i] == DataArray::END_OF_LINE) { AddCollocations(matchings, data, max_rule_span, min_gap_size, max_rule_symbols); @@ -41,6 +45,7 @@ Precomputation::Precomputation( continue; } vector pattern; + // Find all the contiguous frequent patterns starting at position i. for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { pattern.push_back(data[i + j - 1]); if (frequent_patterns_set.count(pattern)) { @@ -65,6 +70,7 @@ vector > Precomputation::FindMostFrequentPatterns( vector lcp = suffix_array->BuildLCPArray(); vector run_start(max_frequent_phrase_len); + // Find all the patterns occurring at least min_frequency times. priority_queue > > heap; for (size_t i = 1; i < lcp.size(); ++i) { for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { @@ -77,6 +83,7 @@ vector > Precomputation::FindMostFrequentPatterns( } } + // Extract the most frequent patterns. vector > frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { int start = heap.top().second.first; @@ -95,10 +102,12 @@ vector > Precomputation::FindMostFrequentPatterns( void Precomputation::AddCollocations( const vector >& matchings, const vector& data, int max_rule_span, int min_gap_size, int max_rule_symbols) { + // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { int start1, size1, is_super1; tie(start1, size1, is_super1) = matchings[i]; + // Select the second (middle) subpattern for (size_t j = i + 1; j < matchings.size(); ++j) { int start2, size2, is_super2; tie(start2, size2, is_super2) = matchings[j]; @@ -116,8 +125,10 @@ void Precomputation::AddCollocations( data.begin() + start2 + size2); AddStartPositions(collocations[pattern], start1, start2); + // Try extending the binary collocation to a ternary collocation. if (is_super2) { pattern.push_back(Precomputation::SECOND_NONTERMINAL); + // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; tie(start3, size3, is_super3) = matchings[k]; diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 2c1eccf8..e3c4d26a 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -20,8 +20,19 @@ typedef unordered_map, vector, VectorHash> Index; class SuffixArray; +/** + * Data structure wrapping an index with all the occurrences of the most + * frequent discontiguous collocations in the source data. + * + * Let a, b, c be contiguous collocations. The index will contain an entry for + * every collocation of the form: + * - aXb, where a and b are frequent + * - aXbXc, where a and b are super-frequent and c is frequent or + * b and c are super-frequent and a is frequent. + */ class Precomputation { public: + // Constructs the index using the suffix array. Precomputation( shared_ptr suffix_array, int num_frequent_patterns, int num_super_frequent_patterns, int max_rule_span, @@ -32,6 +43,7 @@ class Precomputation { void WriteBinary(const fs::path& filepath) const; + // Returns a reference to the index. virtual const Index& GetCollocations() const; static int FIRST_NONTERMINAL; @@ -41,14 +53,23 @@ class Precomputation { Precomputation(); private: + // Finds the most frequent contiguous collocations. vector > FindMostFrequentPatterns( shared_ptr suffix_array, const vector& data, int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency); + + // Given the locations of the frequent contiguous collocations in a sentence, + // it adds new entries to the index for each discontiguous collocation + // matching the criteria specified in the class description. void AddCollocations( const vector >& matchings, const vector& data, int max_rule_span, int min_gap_size, int max_rule_symbols); + + // Adds an occurrence of a binary collocation. void AddStartPositions(vector& positions, int pos1, int pos2); + + // Adds an occurrence of a ternary collocation. void AddStartPositions(vector& positions, int pos1, int pos2, int pos3); Index collocations; diff --git a/extractor/rule.h b/extractor/rule.h index b4d45fc1..bc95709e 100644 --- a/extractor/rule.h +++ b/extractor/rule.h @@ -9,6 +9,9 @@ using namespace std; namespace extractor { +/** + * Structure containing the data for a SCFG rule. + */ struct Rule { Rule(const Phrase& source_phrase, const Phrase& target_phrase, const vector& scores, const vector >& alignment); diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc index b9286472..9f5e8e00 100644 --- a/extractor/rule_extractor.cc +++ b/extractor/rule_extractor.cc @@ -79,6 +79,7 @@ vector RuleExtractor::ExtractRules(const Phrase& phrase, int num_subpatterns = location.num_subpatterns; vector matchings = *location.matchings; + // Calculate statistics for the (sampled) occurrences of the source phrase. map source_phrase_counter; map > > alignments_counter; for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { @@ -91,6 +92,8 @@ vector RuleExtractor::ExtractRules(const Phrase& phrase, } } + // Compute the feature scores and find the most likely (frequent) alignment + // for each pair of source-target phrases. int num_samples = matchings.size() / num_subpatterns; vector rules; for (auto source_phrase_entry: alignments_counter) { @@ -124,6 +127,8 @@ vector RuleExtractor::ExtractAlignments( int sentence_id = source_data_array->GetSentenceId(matching[0]); int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + // Get the span in the opposite sentence for each word in the source-target + // sentece pair. vector source_low, source_high, target_low, target_high; helper->GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id); @@ -134,6 +139,7 @@ vector RuleExtractor::ExtractAlignments( chunklen[i] = phrase.GetChunkLen(i); } + // Basic checks to see if we can extract phrase pairs for this occurrence. if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) || !helper->CheckTightPhrases(matching, chunklen, source_low)) { return extracts; @@ -144,6 +150,7 @@ vector RuleExtractor::ExtractAlignments( int source_phrase_high = matching.back() + chunklen.back() - source_sent_start; int target_phrase_low = -1, target_phrase_high = -1; + // Find target span and reflected source span for the source phrase. if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, source_high, target_phrase_low, target_phrase_high, target_low, target_high, source_back_low, @@ -153,6 +160,7 @@ vector RuleExtractor::ExtractAlignments( return extracts; } + // Get spans for nonterminal gaps. bool met_constraints = true; int num_symbols = phrase.GetNumSymbols(); vector > source_gaps, target_gaps; @@ -163,6 +171,7 @@ vector RuleExtractor::ExtractAlignments( return extracts; } + // Find target phrases aligned with the initial source phrase. bool starts_with_x = source_back_low != source_phrase_low; bool ends_with_x = source_back_high != source_phrase_high; Phrase source_phrase = phrase_builder->Extend( @@ -181,6 +190,8 @@ vector RuleExtractor::ExtractAlignments( return extracts; } + // Extend the source phrase by adding a leading and/or trailing nonterminal + // and find target phrases aligned with the extended source phrase. for (int i = 0; i < 2; ++i) { for (int j = 1 - i; j < 2; ++j) { AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, @@ -203,6 +214,8 @@ void RuleExtractor::AddExtracts( source_indexes, sentence_id); if (target_phrases.size() > 0) { + // Split the probability equally across all target phrases that can be + // aligned with a single occurrence of the source phrase. double pairs_count = 1.0 / target_phrases.size(); for (auto target_phrase: target_phrases) { extracts.push_back(Extract(source_phrase, target_phrase.first, @@ -221,6 +234,7 @@ void RuleExtractor::AddNonterminalExtremities( int extend_right) const { int source_x_low = source_back_low, source_x_high = source_back_high; + // Check if the extended source phrase will remain tight. if (require_tight_phrases) { if (source_low[source_back_low - extend_left] == -1 || source_low[source_back_high + extend_right - 1] == -1) { @@ -228,6 +242,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // Check if we can add a nonterminal to the left. if (extend_left) { if (starts_with_x || source_back_low < min_gap_size) { return; @@ -244,6 +259,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // Check if we can add a nonterminal to the right. if (extend_right) { int source_sent_len = source_data_array->GetSentenceLength(sentence_id); if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { @@ -262,6 +278,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // More length checks. int new_nonterminals = extend_left + extend_right; if (source_x_high - source_x_low > max_rule_span || target_gaps.size() + new_nonterminals > max_nonterminals || @@ -269,6 +286,7 @@ void RuleExtractor::AddNonterminalExtremities( return; } + // Find the target span for the extended phrase and the reflected source span. int target_x_low = -1, target_x_high = -1; if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, source_high, target_x_low, target_x_high, @@ -279,6 +297,7 @@ void RuleExtractor::AddNonterminalExtremities( return; } + // Check gap integrity for the leading nonterminal. if (extend_left) { int source_gap_low = -1, source_gap_high = -1; int target_gap_low = -1, target_gap_high = -1; @@ -294,6 +313,7 @@ void RuleExtractor::AddNonterminalExtremities( make_pair(target_gap_low, target_gap_high)); } + // Check gap integrity for the trailing nonterminal. if (extend_right) { int target_gap_low = -1, target_gap_high = -1; int source_gap_low = -1, source_gap_high = -1; @@ -308,6 +328,7 @@ void RuleExtractor::AddNonterminalExtremities( target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); } + // Find target phrases aligned with the extended source phrase. Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, extend_right); unordered_map source_indexes = helper->GetSourceIndexes( diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h index 8b6daeea..bfec0225 100644 --- a/extractor/rule_extractor.h +++ b/extractor/rule_extractor.h @@ -22,6 +22,10 @@ class RuleExtractorHelper; class Scorer; class TargetPhraseExtractor; +/** + * Structure containing data about the occurrences of a source-target phrase pair + * in the parallel corpus. + */ struct Extract { Extract(const Phrase& source_phrase, const Phrase& target_phrase, double pairs_count, const PhraseAlignment& alignment) : @@ -34,6 +38,9 @@ struct Extract { PhraseAlignment alignment; }; +/** + * Component for extracting SCFG rules. + */ class RuleExtractor { public: RuleExtractor(shared_ptr source_data_array, @@ -64,6 +71,8 @@ class RuleExtractor { virtual ~RuleExtractor(); + // Extracts SCFG rules given a source phrase and a set of its occurrences + // in the source data. virtual vector ExtractRules(const Phrase& phrase, const PhraseLocation& location) const; @@ -71,15 +80,22 @@ class RuleExtractor { RuleExtractor(); private: + // Finds all target phrases that can be aligned with the source phrase for a + // particular occurrence in the data. vector ExtractAlignments(const Phrase& phrase, const vector& matching) const; + // Extracts all target phrases for a given occurrence of the source phrase in + // the data. Constructs a vector of Extracts using these target phrases. void AddExtracts( vector& extracts, const Phrase& source_phrase, const unordered_map& source_indexes, const vector >& target_gaps, const vector& target_low, int target_phrase_low, int target_phrase_high, int sentence_id) const; + // Adds a leading and/or trailing nonterminal to the source phrase and + // extracts target phrases that can be aligned with the extended source + // phrase. void AddNonterminalExtremities( vector& extracts, const vector& matching, const vector& chunklen, const Phrase& source_phrase, diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc index 81b522f0..6410d147 100644 --- a/extractor/rule_extractor_helper.cc +++ b/extractor/rule_extractor_helper.cc @@ -88,6 +88,7 @@ bool RuleExtractorHelper::CheckTightPhrases( return true; } + // Check if the chunk extremities are aligned. int sentence_id = source_data_array->GetSentenceId(matching[0]); int source_sent_start = source_data_array->GetSentenceStart(sentence_id); for (size_t i = 0; i + 1 < chunklen.size(); ++i) { @@ -126,6 +127,7 @@ bool RuleExtractorHelper::FindFixPoint( int source_sent_len = source_data_array->GetSentenceLength(sentence_id); int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + // Extend the target span to the left. if (prev_target_low != -1 && target_phrase_low != prev_target_low) { if (prev_target_low - target_phrase_low < min_target_gap_size) { target_phrase_low = prev_target_low - min_target_gap_size; @@ -135,6 +137,7 @@ bool RuleExtractorHelper::FindFixPoint( } } + // Extend the target span to the right. if (prev_target_high != -1 && target_phrase_high != prev_target_high) { if (target_phrase_high - prev_target_high < min_target_gap_size) { target_phrase_high = prev_target_high + min_target_gap_size; @@ -144,10 +147,12 @@ bool RuleExtractorHelper::FindFixPoint( } } + // Check target span length. if (target_phrase_high - target_phrase_low > max_rule_span) { return false; } + // Find the initial reflected source span. source_back_low = source_back_high = -1; FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, source_back_low, source_back_high); @@ -157,6 +162,7 @@ bool RuleExtractorHelper::FindFixPoint( source_back_low = min(source_back_low, source_phrase_low); source_back_high = max(source_back_high, source_phrase_high); + // Stop if the reflected source span matches the previous source span. if (source_back_low == source_phrase_low && source_back_high == source_phrase_high) { return true; @@ -212,10 +218,14 @@ bool RuleExtractorHelper::FindFixPoint( prev_target_low = target_phrase_low; prev_target_high = target_phrase_high; + // Find the reflection including the left gap (if one was added). FindProjection(source_back_low, source_phrase_low, source_low, source_high, target_phrase_low, target_phrase_high); + // Find the reflection including the right gap (if one was added). FindProjection(source_phrase_high, source_back_high, source_low, source_high, target_phrase_low, target_phrase_high); + // Stop if the new re-reflected target span matches the previous target + // span. if (prev_target_low == target_phrase_low && prev_target_high == target_phrase_high) { return true; @@ -232,6 +242,7 @@ bool RuleExtractorHelper::FindFixPoint( source_phrase_low = source_back_low; source_phrase_high = source_back_high; + // Re-reflect the target span. FindProjection(target_phrase_low, prev_target_low, target_low, target_high, source_back_low, source_back_high); FindProjection(prev_target_high, target_phrase_high, target_low, diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h index 7bf80c4b..bea75bc3 100644 --- a/extractor/rule_extractor_helper.h +++ b/extractor/rule_extractor_helper.h @@ -12,6 +12,9 @@ namespace extractor { class Alignment; class DataArray; +/** + * Helper class for extracting SCFG rules. + */ class RuleExtractorHelper { public: RuleExtractorHelper(shared_ptr source_data_array, @@ -25,18 +28,23 @@ class RuleExtractorHelper { virtual ~RuleExtractorHelper(); + // Find the alignment span for each word in the source target sentence pair. virtual void GetLinksSpans(vector& source_low, vector& source_high, vector& target_low, vector& target_high, int sentence_id) const; + // Check if one chunk (all chunks) is aligned at least in one point. virtual bool CheckAlignedTerminals(const vector& matching, const vector& chunklen, const vector& source_low) const; + // Check if the chunks are tight. virtual bool CheckTightPhrases(const vector& matching, const vector& chunklen, const vector& source_low) const; + // Find the target span and the reflected source span for a source phrase + // occurrence. virtual bool FindFixPoint( int source_phrase_low, int source_phrase_high, const vector& source_low, const vector& source_high, @@ -47,6 +55,7 @@ class RuleExtractorHelper { int max_new_x, bool allow_low_x, bool allow_high_x, bool allow_arbitrary_expansion) const; + // Find the gap spans for each nonterminal in the source phrase. virtual bool GetGaps( vector >& source_gaps, vector >& target_gaps, const vector& matching, const vector& chunklen, @@ -55,8 +64,10 @@ class RuleExtractorHelper { int source_phrase_low, int source_phrase_high, int source_back_low, int source_back_high, int& num_symbols, bool& met_constraints) const; + // Get the order of the nonterminals in the target phrase. virtual vector GetGapOrder(const vector >& gaps) const; + // Map each terminal symbol with its position in the source phrase. virtual unordered_map GetSourceIndexes( const vector& matching, const vector& chunklen, int starts_with_x) const; @@ -65,6 +76,8 @@ class RuleExtractorHelper { RuleExtractorHelper(); private: + // Find the projection of a source phrase in the target sentence. May also be + // used to find the projection of a target phrase in the source sentence. void FindProjection( int source_phrase_low, int source_phrase_high, const vector& source_low, const vector& source_high, diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index fbc62e50..8c30fb9e 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -152,12 +152,18 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { + // For phrases containing a nonterminal, we use either the occurrences + // of the prefix or the suffix to determine the occurrences of the + // phrase. Clock::time_point intersect_start = Clock::now(); phrase_location = fast_intersector->Intersect( node->matchings, next_suffix_link->matchings, next_phrase); Clock::time_point intersect_stop = Clock::now(); total_intersect_time += GetDuration(intersect_start, intersect_stop); } else { + // For phrases not containing any nonterminals, we simply query the + // suffix array using the suffix array range of the prefix as a + // starting point. Clock::time_point lookup_start = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, @@ -170,9 +176,12 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { if (phrase_location.IsEmpty()) { continue; } + + // Create new trie node to store data about the current phrase. next_node = make_shared( next_suffix_link, next_phrase, phrase_location); } + // Add the new trie node to the trie cache. node->AddChild(word_id, next_node); // Automatically adds a trailing non terminal if allowed. Simply copy the @@ -182,6 +191,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { + // Extract rules for the sampled set of occurrences. PhraseLocation sample = sampler->Sample(next_node->matchings); vector new_rules = rule_extractor->ExtractRules(next_phrase, sample); @@ -193,6 +203,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { next_node = node->GetChild(word_id); } + // Create more states (phrases) to be analyzed. vector new_states = ExtendState(word_ids, state, phrase, next_phrase, next_node); for (State new_state: new_states) { @@ -262,6 +273,7 @@ vector HieroCachingRuleFactory::ExtendState( return new_states; } + // New state for adding the next symbol. new_states.push_back(State(state.start, state.end + 1, symbols, state.subpatterns_start, node, state.starts_with_x)); @@ -272,6 +284,7 @@ vector HieroCachingRuleFactory::ExtendState( return new_states; } + // New states for adding a nonterminal followed by a new symbol. int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); symbols.push_back(var_id); vector subpatterns_start = state.subpatterns_start; diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index d8dc2ccc..52e8712a 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -25,6 +25,17 @@ class State; class SuffixArray; class Vocabulary; +/** + * Component containing most of the logic for extracting SCFG rules for a given + * sentence. + * + * Given a sentence (as a vector of word ids), this class constructs all the + * possible source phrases starting from this sentence. For each source phrase, + * it finds all its occurrences in the source data and samples some of these + * occurrences to extract aligned source-target phrase pairs. A trie cache is + * used to avoid unnecessary computations if a source phrase can be constructed + * more than once (e.g. some words occur more than once in the sentence). + */ class HieroCachingRuleFactory { public: HieroCachingRuleFactory( @@ -58,21 +69,30 @@ class HieroCachingRuleFactory { virtual ~HieroCachingRuleFactory(); + // Constructs SCFG rules for a given sentence. + // (See class description for more details.) virtual Grammar GetGrammar(const vector& word_ids); protected: HieroCachingRuleFactory(); private: + // Checks if the phrase (if previously encountered) or its prefix have any + // occurrences in the source data. bool CannotHaveMatchings(shared_ptr node, int word_id); + // Checks if the phrase has previously been analyzed. bool RequiresLookup(shared_ptr node, int word_id); + // Creates a new state in the trie that corresponds to adding a trailing + // nonterminal to the current phrase. void AddTrailingNonterminal(vector symbols, const Phrase& prefix, const shared_ptr& prefix_node, bool starts_with_x); + // Extends the current state by possibly adding a nonterminal followed by a + // terminal. vector ExtendState(const vector& word_ids, const State& state, vector symbols, diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index dba4578c..d5ff23b2 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -35,6 +35,7 @@ using namespace std; using namespace extractor; using namespace features; +// Returns the file path in which a given grammar should be written. fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { string file_name = "grammar." + to_string(file_number); return grammar_path / file_name; @@ -45,6 +46,7 @@ int main(int argc, char** argv) { #pragma omp parallel num_threads_default = omp_get_num_threads(); + // Sets up the command line arguments map. po::options_description desc("Command line options"); desc.add_options() ("help,h", "Show available options") @@ -69,7 +71,7 @@ int main(int argc, char** argv) { ("max_nonterminals", po::value()->default_value(2), "Maximum number of nonterminals in a rule") ("min_frequency", po::value()->default_value(1000), - "Minimum number of occurences for a pharse to be considered frequent") + "Minimum number of occurrences for a pharse to be considered frequent") ("max_samples", po::value()->default_value(300), "Maximum number of samples") ("tight_phrases", po::value()->default_value(true), @@ -78,8 +80,8 @@ int main(int argc, char** argv) { po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); - // Check for help argument before notify, so we don't need to pass in the - // required parameters. + // Checks for the help option before calling notify, so the we don't get an + // exception for missing required arguments. if (vm.count("help")) { cout << desc << endl; return 0; @@ -94,6 +96,7 @@ int main(int argc, char** argv) { return 1; } + // Reads the parallel corpus. Clock::time_point preprocess_start_time = Clock::now(); cerr << "Reading source and target data..." << endl; Clock::time_point start_time = Clock::now(); @@ -111,6 +114,7 @@ int main(int argc, char** argv) { cerr << "Reading data took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs the suffix array for the source data. cerr << "Creating source suffix array..." << endl; start_time = Clock::now(); shared_ptr source_suffix_array = @@ -119,6 +123,7 @@ int main(int argc, char** argv) { cerr << "Creating suffix array took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Reads the alignment. cerr << "Reading alignment..." << endl; start_time = Clock::now(); shared_ptr alignment = @@ -127,6 +132,8 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs an index storing the occurrences in the source data for each + // frequent collocation. cerr << "Precomputing collocations..." << endl; start_time = Clock::now(); shared_ptr precomputation = make_shared( @@ -142,6 +149,8 @@ int main(int argc, char** argv) { cerr << "Precomputing collocations took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs a table storing p(e | f) and p(f | e) for every pair of source + // and target words. cerr << "Precomputing conditional probabilities..." << endl; start_time = Clock::now(); shared_ptr table = make_shared( @@ -155,6 +164,7 @@ int main(int argc, char** argv) { << GetDuration(preprocess_start_time, preprocess_stop_time) << " seconds" << endl; + // Features used to score each grammar rule. Clock::time_point extraction_start_time = Clock::now(); vector > features = { make_shared(), @@ -167,6 +177,7 @@ int main(int argc, char** argv) { }; shared_ptr scorer = make_shared(features); + // Sets up the grammar extractor. GrammarExtractor extractor( source_suffix_array, target_data_array, @@ -180,26 +191,30 @@ int main(int argc, char** argv) { vm["max_samples"].as(), vm["tight_phrases"].as()); - // Release extra memory used by the initial precomputation. + // Releases extra memory used by the initial precomputation. precomputation.reset(); + // Creates the grammars directory if it doesn't exist. fs::path grammar_path = vm["grammars"].as(); if (!fs::is_directory(grammar_path)) { fs::create_directory(grammar_path); } + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). string sentence; vector sentences; while (getline(cin, sentence)) { sentences.push_back(sentence); } + // Extracts the grammar for each sentence and saves it to a file. vector suffixes(sentences.size()); #pragma omp parallel for schedule(dynamic) \ num_threads(vm["threads"].as()) for (size_t i = 0; i < sentences.size(); ++i) { - string delimiter = "|||", suffix; - int position = sentences[i].find(delimiter); + string suffix; + int position = sentences[i].find("|||"); if (position != sentences[i].npos) { suffix = sentences[i].substr(position); sentences[i] = sentences[i].substr(0, position); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index f64a408c..d81956b5 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -16,6 +16,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { vector sample; int num_subpatterns; if (location.matchings == NULL) { + // Sample suffix array range. num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); @@ -23,6 +24,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { sample.push_back(suffix_array->GetSuffix(Round(i))); } } else { + // Sample vector of occurrences. num_subpatterns = location.num_subpatterns; int num_matchings = location.matchings->size() / num_subpatterns; double step = max(1.0, (double) num_matchings / max_samples); diff --git a/extractor/sampler.h b/extractor/sampler.h index cda28b10..be4aa1bb 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -10,18 +10,23 @@ namespace extractor { class PhraseLocation; class SuffixArray; +/** + * Provides uniform sampling for a PhraseLocation. + */ class Sampler { public: Sampler(shared_ptr suffix_array, int max_samples); virtual ~Sampler(); + // Samples uniformly at most max_samples phrase occurrences. virtual PhraseLocation Sample(const PhraseLocation& location) const; protected: Sampler(); private: + // Round floating point number to the nearest integer. int Round(double x) const; shared_ptr suffix_array; diff --git a/extractor/scorer.h b/extractor/scorer.h index c31db0ca..af8a3b10 100644 --- a/extractor/scorer.h +++ b/extractor/scorer.h @@ -14,14 +14,19 @@ namespace features { class FeatureContext; } // namespace features +/** + * Computes the feature scores for a source-target phrase pair. + */ class Scorer { public: Scorer(const vector >& features); virtual ~Scorer(); + // Computes the feature score for the given context. virtual vector Score(const features::FeatureContext& context) const; + // Returns the set of feature names used to score any context. virtual vector GetFeatureNames() const; protected: diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index 7a4f1110..bf731d79 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -17,18 +17,26 @@ class PhraseLocation; class SuffixArray { public: + // Creates a suffix array from a data array. SuffixArray(shared_ptr data_array); virtual ~SuffixArray(); + // Returns the size of the suffix array. virtual int GetSize() const; + // Returns the data array on top of which the suffix array is constructed. virtual shared_ptr GetData() const; + // Constructs the longest-common-prefix array using the algorithm of Kasai et + // al. (2001). virtual vector BuildLCPArray() const; + // Returns the i-th suffix. virtual int GetSuffix(int rank) const; + // Given the range in which a phrase is located and the next word, returns the + // range corresponding to the phrase extended with the next word. virtual PhraseLocation Lookup(int low, int high, const string& word, int offset) const; @@ -38,14 +46,23 @@ class SuffixArray { SuffixArray(); private: + // Constructs the suffix array using the algorithm of Larsson and Sadakane + // (1999). void BuildSuffixArray(); + // Bucket sort on the data array (used for initializing the construction of + // the suffix array.) void InitialBucketSort(vector& groups); void TernaryQuicksort(int left, int right, int step, vector& groups); + // Constructs the suffix array in log(n) steps by doubling the length of the + // suffixes at each step. void PrefixDoublingSort(vector& groups); + // Given a [low, high) range in the suffix array in which all elements have + // the first offset-1 values the same, it returns the first position where the + // offset value is greater or equal to word_id. int LookupRangeStart(int low, int high, int word_id, int offset) const; shared_ptr data_array; diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc index 9f8bc6e2..2b8a2e4a 100644 --- a/extractor/target_phrase_extractor.cc +++ b/extractor/target_phrase_extractor.cc @@ -43,11 +43,13 @@ vector > TargetPhraseExtractor::ExtractPhrases( int target_x_low = target_phrase_low, target_x_high = target_phrase_high; if (!require_tight_phrases) { + // Extend loose target phrase to the left. while (target_x_low > 0 && target_phrase_high - target_x_low < max_rule_span && target_low[target_x_low - 1] == -1) { --target_x_low; } + // Extend loose target phrase to the right. while (target_x_high < target_sent_len && target_x_high - target_phrase_low < max_rule_span && target_low[target_x_high] == -1) { @@ -59,10 +61,12 @@ vector > TargetPhraseExtractor::ExtractPhrases( for (size_t i = 0; i < gaps.size(); ++i) { gaps[i] = target_gaps[target_gap_order[i]]; if (!require_tight_phrases) { + // Extend gap to the left. while (gaps[i].first > target_x_low && target_low[gaps[i].first - 1] == -1) { --gaps[i].first; } + // Extend gap to the right. while (gaps[i].second < target_x_high && target_low[gaps[i].second] == -1) { ++gaps[i].second; @@ -70,6 +74,9 @@ vector > TargetPhraseExtractor::ExtractPhrases( } } + // Compute the range in which each chunk may start or end. (Even indexes + // represent the range in which the chunk may start, odd indexes represent the + // range in which the chunk may end.) vector > ranges(2 * gaps.size() + 2); ranges.front() = make_pair(target_x_low, target_phrase_low); ranges.back() = make_pair(target_phrase_high, target_x_high); @@ -101,6 +108,7 @@ void TargetPhraseExtractor::GeneratePhrases( vector symbols; unordered_map target_indexes; + // Construct target phrase chunk by chunk. int target_sent_start = target_data_array->GetSentenceStart(sentence_id); for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { @@ -115,6 +123,7 @@ void TargetPhraseExtractor::GeneratePhrases( } } + // Construct the alignment between the source and the target phrase. vector > links = alignment->GetLinks(sentence_id); vector > alignment; for (pair link: links) { @@ -133,6 +142,7 @@ void TargetPhraseExtractor::GeneratePhrases( if (index > 0) { subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); } + // Choose every possible combination of [start, end) for the current chunk. while (subpatterns[index] <= ranges[index].second) { subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); while (subpatterns[index + 1] <= ranges[index + 1].second) { diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h index a4b54145..289bae2f 100644 --- a/extractor/target_phrase_extractor.h +++ b/extractor/target_phrase_extractor.h @@ -30,6 +30,8 @@ class TargetPhraseExtractor { virtual ~TargetPhraseExtractor(); + // Finds all the target phrases that can extracted from a span in the + // target sentence (matching the given set of target phrase gaps). virtual vector > ExtractPhrases( const vector >& target_gaps, const vector& target_low, int target_phrase_low, int target_phrase_high, @@ -39,6 +41,8 @@ class TargetPhraseExtractor { TargetPhraseExtractor(); private: + // Computes the cartesian product over the sets of possible target phrase + // chunks. void GeneratePhrases( vector >& target_phrases, const vector >& ranges, int index, diff --git a/extractor/time_util.h b/extractor/time_util.h index 45f79199..f7fd51d3 100644 --- a/extractor/time_util.h +++ b/extractor/time_util.h @@ -10,6 +10,7 @@ namespace extractor { typedef high_resolution_clock Clock; +// Computes the duration in seconds of the specified time interval. double GetDuration(const Clock::time_point& start_time, const Clock::time_point& stop_time); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 1852a357..45da707a 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -23,6 +23,8 @@ TranslationTable::TranslationTable(shared_ptr source_data_array, unordered_map target_links_count; unordered_map, int, PairHash> links_count; + // For each pair of aligned source target words increment their link count by + // 1. Unaligned words are paired with the NULL token. for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { vector > links = alignment->GetLinks(i); int source_start = source_data_array->GetSentenceStart(i); @@ -40,25 +42,28 @@ TranslationTable::TranslationTable(shared_ptr source_data_array, for (pair link: links) { source_linked_words[link.first] = 1; target_linked_words[link.second] = 1; - IncreaseLinksCount(source_links_count, target_links_count, links_count, + IncrementLinksCount(source_links_count, target_links_count, links_count, source_sentence[link.first], target_sentence[link.second]); } for (size_t i = 0; i < source_sentence.size(); ++i) { if (!source_linked_words[i]) { - IncreaseLinksCount(source_links_count, target_links_count, links_count, - source_sentence[i], DataArray::NULL_WORD); + IncrementLinksCount(source_links_count, target_links_count, links_count, + source_sentence[i], DataArray::NULL_WORD); } } for (size_t i = 0; i < target_sentence.size(); ++i) { if (!target_linked_words[i]) { - IncreaseLinksCount(source_links_count, target_links_count, links_count, - DataArray::NULL_WORD, target_sentence[i]); + IncrementLinksCount(source_links_count, target_links_count, links_count, + DataArray::NULL_WORD, target_sentence[i]); } } } + // Calculating: + // p(e | f) = count(e, f) / count(f) + // p(f | e) = count(e, f) / count(e) for (pair, int> link_count: links_count) { int source_word = link_count.first.first; int target_word = link_count.first.second; @@ -72,7 +77,7 @@ TranslationTable::TranslationTable() {} TranslationTable::~TranslationTable() {} -void TranslationTable::IncreaseLinksCount( +void TranslationTable::IncrementLinksCount( unordered_map& source_links_count, unordered_map& target_links_count, unordered_map, int, PairHash>& links_count, diff --git a/extractor/translation_table.h b/extractor/translation_table.h index a7be26f5..10504d3b 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -18,6 +18,9 @@ typedef boost::hash > PairHash; class Alignment; class DataArray; +/** + * Bilexical table with conditional probabilities. + */ class TranslationTable { public: TranslationTable( @@ -27,9 +30,11 @@ class TranslationTable { virtual ~TranslationTable(); + // Returns p(e | f). virtual double GetTargetGivenSourceScore(const string& source_word, const string& target_word); + // Returns p(f | e). virtual double GetSourceGivenTargetScore(const string& source_word, const string& target_word); @@ -39,7 +44,8 @@ class TranslationTable { TranslationTable(); private: - void IncreaseLinksCount( + // Increment links count for the given (f, e) word pair. + void IncrementLinksCount( unordered_map& source_links_count, unordered_map& target_links_count, unordered_map, int, PairHash>& links_count, diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index 03c7dc66..c8fd9411 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -9,16 +9,33 @@ using namespace std; namespace extractor { +/** + * Data structure for mapping words to word ids. + * + * This strucure contains words located in the frequent collocations and words + * encountered during the grammar extraction time. This dictionary is + * considerably smaller than the dictionaries in the data arrays (and so is the + * query time). Note that this is the single data structure that changes state + * and needs to have thread safe read/write operations. + * + * Note: For an experiment using different vocabulary instances for each thread, + * the running time did not improve implying that the critical regions do not + * cause bottlenecks. + */ class Vocabulary { public: virtual ~Vocabulary(); + // Returns the word id for the given word. virtual int GetTerminalIndex(const string& word); + // Returns the id for a nonterminal located at the given position in a phrase. int GetNonterminalIndex(int position); + // Checks if a symbol is a nonterminal. bool IsTerminal(int symbol); + // Returns the word corresponding to the given word id. virtual string GetTerminalValue(int symbol); private: -- cgit v1.2.3