diff options
Diffstat (limited to 'extractor')
127 files changed, 8378 insertions, 0 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am new file mode 100644 index 00000000..8f76dea5 --- /dev/null +++ b/extractor/Makefile.am @@ -0,0 +1,175 @@ +bin_PROGRAMS = compile run_extractor + +noinst_PROGRAMS = \ + alignment_test \ + binary_search_merger_test \ + data_array_test \ + fast_intersector_test \ + feature_count_source_target_test \ + feature_is_source_singleton_test \ + feature_is_source_target_singleton_test \ + feature_max_lex_source_given_target_test \ + feature_max_lex_target_given_source_test \ + feature_sample_source_count_test \ + feature_target_given_source_coherent_test \ + grammar_extractor_test \ + intersector_test \ + linear_merger_test \ + matching_comparator_test \ + matching_test \ + matchings_finder_test \ + phrase_test \ + precomputation_test \ + rule_extractor_helper_test \ + rule_extractor_test \ + rule_factory_test \ + sampler_test \ + scorer_test \ + suffix_array_test \ + target_phrase_extractor_test \ + translation_table_test \ + veb_test + +TESTS = alignment_test \ + binary_search_merger_test \ + data_array_test \ + fast_intersector_test \ + feature_count_source_target_test \ + feature_is_source_singleton_test \ + feature_is_source_target_singleton_test \ + feature_max_lex_source_given_target_test \ + feature_max_lex_target_given_source_test \ + feature_sample_source_count_test \ + feature_target_given_source_coherent_test \ + grammar_extractor_test \ + intersector_test \ + linear_merger_test \ + matching_comparator_test \ + matching_test \ + matchings_finder_test \ + phrase_test \ + precomputation_test \ + rule_extractor_helper_test \ + rule_extractor_test \ + rule_factory_test \ + sampler_test \ + scorer_test \ + suffix_array_test \ + target_phrase_extractor_test \ + translation_table_test \ + veb_test + +alignment_test_SOURCES = alignment_test.cc +alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +binary_search_merger_test_SOURCES = binary_search_merger_test.cc +binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +data_array_test_SOURCES = data_array_test.cc +data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +fast_intersector_test_SOURCES = fast_intersector_test.cc +fast_intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_count_source_target_test_SOURCES = features/count_source_target_test.cc +feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc +feature_is_source_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_target_singleton_test_SOURCES = features/is_source_target_singleton_test.cc +feature_is_source_target_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_max_lex_source_given_target_test_SOURCES = features/max_lex_source_given_target_test.cc +feature_max_lex_source_given_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_max_lex_target_given_source_test_SOURCES = features/max_lex_target_given_source_test.cc +feature_max_lex_target_given_source_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_sample_source_count_test_SOURCES = features/sample_source_count_test.cc +feature_sample_source_count_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_target_given_source_coherent_test_SOURCES = features/target_given_source_coherent_test.cc +feature_target_given_source_coherent_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +grammar_extractor_test_SOURCES = grammar_extractor_test.cc +grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +intersector_test_SOURCES = intersector_test.cc +intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +linear_merger_test_SOURCES = linear_merger_test.cc +linear_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matching_comparator_test_SOURCES = matching_comparator_test.cc +matching_comparator_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +matching_test_SOURCES = matching_test.cc +matching_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +matchings_finder_test_SOURCES = matchings_finder_test.cc +matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_test_SOURCES = phrase_test.cc +phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +precomputation_test_SOURCES = precomputation_test.cc +precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_helper_test_SOURCES = rule_extractor_helper_test.cc +rule_extractor_helper_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_test_SOURCES = rule_extractor_test.cc +rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_factory_test_SOURCES = rule_factory_test.cc +rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +sampler_test_SOURCES = sampler_test.cc +sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +scorer_test_SOURCES = scorer_test.cc +scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_test_SOURCES = suffix_array_test.cc +suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc +target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +translation_table_test_SOURCES = translation_table_test.cc +translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +veb_test_SOURCES = veb_test.cc +veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a + +noinst_LIBRARIES = libextractor.a libcompile.a + +compile_SOURCES = compile.cc +compile_LDADD = libcompile.a +run_extractor_SOURCES = run_extractor.cc +run_extractor_LDADD = libextractor.a + +libcompile_a_SOURCES = \ + alignment.cc \ + data_array.cc \ + phrase_location.cc \ + precomputation.cc \ + suffix_array.cc \ + time_util.cc \ + translation_table.cc + +libextractor_a_SOURCES = \ + alignment.cc \ + binary_search_merger.cc \ + data_array.cc \ + fast_intersector.cc \ + features/count_source_target.cc \ + features/feature.cc \ + features/is_source_singleton.cc \ + features/is_source_target_singleton.cc \ + features/max_lex_source_given_target.cc \ + features/max_lex_target_given_source.cc \ + features/sample_source_count.cc \ + features/target_given_source_coherent.cc \ + grammar.cc \ + grammar_extractor.cc \ + matching.cc \ + matching_comparator.cc \ + matchings_finder.cc \ + intersector.cc \ + linear_merger.cc \ + matchings_trie.cc \ + phrase.cc \ + phrase_builder.cc \ + phrase_location.cc \ + precomputation.cc \ + rule.cc \ + rule_extractor.cc \ + rule_extractor_helper.cc \ + rule_factory.cc \ + sampler.cc \ + scorer.cc \ + suffix_array.cc \ + target_phrase_extractor.cc \ + time_util.cc \ + translation_table.cc \ + veb.cc \ + veb_bitset.cc \ + veb_tree.cc \ + vocabulary.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) diff --git a/extractor/alignment.cc b/extractor/alignment.cc new file mode 100644 index 00000000..ff39d484 --- /dev/null +++ b/extractor/alignment.cc @@ -0,0 +1,51 @@ +#include "alignment.h" + +#include <fstream> +#include <sstream> +#include <string> +#include <fcntl.h> +#include <unistd.h> +#include <vector> + +#include <boost/algorithm/string.hpp> +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +Alignment::Alignment(const string& filename) { + ifstream infile(filename.c_str()); + string line; + while (getline(infile, line)) { + vector<string> items; + boost::split(items, line, boost::is_any_of(" -")); + vector<pair<int, int> > alignment; + alignment.reserve(items.size() / 2); + for (size_t i = 0; i < items.size(); i += 2) { + alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1]))); + } + alignments.push_back(alignment); + } + // Note: shrink_to_fit does nothing for vector<vector<string> > on g++ 4.6.3, + // but let's hope that the bug will be fixed in a newer version. + alignments.shrink_to_fit(); +} + +Alignment::Alignment() {} + +Alignment::~Alignment() {} + +vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const { + return alignments[sentence_index]; +} + +void Alignment::WriteBinary(const fs::path& filepath) { + FILE* file = fopen(filepath.string().c_str(), "w"); + int size = alignments.size(); + fwrite(&size, sizeof(int), 1, file); + for (vector<pair<int, int> > alignment: alignments) { + size = alignment.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(alignment.data(), sizeof(pair<int, int>), size, file); + } +} diff --git a/extractor/alignment.h b/extractor/alignment.h new file mode 100644 index 00000000..f7e79585 --- /dev/null +++ b/extractor/alignment.h @@ -0,0 +1,29 @@ +#ifndef _ALIGNMENT_H_ +#define _ALIGNMENT_H_ + +#include <string> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +class Alignment { + public: + Alignment(const string& filename); + + virtual vector<pair<int, int> > GetLinks(int sentence_index) const; + + void WriteBinary(const fs::path& filepath); + + virtual ~Alignment(); + + protected: + Alignment(); + + private: + vector<vector<pair<int, int> > > alignments; +}; + +#endif diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc new file mode 100644 index 00000000..1bc51a56 --- /dev/null +++ b/extractor/alignment_test.cc @@ -0,0 +1,31 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "alignment.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class AlignmentTest : public Test { + protected: + virtual void SetUp() { + alignment = make_shared<Alignment>("sample_alignment.txt"); + } + + shared_ptr<Alignment> alignment; +}; + +TEST_F(AlignmentTest, TestGetLinks) { + vector<pair<int, int> > expected_links = { + make_pair(0, 0), make_pair(1, 1), make_pair(2, 2) + }; + EXPECT_EQ(expected_links, alignment->GetLinks(0)); + expected_links = {make_pair(1, 0), make_pair(2, 1)}; + EXPECT_EQ(expected_links, alignment->GetLinks(1)); +} + +} // namespace diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc new file mode 100644 index 00000000..c1b86a77 --- /dev/null +++ b/extractor/binary_search_merger.cc @@ -0,0 +1,251 @@ +#include "binary_search_merger.h" + +#include "data_array.h" +#include "linear_merger.h" +#include "matching.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "vocabulary.h" + +double BinarySearchMerger::BAEZA_YATES_FACTOR = 1.0; + +BinarySearchMerger::BinarySearchMerger( + shared_ptr<Vocabulary> vocabulary, + shared_ptr<LinearMerger> linear_merger, + shared_ptr<DataArray> data_array, + shared_ptr<MatchingComparator> comparator, + bool force_binary_search_merge) : + vocabulary(vocabulary), linear_merger(linear_merger), + data_array(data_array), comparator(comparator), + force_binary_search_merge(force_binary_search_merge) {} + +BinarySearchMerger::BinarySearchMerger() {} + +BinarySearchMerger::~BinarySearchMerger() {} + +void BinarySearchMerger::Merge( + vector<int>& locations, const Phrase& phrase, const Phrase& suffix, + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& suffix_start, + const vector<int>::iterator& suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const { + if (IsIntersectionVoid(prefix_start, prefix_end, suffix_start, suffix_end, + prefix_subpatterns, suffix_subpatterns, suffix)) { + return; + } + + int prefix_set_size = prefix_end - prefix_start; + int suffix_set_size = suffix_end - suffix_start; + if (ShouldUseLinearMerge(prefix_set_size, suffix_set_size)) { + linear_merger->Merge(locations, phrase, suffix, prefix_start, prefix_end, + suffix_start, suffix_end, prefix_subpatterns, suffix_subpatterns); + return; + } + + vector<int>::iterator low, high, prefix_low, prefix_high, suffix_mid; + if (prefix_set_size > suffix_set_size) { + // Binary search on the prefix set. + suffix_mid = GetMiddle(suffix_start, suffix_end, suffix_subpatterns); + low = prefix_start, high = prefix_end; + while (low < high) { + vector<int>::iterator prefix_mid = + GetMiddle(low, high, prefix_subpatterns); + + GetComparableMatchings(prefix_start, prefix_end, prefix_mid, + prefix_subpatterns, prefix_low, prefix_high); + int comparison = CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix); + if (comparison == 0) { + break; + } else if (comparison < 0) { + low = prefix_mid + prefix_subpatterns; + } else { + high = prefix_mid; + } + } + } else { + // Binary search on the suffix set. + vector<int>::iterator prefix_mid = + GetMiddle(prefix_start, prefix_end, prefix_subpatterns); + + GetComparableMatchings(prefix_start, prefix_end, prefix_mid, + prefix_subpatterns, prefix_low, prefix_high); + low = suffix_start, high = suffix_end; + while (low < high) { + suffix_mid = GetMiddle(low, high, suffix_subpatterns); + + int comparison = CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix); + if (comparison == 0) { + break; + } else if (comparison > 0) { + low = suffix_mid + suffix_subpatterns; + } else { + high = suffix_mid; + } + } + } + + vector<int> result; + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + vector<int>::iterator suffix_low, suffix_high; + if (low < high) { + // We found a group of prefixes with the same starting position that give + // different results when compared to the found suffix. + // Find all matching suffixes for the previously found set of prefixes. + suffix_low = suffix_mid; + suffix_high = suffix_mid + suffix_subpatterns; + for (auto i = prefix_low; i != prefix_high; i += prefix_subpatterns) { + Matching left(i, prefix_subpatterns, data_array->GetSentenceId(*i)); + while (suffix_low != suffix_start) { + Matching right(suffix_low - suffix_subpatterns, suffix_subpatterns, + data_array->GetSentenceId(*(suffix_low - suffix_subpatterns))); + if (comparator->Compare(left, right, last_chunk_len, offset) <= 0) { + suffix_low -= suffix_subpatterns; + } else { + break; + } + } + + for (auto j = suffix_low; j != suffix_end; j += suffix_subpatterns) { + Matching right(j, suffix_subpatterns, data_array->GetSentenceId(*j)); + int comparison = comparator->Compare(left, right, last_chunk_len, + offset); + if (comparison == 0) { + vector<int> merged = left.Merge(right, phrase.Arity() + 1); + result.insert(result.end(), merged.begin(), merged.end()); + } else if (comparison < 0) { + break; + } + suffix_high = max(suffix_high, j + suffix_subpatterns); + } + } + + swap(suffix_low, suffix_high); + } else if (prefix_set_size > suffix_set_size) { + // We did the binary search on the prefix set. + suffix_low = suffix_mid; + suffix_high = suffix_mid + suffix_subpatterns; + if (CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix) < 0) { + prefix_low = prefix_high; + } else { + prefix_high = prefix_low; + } + } else { + // We did the binary search on the suffix set. + if (CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix) < 0) { + suffix_low = suffix_mid; + suffix_high = suffix_mid; + } else { + suffix_low = suffix_mid + suffix_subpatterns; + suffix_high = suffix_mid + suffix_subpatterns; + } + } + + Merge(locations, phrase, suffix, prefix_start, prefix_low, suffix_start, + suffix_low, prefix_subpatterns, suffix_subpatterns); + locations.insert(locations.end(), result.begin(), result.end()); + Merge(locations, phrase, suffix, prefix_high, prefix_end, suffix_high, + suffix_end, prefix_subpatterns, suffix_subpatterns); +} + +bool BinarySearchMerger::IsIntersectionVoid( + vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, + vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns, + const Phrase& suffix) const { + // Is any of the sets empty? + if (prefix_start >= prefix_end || suffix_start >= suffix_end) { + return true; + } + + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + // Is the first value from the first set larger than the last value in the + // second set? + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + Matching right(suffix_end - suffix_subpatterns, suffix_subpatterns, + data_array->GetSentenceId(*(suffix_end - suffix_subpatterns))); + if (comparator->Compare(left, right, last_chunk_len, offset) > 0) { + return true; + } + + // Is the last value from the first set smaller than the first value in the + // second set? + left = Matching(prefix_end - prefix_subpatterns, prefix_subpatterns, + data_array->GetSentenceId(*(prefix_end - prefix_subpatterns))); + right = Matching(suffix_start, suffix_subpatterns, + data_array->GetSentenceId(*suffix_start)); + if (comparator->Compare(left, right, last_chunk_len, offset) < 0) { + return true; + } + + return false; +} + +bool BinarySearchMerger::ShouldUseLinearMerge( + int prefix_set_size, int suffix_set_size) const { + if (force_binary_search_merge) { + return false; + } + + int min_size = min(prefix_set_size, suffix_set_size); + int max_size = max(prefix_set_size, suffix_set_size); + + return BAEZA_YATES_FACTOR * min_size * log2(max_size) > max_size; +} + +vector<int>::iterator BinarySearchMerger::GetMiddle( + vector<int>::iterator low, vector<int>::iterator high, + int num_subpatterns) const { + return low + (((high - low) / num_subpatterns) / 2) * num_subpatterns; +} + +void BinarySearchMerger::GetComparableMatchings( + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& prefix_mid, + int num_subpatterns, + vector<int>::iterator& prefix_low, + vector<int>::iterator& prefix_high) const { + prefix_low = prefix_mid; + while (prefix_low != prefix_start + && *prefix_mid == *(prefix_low - num_subpatterns)) { + prefix_low -= num_subpatterns; + } + prefix_high = prefix_mid + num_subpatterns; + while (prefix_high != prefix_end + && *prefix_mid == *prefix_high) { + prefix_high += num_subpatterns; + } +} + +int BinarySearchMerger::CompareMatchingsSet( + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& suffix_mid, + int prefix_subpatterns, + int suffix_subpatterns, + const Phrase& suffix) const { + int result = 0; + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + + Matching right(suffix_mid, suffix_subpatterns, + data_array->GetSentenceId(*suffix_mid)); + for (auto i = prefix_start; i != prefix_end; i += prefix_subpatterns) { + Matching left(i, prefix_subpatterns, data_array->GetSentenceId(*i)); + int comparison = comparator->Compare(left, right, last_chunk_len, offset); + if (i == prefix_start) { + result = comparison; + } else if (comparison != result) { + return 0; + } + } + return result; +} diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h new file mode 100644 index 00000000..c887e012 --- /dev/null +++ b/extractor/binary_search_merger.h @@ -0,0 +1,75 @@ +#ifndef _BINARY_SEARCH_MERGER_H_ +#define _BINARY_SEARCH_MERGER_H_ + +#include <memory> +#include <vector> + +using namespace std; + +class DataArray; +class LinearMerger; +class MatchingComparator; +class Phrase; +class Vocabulary; + +class BinarySearchMerger { + public: + BinarySearchMerger(shared_ptr<Vocabulary> vocabulary, + shared_ptr<LinearMerger> linear_merger, + shared_ptr<DataArray> data_array, + shared_ptr<MatchingComparator> comparator, + bool force_binary_search_merge = false); + + virtual ~BinarySearchMerger(); + + virtual void Merge( + vector<int>& locations, const Phrase& phrase, const Phrase& suffix, + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& suffix_start, + const vector<int>::iterator& suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const; + + static double BAEZA_YATES_FACTOR; + + protected: + BinarySearchMerger(); + + private: + bool IsIntersectionVoid( + vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, + vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns, + const Phrase& suffix) const; + + bool ShouldUseLinearMerge(int prefix_set_size, int suffix_set_size) const; + + vector<int>::iterator GetMiddle(vector<int>::iterator low, + vector<int>::iterator high, + int num_subpatterns) const; + + void GetComparableMatchings( + const vector<int>::iterator& prefix_start, + const vector<int>::iterator& prefix_end, + const vector<int>::iterator& prefix_mid, + int num_subpatterns, + vector<int>::iterator& prefix_low, + vector<int>::iterator& prefix_high) const; + + int CompareMatchingsSet( + const vector<int>::iterator& prefix_low, + const vector<int>::iterator& prefix_high, + const vector<int>::iterator& suffix_mid, + int prefix_subpatterns, + int suffix_subpatterns, + const Phrase& suffix) const; + + shared_ptr<Vocabulary> vocabulary; + shared_ptr<LinearMerger> linear_merger; + shared_ptr<DataArray> data_array; + shared_ptr<MatchingComparator> comparator; + // Should be true only for testing. + bool force_binary_search_merge; +}; + +#endif diff --git a/extractor/binary_search_merger_test.cc b/extractor/binary_search_merger_test.cc new file mode 100644 index 00000000..b1baa62f --- /dev/null +++ b/extractor/binary_search_merger_test.cc @@ -0,0 +1,157 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "binary_search_merger.h" +#include "matching_comparator.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_vocabulary.h" +#include "mocks/mock_linear_merger.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class BinarySearchMergerTest : public Test { + protected: + virtual void SetUp() { + shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + + shared_ptr<MockDataArray> data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*data_array, GetSentenceId(_)) + .WillRepeatedly(Return(1)); + + shared_ptr<MatchingComparator> comparator = + make_shared<MatchingComparator>(1, 20); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + // We are going to force the binary_search_merger to do all the work, so we + // need to check that the linear_merger never gets called. + shared_ptr<MockLinearMerger> linear_merger = + make_shared<MockLinearMerger>(); + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + binary_search_merger = make_shared<BinarySearchMerger>( + vocabulary, linear_merger, data_array, comparator, true); + } + + shared_ptr<BinarySearchMerger> binary_search_merger; + shared_ptr<PhraseBuilder> phrase_builder; +}; + +TEST_F(BinarySearchMergerTest, aXbTest) { + vector<int> locations; + // Encoding for him X it (see Adam's dissertation). + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{2, 6, 10, 15}; + vector<int> suffix_locs{0, 4, 8, 13}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + vector<int> expected_locations{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + EXPECT_EQ(expected_locations, locations); +} + +TEST_F(BinarySearchMergerTest, aXbXcTest) { + vector<int> locations; + // Encoding for it X him X it (see Adam's dissertation). + vector<int> symbols{1, -1, 2, -2, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2, -2, 1}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, 8, 15, + 13, 15}; + vector<int> suffix_locs{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 2); + + vector<int> expected_locs{0, 2, 4, 0, 2, 8, 0, 2, 13, 0, 6, 8, 0, 6, 13, 0, + 10, 13, 4, 6, 8, 4, 6, 13, 4, 10, 13, 8, 10, 13}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(BinarySearchMergerTest, abXcXdTest) { + // Sentence: Anna has many many nuts and sour apples and juicy apples. + // Phrase: Anna has X and X apples. + vector<int> locations; + vector<int> symbols{1, 2, -1, 3, -2, 4}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{2, -1, 3, -2, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{1, 6, 1, 9}; + vector<int> suffix_locs{2, 6, 8, 2, 6, 11, 2, 9, 11}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 3); + + vector<int> expected_locs{1, 6, 8, 1, 6, 11, 1, 9, 11}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(BinarySearchMergerTest, LargeTest) { + vector<int> locations; + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 20 + 1); + } + vector<int> suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 20 + 5); + suffix_locs.push_back(i * 20 + 13); + } + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(400, locations.size()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * 20 + 1, locations[4 * i]); + EXPECT_EQ(i * 20 + 5, locations[4 * i + 1]); + EXPECT_EQ(i * 20 + 1, locations[4 * i + 2]); + EXPECT_EQ(i * 20 + 13, locations[4 * i + 3]); + } +} + +TEST_F(BinarySearchMergerTest, EmptyResultTest) { + vector<int> locations; + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 200 + 1); + } + vector<int> suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 200 + 101); + } + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(0, locations.size()); +} + +} // namespace diff --git a/extractor/compile.cc b/extractor/compile.cc new file mode 100644 index 00000000..f5cd41f4 --- /dev/null +++ b/extractor/compile.cc @@ -0,0 +1,99 @@ +#include <iostream> +#include <string> + +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "alignment.h" +#include "data_array.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; + +int main(int argc, char** argv) { + po::options_description desc("Command line options"); + desc.add_options() + ("help,h", "Show available options") + ("source,f", po::value<string>(), "Source language corpus") + ("target,e", po::value<string>(), "Target language corpus") + ("bitext,b", po::value<string>(), "Parallel text (source ||| target)") + ("alignment,a", po::value<string>()->required(), "Bitext word alignment") + ("output,o", po::value<string>()->required(), "Output path") + ("frequent", po::value<int>()->default_value(100), + "Number of precomputed frequent patterns") + ("super_frequent", po::value<int>()->default_value(10), + "Number of precomputed super frequent patterns") + ("max_rule_span,s", po::value<int>()->default_value(15), + "Maximum rule span") + ("max_rule_symbols,l", po::value<int>()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size,g", po::value<int>()->default_value(1), "Minimum gap size") + ("max_phrase_len,p", po::value<int>()->default_value(4), + "Maximum frequent phrase length") + ("min_frequency", po::value<int>()->default_value(1000), + "Minimum number of occurences for a pharse to be considered frequent"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + + // Check for help argument before notify, so we don't need to pass in the + // required parameters. + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + po::notify(vm); + + if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { + cerr << "A paralel corpus is required. " + << "Use -f (source) with -e (target) or -b (bitext)." + << endl; + return 1; + } + + fs::path output_dir(vm["output"].as<string>().c_str()); + if (!fs::exists(output_dir)) { + fs::create_directory(output_dir); + } + + shared_ptr<DataArray> source_data_array, target_data_array; + if (vm.count("bitext")) { + source_data_array = make_shared<DataArray>( + vm["bitext"].as<string>(), SOURCE); + target_data_array = make_shared<DataArray>( + vm["bitext"].as<string>(), TARGET); + } else { + source_data_array = make_shared<DataArray>(vm["source"].as<string>()); + target_data_array = make_shared<DataArray>(vm["target"].as<string>()); + } + shared_ptr<SuffixArray> source_suffix_array = + make_shared<SuffixArray>(source_data_array); + source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); + target_data_array->WriteBinary(output_dir / fs::path("e.bin")); + + shared_ptr<Alignment> alignment = + make_shared<Alignment>(vm["alignment"].as<string>()); + alignment->WriteBinary(output_dir / fs::path("a.bin")); + + Precomputation precomputation( + source_suffix_array, + vm["frequent"].as<int>(), + vm["super_frequent"].as<int>(), + vm["max_rule_span"].as<int>(), + vm["max_rule_symbols"].as<int>(), + vm["min_gap_size"].as<int>(), + vm["max_phrase_len"].as<int>(), + vm["min_frequency"].as<int>()); + precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); + + TranslationTable table(source_data_array, target_data_array, alignment); + table.WriteBinary(output_dir / fs::path("lex.bin")); + + return 0; +} diff --git a/extractor/data_array.cc b/extractor/data_array.cc new file mode 100644 index 00000000..cd430c69 --- /dev/null +++ b/extractor/data_array.cc @@ -0,0 +1,156 @@ +#include "data_array.h" + +#include <fstream> +#include <iostream> +#include <sstream> +#include <string> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +int DataArray::NULL_WORD = 0; +int DataArray::END_OF_LINE = 1; +string DataArray::NULL_WORD_STR = "__NULL__"; +string DataArray::END_OF_LINE_STR = "__END_OF_LINE__"; + +DataArray::DataArray() { + InitializeDataArray(); +} + +DataArray::DataArray(const string& filename) { + InitializeDataArray(); + ifstream infile(filename.c_str()); + vector<string> lines; + string line; + while (getline(infile, line)) { + lines.push_back(line); + } + CreateDataArray(lines); +} + +DataArray::DataArray(const string& filename, const Side& side) { + InitializeDataArray(); + ifstream infile(filename.c_str()); + vector<string> lines; + string line, delimiter = "|||"; + while (getline(infile, line)) { + int position = line.find(delimiter); + if (side == SOURCE) { + lines.push_back(line.substr(0, position)); + } else { + lines.push_back(line.substr(position + delimiter.size())); + } + } + CreateDataArray(lines); +} + +void DataArray::InitializeDataArray() { + word2id[NULL_WORD_STR] = NULL_WORD; + id2word.push_back(NULL_WORD_STR); + word2id[END_OF_LINE_STR] = END_OF_LINE; + id2word.push_back(END_OF_LINE_STR); +} + +void DataArray::CreateDataArray(const vector<string>& lines) { + for (size_t i = 0; i < lines.size(); ++i) { + sentence_start.push_back(data.size()); + + istringstream iss(lines[i]); + string word; + while (iss >> word) { + if (word2id.count(word) == 0) { + word2id[word] = id2word.size(); + id2word.push_back(word); + } + data.push_back(word2id[word]); + sentence_id.push_back(i); + } + data.push_back(END_OF_LINE); + sentence_id.push_back(i); + } + sentence_start.push_back(data.size()); + + data.shrink_to_fit(); + sentence_id.shrink_to_fit(); + sentence_start.shrink_to_fit(); +} + +DataArray::~DataArray() {} + +const vector<int>& DataArray::GetData() const { + return data; +} + +int DataArray::AtIndex(int index) const { + return data[index]; +} + +string DataArray::GetWordAtIndex(int index) const { + return id2word[data[index]]; +} + +int DataArray::GetSize() const { + return data.size(); +} + +int DataArray::GetVocabularySize() const { + return id2word.size(); +} + +int DataArray::GetNumSentences() const { + return sentence_start.size() - 1; +} + +int DataArray::GetSentenceStart(int position) const { + return sentence_start[position]; +} + +int DataArray::GetSentenceLength(int sentence_id) const { + // Ignore end of line markers. + return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1; +} + +int DataArray::GetSentenceId(int position) const { + return sentence_id[position]; +} + +void DataArray::WriteBinary(const fs::path& filepath) const { + WriteBinary(fopen(filepath.string().c_str(), "w")); +} + +void DataArray::WriteBinary(FILE* file) const { + int size = id2word.size(); + fwrite(&size, sizeof(int), 1, file); + for (string word: id2word) { + size = word.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(word.data(), sizeof(char), size, file); + } + + size = data.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(data.data(), sizeof(int), size, file); + + size = sentence_id.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(sentence_id.data(), sizeof(int), size, file); + + size = sentence_start.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(sentence_start.data(), sizeof(int), 1, file); +} + +bool DataArray::HasWord(const string& word) const { + return word2id.count(word); +} + +int DataArray::GetWordId(const string& word) const { + auto result = word2id.find(word); + return result == word2id.end() ? -1 : result->second; +} + +string DataArray::GetWord(int word_id) const { + return id2word[word_id]; +} diff --git a/extractor/data_array.h b/extractor/data_array.h new file mode 100644 index 00000000..7c120b3c --- /dev/null +++ b/extractor/data_array.h @@ -0,0 +1,76 @@ +#ifndef _DATA_ARRAY_H_ +#define _DATA_ARRAY_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +enum Side { + SOURCE, + TARGET +}; + +class DataArray { + public: + static int NULL_WORD; + static int END_OF_LINE; + static string NULL_WORD_STR; + static string END_OF_LINE_STR; + + DataArray(const string& filename); + + DataArray(const string& filename, const Side& side); + + virtual ~DataArray(); + + virtual const vector<int>& GetData() const; + + virtual int AtIndex(int index) const; + + virtual string GetWordAtIndex(int index) const; + + virtual int GetSize() const; + + virtual int GetVocabularySize() const; + + virtual bool HasWord(const string& word) const; + + virtual int GetWordId(const string& word) const; + + virtual string GetWord(int word_id) const; + + virtual int GetNumSentences() const; + + virtual int GetSentenceStart(int position) const; + + //TODO(pauldb): Add unit tests. + virtual int GetSentenceLength(int sentence_id) const; + + virtual int GetSentenceId(int position) const; + + void WriteBinary(const fs::path& filepath) const; + + void WriteBinary(FILE* file) const; + + protected: + DataArray(); + + private: + void InitializeDataArray(); + void CreateDataArray(const vector<string>& lines); + + unordered_map<string, int> word2id; + vector<string> id2word; + vector<int> data; + // TODO(pauldb): We only need sentence_id for the source language. Maybe we + // can save some memory here. + vector<int> sentence_id; + vector<int> sentence_start; +}; + +#endif diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc new file mode 100644 index 00000000..772ba10e --- /dev/null +++ b/extractor/data_array_test.cc @@ -0,0 +1,70 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include <boost/filesystem.hpp> + +#include "data_array.h" + +using namespace std; +using namespace ::testing; +namespace fs = boost::filesystem; + +namespace { + +class DataArrayTest : public Test { + protected: + virtual void SetUp() { + string sample_test_file("sample_bitext.txt"); + source_data = make_shared<DataArray>(sample_test_file, SOURCE); + target_data = make_shared<DataArray>(sample_test_file, TARGET); + } + + shared_ptr<DataArray> source_data; + shared_ptr<DataArray> target_data; +}; + +TEST_F(DataArrayTest, TestGetData) { + vector<int> expected_source_data{2, 3, 4, 5, 1, 2, 6, 7, 8, 5, 1}; + EXPECT_EQ(expected_source_data, source_data->GetData()); + EXPECT_EQ(expected_source_data.size(), source_data->GetSize()); + for (size_t i = 0; i < expected_source_data.size(); ++i) { + EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i)); + } + + vector<int> expected_target_data{2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1}; + EXPECT_EQ(expected_target_data, target_data->GetData()); + EXPECT_EQ(expected_target_data.size(), target_data->GetSize()); + for (size_t i = 0; i < expected_target_data.size(); ++i) { + EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i)); + } +} + +TEST_F(DataArrayTest, TestVocabulary) { + EXPECT_EQ(9, source_data->GetVocabularySize()); + EXPECT_TRUE(source_data->HasWord("mere")); + EXPECT_EQ(4, source_data->GetWordId("mere")); + EXPECT_EQ("mere", source_data->GetWord(4)); + EXPECT_FALSE(source_data->HasWord("banane")); + + EXPECT_EQ(11, target_data->GetVocabularySize()); + EXPECT_TRUE(target_data->HasWord("apples")); + EXPECT_EQ(4, target_data->GetWordId("apples")); + EXPECT_EQ("apples", target_data->GetWord(4)); + EXPECT_FALSE(target_data->HasWord("bananas")); +} + +TEST_F(DataArrayTest, TestSentenceData) { + EXPECT_EQ(2, source_data->GetNumSentences()); + EXPECT_EQ(0, source_data->GetSentenceStart(0)); + EXPECT_EQ(5, source_data->GetSentenceStart(1)); + EXPECT_EQ(11, source_data->GetSentenceStart(2)); + + EXPECT_EQ(2, target_data->GetNumSentences()); + EXPECT_EQ(0, target_data->GetSentenceStart(0)); + EXPECT_EQ(5, target_data->GetSentenceStart(1)); + EXPECT_EQ(13, target_data->GetSentenceStart(2)); +} + +} // namespace diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc new file mode 100644 index 00000000..8c7a7af8 --- /dev/null +++ b/extractor/fast_intersector.cc @@ -0,0 +1,191 @@ +#include "fast_intersector.h" + +#include <cassert> + +#include "data_array.h" +#include "phrase.h" +#include "phrase_location.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "vocabulary.h" + +FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array, + shared_ptr<Precomputation> precomputation, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + int min_gap_size) : + suffix_array(suffix_array), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size) { + Index precomputed_collocations = precomputation->GetCollocations(); + for (pair<vector<int>, vector<int> > entry: precomputed_collocations) { + vector<int> phrase = ConvertPhrase(entry.first); + collocations[phrase] = entry.second; + } +} + +FastIntersector::FastIntersector() {} + +FastIntersector::~FastIntersector() {} + +vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) { + vector<int> new_phrase; + new_phrase.reserve(old_phrase.size()); + shared_ptr<DataArray> data_array = suffix_array->GetData(); + int num_nonterminals = 0; + for (int word_id: old_phrase) { + // TODO(pauldb): Remove overhead for relabelling the nonterminals here. + if (word_id == Precomputation::NON_TERMINAL) { + ++num_nonterminals; + new_phrase.push_back(vocabulary->GetNonterminalIndex(num_nonterminals)); + } else { + new_phrase.push_back( + vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); + } + } + return new_phrase; +} + +PhraseLocation FastIntersector::Intersect( + PhraseLocation& prefix_location, + PhraseLocation& suffix_location, + const Phrase& phrase) { + vector<int> symbols = phrase.Get(); + + // We should never attempt to do an intersect query for a pattern starting or + // ending with a non terminal. The RuleFactory should handle these cases, + // initializing the matchings list with the one for the pattern without the + // starting or ending terminal. + assert(vocabulary->IsTerminal(symbols.front()) + && vocabulary->IsTerminal(symbols.back())); + + if (collocations.count(symbols)) { + return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + } + + bool prefix_ends_with_x = + !vocabulary->IsTerminal(symbols[symbols.size() - 2]); + bool suffix_starts_with_x = !vocabulary->IsTerminal(symbols[1]); + if (EstimateNumOperations(prefix_location, prefix_ends_with_x) <= + EstimateNumOperations(suffix_location, suffix_starts_with_x)) { + return ExtendPrefixPhraseLocation(prefix_location, phrase, + prefix_ends_with_x, symbols.back()); + } else { + return ExtendSuffixPhraseLocation(suffix_location, phrase, + suffix_starts_with_x, symbols.front()); + } +} + +int FastIntersector::EstimateNumOperations( + const PhraseLocation& phrase_location, bool has_margin_x) const { + int num_locations = phrase_location.GetSize(); + return has_margin_x ? num_locations * max_rule_span : num_locations; +} + +PhraseLocation FastIntersector::ExtendPrefixPhraseLocation( + PhraseLocation& prefix_location, const Phrase& phrase, + bool prefix_ends_with_x, int next_symbol) const { + ExtendPhraseLocation(prefix_location); + vector<int> positions = *prefix_location.matchings; + int num_subpatterns = prefix_location.num_subpatterns; + + vector<int> new_positions; + shared_ptr<DataArray> data_array = suffix_array->GetData(); + int data_array_symbol = data_array->GetWordId( + vocabulary->GetTerminalValue(next_symbol)); + if (data_array_symbol == -1) { + return PhraseLocation(new_positions, num_subpatterns); + } + + pair<int, int> range = GetSearchRange(prefix_ends_with_x); + for (size_t i = 0; i < positions.size(); i += num_subpatterns) { + int sent_id = data_array->GetSentenceId(positions[i]); + int sent_end = data_array->GetSentenceStart(sent_id + 1) - 1; + int pattern_end = positions[i + num_subpatterns - 1] + range.first; + if (prefix_ends_with_x) { + pattern_end += phrase.GetChunkLen(phrase.Arity() - 1) - 1; + } else { + pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2; + } + for (int j = range.first; j < range.second; ++j) { + if (pattern_end >= sent_end || + pattern_end - positions[i] >= max_rule_span) { + break; + } + + if (data_array->AtIndex(pattern_end) == data_array_symbol) { + new_positions.insert(new_positions.end(), positions.begin() + i, + positions.begin() + i + num_subpatterns); + if (prefix_ends_with_x) { + new_positions.push_back(pattern_end); + } + } + ++pattern_end; + } + } + + return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +PhraseLocation FastIntersector::ExtendSuffixPhraseLocation( + PhraseLocation& suffix_location, const Phrase& phrase, + bool suffix_starts_with_x, int prev_symbol) const { + ExtendPhraseLocation(suffix_location); + vector<int> positions = *suffix_location.matchings; + int num_subpatterns = suffix_location.num_subpatterns; + + vector<int> new_positions; + shared_ptr<DataArray> data_array = suffix_array->GetData(); + int data_array_symbol = data_array->GetWordId( + vocabulary->GetTerminalValue(prev_symbol)); + if (data_array_symbol == -1) { + return PhraseLocation(new_positions, num_subpatterns); + } + + pair<int, int> range = GetSearchRange(suffix_starts_with_x); + for (size_t i = 0; i < positions.size(); i += num_subpatterns) { + int sent_id = data_array->GetSentenceId(positions[i]); + int sent_start = data_array->GetSentenceStart(sent_id); + int pattern_start = positions[i] - range.first; + int pattern_end = positions[i + num_subpatterns - 1] + + phrase.GetChunkLen(phrase.Arity()) - 1; + for (int j = range.first; j < range.second; ++j) { + if (pattern_start < sent_start || + pattern_end - pattern_start >= max_rule_span) { + break; + } + + if (data_array->AtIndex(pattern_start) == data_array_symbol) { + new_positions.push_back(pattern_start); + new_positions.insert(new_positions.end(), + positions.begin() + i + !suffix_starts_with_x, + positions.begin() + i + num_subpatterns); + } + --pattern_start; + } + } + + return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +void FastIntersector::ExtendPhraseLocation(PhraseLocation& location) const { + if (location.matchings != NULL) { + return; + } + + location.num_subpatterns = 1; + location.matchings = make_shared<vector<int> >(); + for (int i = location.sa_low; i < location.sa_high; ++i) { + location.matchings->push_back(suffix_array->GetSuffix(i)); + } + location.sa_low = location.sa_high = 0; +} + +pair<int, int> FastIntersector::GetSearchRange(bool has_marginal_x) const { + if (has_marginal_x) { + return make_pair(min_gap_size + 1, max_rule_span); + } else { + return make_pair(1, 2); + } +} diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h new file mode 100644 index 00000000..785e428e --- /dev/null +++ b/extractor/fast_intersector.h @@ -0,0 +1,65 @@ +#ifndef _FAST_INTERSECTOR_H_ +#define _FAST_INTERSECTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include <boost/functional/hash.hpp> + +using namespace std; + +typedef boost::hash<vector<int> > VectorHash; +typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; + +class Phrase; +class PhraseLocation; +class Precomputation; +class SuffixArray; +class Vocabulary; + +class FastIntersector { + public: + FastIntersector(shared_ptr<SuffixArray> suffix_array, + shared_ptr<Precomputation> precomputation, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + int min_gap_size); + + virtual ~FastIntersector(); + + virtual PhraseLocation Intersect(PhraseLocation& prefix_location, + PhraseLocation& suffix_location, + const Phrase& phrase); + + protected: + FastIntersector(); + + private: + vector<int> ConvertPhrase(const vector<int>& old_phrase); + + int EstimateNumOperations(const PhraseLocation& phrase_location, + bool has_margin_x) const; + + PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location, + const Phrase& phrase, + bool prefix_ends_with_x, + int next_symbol) const; + + PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location, + const Phrase& phrase, + bool suffix_starts_with_x, + int prev_symbol) const; + + void ExtendPhraseLocation(PhraseLocation& location) const; + + pair<int, int> GetSearchRange(bool has_marginal_x) const; + + shared_ptr<SuffixArray> suffix_array; + shared_ptr<Vocabulary> vocabulary; + int max_rule_span; + int min_gap_size; + Index collocations; +}; + +#endif diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc new file mode 100644 index 00000000..0d6ef367 --- /dev/null +++ b/extractor/fast_intersector_test.cc @@ -0,0 +1,146 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "fast_intersector.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_precomputation.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class FastIntersectorTest : public Test { + protected: + virtual void SetUp() { + vector<string> words = {"EOL", "it", "makes", "him", "and", "mars", ",", + "sets", "on", "takes", "off", "."}; + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i])) + .WillRepeatedly(Return(i)); + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(words[i])); + } + + vector<int> data = {1, 2, 3, 4, 1, 5, 3, 6, 1, + 7, 3, 8, 4, 1, 9, 3, 10, 11, 0}; + data_array = make_shared<MockDataArray>(); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + EXPECT_CALL(*data_array, GetSentenceId(i)) + .WillRepeatedly(Return(0)); + } + EXPECT_CALL(*data_array, GetSentenceStart(0)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*data_array, GetSentenceStart(1)) + .WillRepeatedly(Return(19)); + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i)); + EXPECT_CALL(*data_array, GetWord(i)) + .WillRepeatedly(Return(words[i])); + } + + vector<int> suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9, + 11, 14, 16, 17}; + suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)). + WillRepeatedly(Return(suffixes[i])); + } + + precomputation = make_shared<MockPrecomputation>(); + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + intersector = make_shared<FastIntersector>(suffix_array, precomputation, + vocabulary, 15, 1); + } + + Index collocations; + shared_ptr<MockDataArray> data_array; + shared_ptr<MockSuffixArray> suffix_array; + shared_ptr<MockPrecomputation> precomputation; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<FastIntersector> intersector; + shared_ptr<PhraseBuilder> phrase_builder; +}; + +TEST_F(FastIntersectorTest, TestCachedCollocation) { + vector<int> symbols = {8, -1, 9}; + vector<int> expected_location = {11}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_location(15, 16), suffix_location(16, 17); + + collocations[symbols] = expected_location; + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + intersector = make_shared<FastIntersector>(suffix_array, precomputation, + vocabulary, 15, 1); + + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + + EXPECT_EQ(PhraseLocation(expected_location, 2), result); + EXPECT_EQ(PhraseLocation(15, 16), prefix_location); + EXPECT_EQ(PhraseLocation(16, 17), suffix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) { + vector<int> symbols = {1, -1, 3, -1, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, + 8, 15, 3, 15}; + vector<int> suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + PhraseLocation prefix_location(prefix_locs, 2); + PhraseLocation suffix_location(suffix_locs, 2); + + vector<int> expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8, + 4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13, + 0, 10, 13}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 3), result); +} + +/* +TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) { + vector<int> symbols = {1, -1, 3}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_location(1, 5), suffix_location(6, 10); + + vector<int> expected_prefix_locs = {0, 4, 8, 13}; + vector<int> expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, + 8, 15, 13, 15}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 2), result); + EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) { + // The suffix matches in fewer positions, but because it starts with an X + // it requires more operations and we prefer extending the prefix. + vector<int> symbols = {1, -1, 4, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12}; + PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12); + + vector<int> expected_locs = {0, 3, 0, 12, 4, 12, 8, 12}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 2), result); + EXPECT_EQ(PhraseLocation(10, 12), suffix_location); +} +*/ + +} // namespace diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc new file mode 100644 index 00000000..9441b451 --- /dev/null +++ b/extractor/features/count_source_target.cc @@ -0,0 +1,11 @@ +#include "count_source_target.h" + +#include <cmath> + +double CountSourceTarget::Score(const FeatureContext& context) const { + return log10(1 + context.pair_count); +} + +string CountSourceTarget::GetName() const { + return "CountEF"; +} diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h new file mode 100644 index 00000000..a2481944 --- /dev/null +++ b/extractor/features/count_source_target.h @@ -0,0 +1,13 @@ +#ifndef _COUNT_SOURCE_TARGET_H_ +#define _COUNT_SOURCE_TARGET_H_ + +#include "feature.h" + +class CountSourceTarget : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc new file mode 100644 index 00000000..22633bb6 --- /dev/null +++ b/extractor/features/count_source_target_test.cc @@ -0,0 +1,32 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "count_source_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class CountSourceTargetTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<CountSourceTarget>(); + } + + shared_ptr<CountSourceTarget> feature; +}; + +TEST_F(CountSourceTargetTest, TestGetName) { + EXPECT_EQ("CountEF", feature->GetName()); +} + +TEST_F(CountSourceTargetTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 9, 13); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc new file mode 100644 index 00000000..876f5f8f --- /dev/null +++ b/extractor/features/feature.cc @@ -0,0 +1,5 @@ +#include "feature.h" + +const double Feature::MAX_SCORE = 99.0; + +Feature::~Feature() {} diff --git a/extractor/features/feature.h b/extractor/features/feature.h new file mode 100644 index 00000000..aca58401 --- /dev/null +++ b/extractor/features/feature.h @@ -0,0 +1,36 @@ +#ifndef _FEATURE_H_ +#define _FEATURE_H_ + +#include <string> + +//TODO(pauldb): include headers nicely. +#include "../phrase.h" + +using namespace std; + +struct FeatureContext { + FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, + double source_phrase_count, int pair_count, int num_samples) : + source_phrase(source_phrase), target_phrase(target_phrase), + source_phrase_count(source_phrase_count), pair_count(pair_count), + num_samples(num_samples) {} + + Phrase source_phrase; + Phrase target_phrase; + double source_phrase_count; + int pair_count; + int num_samples; +}; + +class Feature { + public: + virtual double Score(const FeatureContext& context) const = 0; + + virtual string GetName() const = 0; + + virtual ~Feature(); + + static const double MAX_SCORE; +}; + +#endif diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc new file mode 100644 index 00000000..98d4e5fe --- /dev/null +++ b/extractor/features/is_source_singleton.cc @@ -0,0 +1,11 @@ +#include "is_source_singleton.h" + +#include <cmath> + +double IsSourceSingleton::Score(const FeatureContext& context) const { + return context.source_phrase_count == 1; +} + +string IsSourceSingleton::GetName() const { + return "IsSingletonF"; +} diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h new file mode 100644 index 00000000..7cc72828 --- /dev/null +++ b/extractor/features/is_source_singleton.h @@ -0,0 +1,13 @@ +#ifndef _IS_SOURCE_SINGLETON_H_ +#define _IS_SOURCE_SINGLETON_H_ + +#include "feature.h" + +class IsSourceSingleton : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc new file mode 100644 index 00000000..8c71e593 --- /dev/null +++ b/extractor/features/is_source_singleton_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<IsSourceSingleton>(); + } + + shared_ptr<IsSourceSingleton> feature; +}; + +TEST_F(IsSourceSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonF", feature->GetName()); +} + +TEST_F(IsSourceSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 31); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1, 3, 25); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc new file mode 100644 index 00000000..31d36532 --- /dev/null +++ b/extractor/features/is_source_target_singleton.cc @@ -0,0 +1,11 @@ +#include "is_source_target_singleton.h" + +#include <cmath> + +double IsSourceTargetSingleton::Score(const FeatureContext& context) const { + return context.pair_count == 1; +} + +string IsSourceTargetSingleton::GetName() const { + return "IsSingletonFE"; +} diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h new file mode 100644 index 00000000..58913b74 --- /dev/null +++ b/extractor/features/is_source_target_singleton.h @@ -0,0 +1,13 @@ +#ifndef _IS_SOURCE_TARGET_SINGLETON_H_ +#define _IS_SOURCE_TARGET_SINGLETON_H_ + +#include "feature.h" + +class IsSourceTargetSingleton : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc new file mode 100644 index 00000000..a51f77c9 --- /dev/null +++ b/extractor/features/is_source_target_singleton_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_target_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IsSourceTargetSingletonTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<IsSourceTargetSingleton>(); + } + + shared_ptr<IsSourceTargetSingleton> feature; +}; + +TEST_F(IsSourceTargetSingletonTest, TestGetName) { + EXPECT_EQ("IsSingletonFE", feature->GetName()); +} + +TEST_F(IsSourceTargetSingletonTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.5, 3, 7); + EXPECT_EQ(0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 2.3, 1, 28); + EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc new file mode 100644 index 00000000..21f5c76a --- /dev/null +++ b/extractor/features/max_lex_source_given_target.cc @@ -0,0 +1,31 @@ +#include "max_lex_source_given_target.h" + +#include <cmath> + +#include "../data_array.h" +#include "../translation_table.h" + +MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( + shared_ptr<TranslationTable> table) : + table(table) {} + +double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { + vector<string> source_words = context.source_phrase.GetWords(); + vector<string> target_words = context.target_phrase.GetWords(); + target_words.push_back(DataArray::NULL_WORD_STR); + + double score = 0; + for (string source_word: source_words) { + double max_score = 0; + for (string target_word: target_words) { + max_score = max(max_score, + table->GetSourceGivenTargetScore(source_word, target_word)); + } + score += max_score > 0 ? -log10(max_score) : MAX_SCORE; + } + return score; +} + +string MaxLexSourceGivenTarget::GetName() const { + return "MaxLexFgivenE"; +} diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h new file mode 100644 index 00000000..e87c1c8e --- /dev/null +++ b/extractor/features/max_lex_source_given_target.h @@ -0,0 +1,24 @@ +#ifndef _MAX_LEX_SOURCE_GIVEN_TARGET_H_ +#define _MAX_LEX_SOURCE_GIVEN_TARGET_H_ + +#include <memory> + +#include "feature.h" + +using namespace std; + +class TranslationTable; + +class MaxLexSourceGivenTarget : public Feature { + public: + MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table); + + double Score(const FeatureContext& context) const; + + string GetName() const; + + private: + shared_ptr<TranslationTable> table; +}; + +#endif diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc new file mode 100644 index 00000000..5fd41f8b --- /dev/null +++ b/extractor/features/max_lex_source_given_target_test.cc @@ -0,0 +1,74 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_source_given_target.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexSourceGivenTargetTest : public Test { + protected: + virtual void SetUp() { + vector<string> source_words = {"f1", "f2", "f3"}; + vector<string> target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + table = make_shared<MockTranslationTable>(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < source_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetSourceGivenTargetScore( + source_words[i], DataArray::NULL_WORD_STR)) + .WillRepeatedly(Return(value)); + } + + feature = make_shared<MaxLexSourceGivenTarget>(table); + } + + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockTranslationTable> table; + shared_ptr<MaxLexSourceGivenTarget> feature; +}; + +TEST_F(MaxLexSourceGivenTargetTest, TestGetName) { + EXPECT_EQ("MaxLexFgivenE", feature->GetName()); +} + +TEST_F(MaxLexSourceGivenTargetTest, TestScore) { + vector<int> source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector<int> target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11); + EXPECT_EQ(99 - log10(18), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc new file mode 100644 index 00000000..f2bc2474 --- /dev/null +++ b/extractor/features/max_lex_target_given_source.cc @@ -0,0 +1,31 @@ +#include "max_lex_target_given_source.h" + +#include <cmath> + +#include "../data_array.h" +#include "../translation_table.h" + +MaxLexTargetGivenSource::MaxLexTargetGivenSource( + shared_ptr<TranslationTable> table) : + table(table) {} + +double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { + vector<string> source_words = context.source_phrase.GetWords(); + source_words.push_back(DataArray::NULL_WORD_STR); + vector<string> target_words = context.target_phrase.GetWords(); + + double score = 0; + for (string target_word: target_words) { + double max_score = 0; + for (string source_word: source_words) { + max_score = max(max_score, + table->GetTargetGivenSourceScore(source_word, target_word)); + } + score += max_score > 0 ? -log10(max_score) : MAX_SCORE; + } + return score; +} + +string MaxLexTargetGivenSource::GetName() const { + return "MaxLexEgivenF"; +} diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h new file mode 100644 index 00000000..9585ff04 --- /dev/null +++ b/extractor/features/max_lex_target_given_source.h @@ -0,0 +1,24 @@ +#ifndef _MAX_LEX_TARGET_GIVEN_SOURCE_H_ +#define _MAX_LEX_TARGET_GIVEN_SOURCE_H_ + +#include <memory> + +#include "feature.h" + +using namespace std; + +class TranslationTable; + +class MaxLexTargetGivenSource : public Feature { + public: + MaxLexTargetGivenSource(shared_ptr<TranslationTable> table); + + double Score(const FeatureContext& context) const; + + string GetName() const; + + private: + shared_ptr<TranslationTable> table; +}; + +#endif diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc new file mode 100644 index 00000000..c8701bf7 --- /dev/null +++ b/extractor/features/max_lex_target_given_source_test.cc @@ -0,0 +1,74 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "../mocks/mock_translation_table.h" +#include "../mocks/mock_vocabulary.h" +#include "../data_array.h" +#include "../phrase_builder.h" +#include "max_lex_target_given_source.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MaxLexTargetGivenSourceTest : public Test { + protected: + virtual void SetUp() { + vector<string> source_words = {"f1", "f2", "f3"}; + vector<string> target_words = {"e1", "e2", "e3"}; + + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < source_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(source_words[i])); + } + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) + .WillRepeatedly(Return(target_words[i])); + } + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + table = make_shared<MockTranslationTable>(); + for (size_t i = 0; i < source_words.size(); ++i) { + for (size_t j = 0; j < target_words.size(); ++j) { + int value = i - j; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + source_words[i], target_words[j])).WillRepeatedly(Return(value)); + } + } + + for (size_t i = 0; i < target_words.size(); ++i) { + int value = i * 3; + EXPECT_CALL(*table, GetTargetGivenSourceScore( + DataArray::NULL_WORD_STR, target_words[i])) + .WillRepeatedly(Return(value)); + } + + feature = make_shared<MaxLexTargetGivenSource>(table); + } + + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockTranslationTable> table; + shared_ptr<MaxLexTargetGivenSource> feature; +}; + +TEST_F(MaxLexTargetGivenSourceTest, TestGetName) { + EXPECT_EQ("MaxLexEgivenF", feature->GetName()); +} + +TEST_F(MaxLexTargetGivenSourceTest, TestScore) { + vector<int> source_symbols = {0, 1, 2}; + Phrase source_phrase = phrase_builder->Build(source_symbols); + vector<int> target_symbols = {3, 4, 5}; + Phrase target_phrase = phrase_builder->Build(target_symbols); + FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19); + EXPECT_EQ(-log10(36), feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc new file mode 100644 index 00000000..88b645b1 --- /dev/null +++ b/extractor/features/sample_source_count.cc @@ -0,0 +1,11 @@ +#include "sample_source_count.h" + +#include <cmath> + +double SampleSourceCount::Score(const FeatureContext& context) const { + return log10(1 + context.num_samples); +} + +string SampleSourceCount::GetName() const { + return "SampleCountF"; +} diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h new file mode 100644 index 00000000..62d236c8 --- /dev/null +++ b/extractor/features/sample_source_count.h @@ -0,0 +1,13 @@ +#ifndef _SAMPLE_SOURCE_COUNT_H_ +#define _SAMPLE_SOURCE_COUNT_H_ + +#include "feature.h" + +class SampleSourceCount : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc new file mode 100644 index 00000000..7d226104 --- /dev/null +++ b/extractor/features/sample_source_count_test.cc @@ -0,0 +1,36 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "sample_source_count.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class SampleSourceCountTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<SampleSourceCount>(); + } + + shared_ptr<SampleSourceCount> feature; +}; + +TEST_F(SampleSourceCountTest, TestGetName) { + EXPECT_EQ("SampleCountF", feature->GetName()); +} + +TEST_F(SampleSourceCountTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0, 3, 1); + EXPECT_EQ(log10(2), feature->Score(context)); + + context = FeatureContext(phrase, phrase, 3.2, 3, 9); + EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc new file mode 100644 index 00000000..274b3364 --- /dev/null +++ b/extractor/features/target_given_source_coherent.cc @@ -0,0 +1,12 @@ +#include "target_given_source_coherent.h" + +#include <cmath> + +double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { + double prob = (double) context.pair_count / context.num_samples; + return prob > 0 ? -log10(prob) : MAX_SCORE; +} + +string TargetGivenSourceCoherent::GetName() const { + return "EgivenFCoherent"; +} diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h new file mode 100644 index 00000000..09c8edb1 --- /dev/null +++ b/extractor/features/target_given_source_coherent.h @@ -0,0 +1,13 @@ +#ifndef _TARGET_GIVEN_SOURCE_COHERENT_H_ +#define _TARGET_GIVEN_SOURCE_COHERENT_H_ + +#include "feature.h" + +class TargetGivenSourceCoherent : public Feature { + public: + double Score(const FeatureContext& context) const; + + string GetName() const; +}; + +#endif diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc new file mode 100644 index 00000000..c54c06c2 --- /dev/null +++ b/extractor/features/target_given_source_coherent_test.cc @@ -0,0 +1,35 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "target_given_source_coherent.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class TargetGivenSourceCoherentTest : public Test { + protected: + virtual void SetUp() { + feature = make_shared<TargetGivenSourceCoherent>(); + } + + shared_ptr<TargetGivenSourceCoherent> feature; +}; + +TEST_F(TargetGivenSourceCoherentTest, TestGetName) { + EXPECT_EQ("EgivenFCoherent", feature->GetName()); +} + +TEST_F(TargetGivenSourceCoherentTest, TestScore) { + Phrase phrase; + FeatureContext context(phrase, phrase, 0.3, 2, 20); + EXPECT_EQ(1.0, feature->Score(context)); + + context = FeatureContext(phrase, phrase, 1.9, 0, 1); + EXPECT_EQ(99.0, feature->Score(context)); +} + +} // namespace diff --git a/extractor/grammar.cc b/extractor/grammar.cc new file mode 100644 index 00000000..8124a804 --- /dev/null +++ b/extractor/grammar.cc @@ -0,0 +1,39 @@ +#include "grammar.h" + +#include <iomanip> + +#include "rule.h" + +using namespace std; + +Grammar::Grammar(const vector<Rule>& rules, + const vector<string>& feature_names) : + rules(rules), feature_names(feature_names) {} + +vector<Rule> Grammar::GetRules() const { + return rules; +} + +vector<string> Grammar::GetFeatureNames() const { + return feature_names; +} + +ostream& operator<<(ostream& os, const Grammar& grammar) { + vector<Rule> rules = grammar.GetRules(); + vector<string> feature_names = grammar.GetFeatureNames(); + os << setprecision(12); + for (Rule rule: rules) { + os << "[X] ||| " << rule.source_phrase << " ||| " + << rule.target_phrase << " |||"; + for (size_t i = 0; i < rule.scores.size(); ++i) { + os << " " << feature_names[i] << "=" << rule.scores[i]; + } + os << " |||"; + for (auto link: rule.alignment) { + os << " " << link.first << "-" << link.second; + } + os << endl; + } + + return os; +} diff --git a/extractor/grammar.h b/extractor/grammar.h new file mode 100644 index 00000000..889cc2f3 --- /dev/null +++ b/extractor/grammar.h @@ -0,0 +1,27 @@ +#ifndef _GRAMMAR_H_ +#define _GRAMMAR_H_ + +#include <iostream> +#include <string> +#include <vector> + +using namespace std; + +class Rule; + +class Grammar { + public: + Grammar(const vector<Rule>& rules, const vector<string>& feature_names); + + vector<Rule> GetRules() const; + + vector<string> GetFeatureNames() const; + + friend ostream& operator<<(ostream& os, const Grammar& grammar); + + private: + vector<Rule> rules; + vector<string> feature_names; +}; + +#endif diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc new file mode 100644 index 00000000..a03e805f --- /dev/null +++ b/extractor/grammar_extractor.cc @@ -0,0 +1,58 @@ +#include "grammar_extractor.h" + +#include <iterator> +#include <sstream> +#include <vector> + +#include "grammar.h" +#include "rule.h" +#include "vocabulary.h" + +using namespace std; + +GrammarExtractor::GrammarExtractor( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span, + int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) : + vocabulary(make_shared<Vocabulary>()), + rule_factory(make_shared<HieroCachingRuleFactory>( + source_suffix_array, target_data_array, alignment, vocabulary, + precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, + max_rule_symbols, max_samples, use_fast_intersect, use_baeza_yates, + require_tight_phrases)) {} + +GrammarExtractor::GrammarExtractor( + shared_ptr<Vocabulary> vocabulary, + shared_ptr<HieroCachingRuleFactory> rule_factory) : + vocabulary(vocabulary), + rule_factory(rule_factory) {} + +Grammar GrammarExtractor::GetGrammar(const string& sentence) { + vector<string> words = TokenizeSentence(sentence); + vector<int> word_ids = AnnotateWords(words); + return rule_factory->GetGrammar(word_ids); +} + +vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { + vector<string> result; + result.push_back("<s>"); + + istringstream buffer(sentence); + copy(istream_iterator<string>(buffer), + istream_iterator<string>(), + back_inserter(result)); + + result.push_back("</s>"); + return result; +} + +vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) { + vector<int> result; + for (string word: words) { + result.push_back(vocabulary->GetTerminalIndex(word)); + } + return result; +} diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h new file mode 100644 index 00000000..a8f2090d --- /dev/null +++ b/extractor/grammar_extractor.h @@ -0,0 +1,51 @@ +#ifndef _GRAMMAR_EXTRACTOR_H_ +#define _GRAMMAR_EXTRACTOR_H_ + +#include <string> +#include <vector> + +#include "rule_factory.h" + +using namespace std; + +class Alignment; +class DataArray; +class Grammar; +class Precomputation; +class Rule; +class SuffixArray; +class Vocabulary; + +class GrammarExtractor { + public: + GrammarExtractor( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + int max_samples, + bool use_fast_intersect, + bool use_baeza_yates, + bool require_tight_phrases); + + // For testing only. + GrammarExtractor(shared_ptr<Vocabulary> vocabulary, + shared_ptr<HieroCachingRuleFactory> rule_factory); + + Grammar GetGrammar(const string& sentence); + + private: + vector<string> TokenizeSentence(const string& sentence); + + vector<int> AnnotateWords(const vector<string>& words); + + shared_ptr<Vocabulary> vocabulary; + shared_ptr<HieroCachingRuleFactory> rule_factory; +}; + +#endif diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc new file mode 100644 index 00000000..d4ed7d4f --- /dev/null +++ b/extractor/grammar_extractor_test.cc @@ -0,0 +1,49 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "grammar_extractor.h" +#include "mocks/mock_rule_factory.h" +#include "mocks/mock_vocabulary.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace { + +TEST(GrammarExtractorTest, TestAnnotatingWords) { + shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>")) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna")) + .WillRepeatedly(Return(1)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("has")) + .WillRepeatedly(Return(2)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("many")) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("apples")) + .WillRepeatedly(Return(4)); + EXPECT_CALL(*vocabulary, GetTerminalIndex(".")) + .WillRepeatedly(Return(5)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>")) + .WillRepeatedly(Return(6)); + + shared_ptr<MockHieroCachingRuleFactory> factory = + make_shared<MockHieroCachingRuleFactory>(); + vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6}; + vector<Rule> rules; + vector<string> feature_names; + Grammar grammar(rules, feature_names); + EXPECT_CALL(*factory, GetGrammar(word_ids)) + .WillOnce(Return(grammar)); + + GrammarExtractor extractor(vocabulary, factory); + string sentence = "Anna has many many apples ."; + extractor.GetGrammar(sentence); +} + +} // namespace diff --git a/extractor/intersector.cc b/extractor/intersector.cc new file mode 100644 index 00000000..39a7648d --- /dev/null +++ b/extractor/intersector.cc @@ -0,0 +1,154 @@ +#include "intersector.h" + +#include "data_array.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "phrase_location.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "veb.h" +#include "vocabulary.h" + +Intersector::Intersector(shared_ptr<Vocabulary> vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<SuffixArray> suffix_array, + shared_ptr<MatchingComparator> comparator, + bool use_baeza_yates) : + vocabulary(vocabulary), + suffix_array(suffix_array), + use_baeza_yates(use_baeza_yates) { + shared_ptr<DataArray> data_array = suffix_array->GetData(); + linear_merger = make_shared<LinearMerger>(vocabulary, data_array, comparator); + binary_search_merger = make_shared<BinarySearchMerger>( + vocabulary, linear_merger, data_array, comparator); + ConvertIndexes(precomputation, data_array); +} + +Intersector::Intersector(shared_ptr<Vocabulary> vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<SuffixArray> suffix_array, + shared_ptr<LinearMerger> linear_merger, + shared_ptr<BinarySearchMerger> binary_search_merger, + bool use_baeza_yates) : + vocabulary(vocabulary), + suffix_array(suffix_array), + linear_merger(linear_merger), + binary_search_merger(binary_search_merger), + use_baeza_yates(use_baeza_yates) { + ConvertIndexes(precomputation, suffix_array->GetData()); +} + +Intersector::Intersector() {} + +Intersector::~Intersector() {} + +void Intersector::ConvertIndexes(shared_ptr<Precomputation> precomputation, + shared_ptr<DataArray> data_array) { + const Index& precomputed_index = precomputation->GetInvertedIndex(); + for (pair<vector<int>, vector<int> > entry: precomputed_index) { + vector<int> phrase = ConvertPhrase(entry.first, data_array); + inverted_index[phrase] = entry.second; + + phrase.push_back(vocabulary->GetNonterminalIndex(1)); + inverted_index[phrase] = entry.second; + phrase.pop_back(); + phrase.insert(phrase.begin(), vocabulary->GetNonterminalIndex(1)); + inverted_index[phrase] = entry.second; + } + + const Index& precomputed_collocations = precomputation->GetCollocations(); + for (pair<vector<int>, vector<int> > entry: precomputed_collocations) { + vector<int> phrase = ConvertPhrase(entry.first, data_array); + collocations[phrase] = entry.second; + } +} + +vector<int> Intersector::ConvertPhrase(const vector<int>& old_phrase, + shared_ptr<DataArray> data_array) { + vector<int> new_phrase; + new_phrase.reserve(old_phrase.size()); + + int arity = 0; + for (int word_id: old_phrase) { + if (word_id == Precomputation::NON_TERMINAL) { + ++arity; + new_phrase.push_back(vocabulary->GetNonterminalIndex(arity)); + } else { + new_phrase.push_back( + vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); + } + } + + return new_phrase; +} + +PhraseLocation Intersector::Intersect( + const Phrase& prefix, PhraseLocation& prefix_location, + const Phrase& suffix, PhraseLocation& suffix_location, + const Phrase& phrase) { + vector<int> symbols = phrase.Get(); + + // We should never attempt to do an intersect query for a pattern starting or + // ending with a non terminal. The RuleFactory should handle these cases, + // initializing the matchings list with the one for the pattern without the + // starting or ending terminal. + assert(vocabulary->IsTerminal(symbols.front()) + && vocabulary->IsTerminal(symbols.back())); + + if (collocations.count(symbols)) { + return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + } + + vector<int> locations; + ExtendPhraseLocation(prefix, prefix_location); + ExtendPhraseLocation(suffix, suffix_location); + shared_ptr<vector<int> > prefix_matchings = prefix_location.matchings; + shared_ptr<vector<int> > suffix_matchings = suffix_location.matchings; + int prefix_subpatterns = prefix_location.num_subpatterns; + int suffix_subpatterns = suffix_location.num_subpatterns; + if (use_baeza_yates) { + binary_search_merger->Merge(locations, phrase, suffix, + prefix_matchings->begin(), prefix_matchings->end(), + suffix_matchings->begin(), suffix_matchings->end(), + prefix_subpatterns, suffix_subpatterns); + } else { + linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), + prefix_matchings->end(), suffix_matchings->begin(), + suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); + } + return PhraseLocation(locations, phrase.Arity() + 1); +} + +void Intersector::ExtendPhraseLocation( + const Phrase& phrase, PhraseLocation& phrase_location) { + int low = phrase_location.sa_low, high = phrase_location.sa_high; + if (phrase_location.matchings != NULL) { + return; + } + + + phrase_location.num_subpatterns = 1; + phrase_location.sa_low = phrase_location.sa_high = 0; + + vector<int> symbols = phrase.Get(); + if (inverted_index.count(symbols)) { + phrase_location.matchings = + make_shared<vector<int> >(inverted_index[symbols]); + return; + } + + vector<int> matchings; + matchings.reserve(high - low + 1); + shared_ptr<VEB> veb = VEB::Create(suffix_array->GetSize()); + for (int i = low; i < high; ++i) { + veb->Insert(suffix_array->GetSuffix(i)); + } + + int value = veb->GetMinimum(); + while (value != -1) { + matchings.push_back(value); + value = veb->GetSuccessor(value); + } + + phrase_location.matchings = make_shared<vector<int> >(matchings); +} diff --git a/extractor/intersector.h b/extractor/intersector.h new file mode 100644 index 00000000..8b159f17 --- /dev/null +++ b/extractor/intersector.h @@ -0,0 +1,79 @@ +#ifndef _INTERSECTOR_H_ +#define _INTERSECTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include <boost/functional/hash.hpp> + +#include "binary_search_merger.h" +#include "linear_merger.h" + +using namespace std; + +typedef boost::hash<vector<int> > VectorHash; +typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; + +class DataArray; +class MatchingComparator; +class Phrase; +class PhraseLocation; +class Precomputation; +class SuffixArray; +class Vocabulary; + +class Intersector { + public: + Intersector( + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<MatchingComparator> comparator, + bool use_baeza_yates); + + // For testing. + Intersector( + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<LinearMerger> linear_merger, + shared_ptr<BinarySearchMerger> binary_search_merger, + bool use_baeza_yates); + + virtual ~Intersector(); + + virtual PhraseLocation Intersect( + const Phrase& prefix, PhraseLocation& prefix_location, + const Phrase& suffix, PhraseLocation& suffix_location, + const Phrase& phrase); + + protected: + Intersector(); + + private: + void ConvertIndexes(shared_ptr<Precomputation> precomputation, + shared_ptr<DataArray> data_array); + + vector<int> ConvertPhrase(const vector<int>& old_phrase, + shared_ptr<DataArray> data_array); + + void ExtendPhraseLocation(const Phrase& phrase, + PhraseLocation& phrase_location); + + shared_ptr<Vocabulary> vocabulary; + shared_ptr<SuffixArray> suffix_array; + shared_ptr<LinearMerger> linear_merger; + shared_ptr<BinarySearchMerger> binary_search_merger; + Index inverted_index; + Index collocations; + bool use_baeza_yates; + + // TODO(pauldb): Don't forget to remove these. + public: + double sort_time; + double linear_merge_time; + double binary_merge_time; +}; + +#endif diff --git a/extractor/intersector_test.cc b/extractor/intersector_test.cc new file mode 100644 index 00000000..ec318362 --- /dev/null +++ b/extractor/intersector_test.cc @@ -0,0 +1,193 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "intersector.h" +#include "mocks/mock_binary_search_merger.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_linear_merger.h" +#include "mocks/mock_precomputation.h" +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class IntersectorTest : public Test { + protected: + virtual void SetUp() { + data = {2, 3, 4, 3, 4, 3}; + vector<string> words = {"a", "b", "c", "b", "c", "b"}; + data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + + vocabulary = make_shared<MockVocabulary>(); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, GetWord(data[i])) + .WillRepeatedly(Return(words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i])) + .WillRepeatedly(Return(data[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(data[i])) + .WillRepeatedly(Return(words[i])); + } + + vector<int> suffixes = {6, 0, 5, 3, 1, 4, 2}; + suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(data_array)); + EXPECT_CALL(*suffix_array, GetSize()) + .WillRepeatedly(Return(suffixes.size())); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)) + .WillRepeatedly(Return(suffixes[i])); + } + + vector<int> key = {2, -1, 4}; + vector<int> values = {0, 2}; + collocations[key] = values; + precomputation = make_shared<MockPrecomputation>(); + EXPECT_CALL(*precomputation, GetInvertedIndex()) + .WillRepeatedly(ReturnRef(inverted_index)); + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + + linear_merger = make_shared<MockLinearMerger>(); + binary_search_merger = make_shared<MockBinarySearchMerger>(); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + } + + Index inverted_index; + Index collocations; + vector<int> data; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<MockDataArray> data_array; + shared_ptr<MockSuffixArray> suffix_array; + shared_ptr<MockPrecomputation> precomputation; + shared_ptr<MockLinearMerger> linear_merger; + shared_ptr<MockBinarySearchMerger> binary_search_merger; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<Intersector> intersector; +}; + +TEST_F(IntersectorTest, TestCachedCollocation) { + intersector = make_shared<Intersector>(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + vector<int> prefix_symbols = {2, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector<int> suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector<int> symbols = {2, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(0, 1), suffix_locs(2, 3); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + + vector<int> expected_locs = {0, 2}; + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(PhraseLocation(0, 1), prefix_locs); + EXPECT_EQ(PhraseLocation(2, 3), suffix_locs); +} + +TEST_F(IntersectorTest, TestLinearMergeaXb) { + vector<int> prefix_symbols = {3, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector<int> suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector<int> symbols = {3, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7); + + vector<int> ex_prefix_locs = {1, 3, 5}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); + vector<int> ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 1); + + vector<int> expected_locs = {1, 4}; + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared<Intersector>(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(extended_prefix_locs, prefix_locs); + EXPECT_EQ(extended_suffix_locs, suffix_locs); +} + +TEST_F(IntersectorTest, TestBinarySearchMergeaXb) { + vector<int> prefix_symbols = {3, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector<int> suffix_symbols = {-1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector<int> symbols = {3, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7); + + vector<int> ex_prefix_locs = {1, 3, 5}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 1); + vector<int> ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 1); + + vector<int> expected_locs = {1, 4}; + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared<Intersector>(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, true); + + PhraseLocation result = intersector->Intersect( + prefix, prefix_locs, suffix, suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 2); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(extended_prefix_locs, prefix_locs); + EXPECT_EQ(extended_suffix_locs, suffix_locs); +} + +TEST_F(IntersectorTest, TestMergeaXbXc) { + vector<int> prefix_symbols = {2, -1, 4, -1}; + Phrase prefix = phrase_builder->Build(prefix_symbols); + vector<int> suffix_symbols = {-1, 4, -1, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + vector<int> symbols = {2, -1, 4, -1, 4}; + Phrase phrase = phrase_builder->Build(symbols); + + vector<int> ex_prefix_locs = {0, 2, 0, 4}; + PhraseLocation extended_prefix_locs(ex_prefix_locs, 2); + vector<int> ex_suffix_locs = {2, 4}; + PhraseLocation extended_suffix_locs(ex_suffix_locs, 2); + vector<int> expected_locs = {0, 2, 4}; + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)) + .Times(1) + .WillOnce(SetArgReferee<0>(expected_locs)); + EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + intersector = make_shared<Intersector>(vocabulary, precomputation, + suffix_array, linear_merger, binary_search_merger, false); + + PhraseLocation result = intersector->Intersect( + prefix, extended_prefix_locs, suffix, extended_suffix_locs, phrase); + PhraseLocation expected_result(expected_locs, 3); + + EXPECT_EQ(expected_result, result); + EXPECT_EQ(ex_prefix_locs, *extended_prefix_locs.matchings); + EXPECT_EQ(ex_suffix_locs, *extended_suffix_locs.matchings); +} + +} // namespace diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc new file mode 100644 index 00000000..e7a32788 --- /dev/null +++ b/extractor/linear_merger.cc @@ -0,0 +1,65 @@ +#include "linear_merger.h" + +#include <cmath> + +#include "data_array.h" +#include "matching.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "phrase_location.h" +#include "vocabulary.h" + +LinearMerger::LinearMerger(shared_ptr<Vocabulary> vocabulary, + shared_ptr<DataArray> data_array, + shared_ptr<MatchingComparator> comparator) : + vocabulary(vocabulary), data_array(data_array), comparator(comparator) {} + +LinearMerger::LinearMerger() {} + +LinearMerger::~LinearMerger() {} + +void LinearMerger::Merge( + vector<int>& locations, const Phrase& phrase, const Phrase& suffix, + vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, + vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns) { + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + + while (prefix_start != prefix_end) { + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + + while (suffix_start != suffix_end) { + Matching right(suffix_start, suffix_subpatterns, + data_array->GetSentenceId(*suffix_start)); + if (comparator->Compare(left, right, last_chunk_len, offset) > 0) { + suffix_start += suffix_subpatterns; + } else { + break; + } + } + + int start_position = *prefix_start; + vector<int> :: iterator i = suffix_start; + while (prefix_start != prefix_end && *prefix_start == start_position) { + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + + while (i != suffix_end) { + Matching right(i, suffix_subpatterns, data_array->GetSentenceId(*i)); + int comparison = comparator->Compare(left, right, last_chunk_len, + offset); + if (comparison == 0) { + vector<int> merged = left.Merge(right, phrase.Arity() + 1); + locations.insert(locations.end(), merged.begin(), merged.end()); + } else if (comparison < 0) { + break; + } + i += suffix_subpatterns; + } + + prefix_start += prefix_subpatterns; + } + } +} diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h new file mode 100644 index 00000000..c3c7111e --- /dev/null +++ b/extractor/linear_merger.h @@ -0,0 +1,38 @@ +#ifndef _LINEAR_MERGER_H_ +#define _LINEAR_MERGER_H_ + +#include <memory> +#include <vector> + +using namespace std; + +class MatchingComparator; +class Phrase; +class PhraseLocation; +class DataArray; +class Vocabulary; + +class LinearMerger { + public: + LinearMerger(shared_ptr<Vocabulary> vocabulary, + shared_ptr<DataArray> data_array, + shared_ptr<MatchingComparator> comparator); + + virtual ~LinearMerger(); + + virtual void Merge( + vector<int>& locations, const Phrase& phrase, const Phrase& suffix, + vector<int>::iterator prefix_start, vector<int>::iterator prefix_end, + vector<int>::iterator suffix_start, vector<int>::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns); + + protected: + LinearMerger(); + + private: + shared_ptr<Vocabulary> vocabulary; + shared_ptr<DataArray> data_array; + shared_ptr<MatchingComparator> comparator; +}; + +#endif diff --git a/extractor/linear_merger_test.cc b/extractor/linear_merger_test.cc new file mode 100644 index 00000000..a6963430 --- /dev/null +++ b/extractor/linear_merger_test.cc @@ -0,0 +1,149 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "linear_merger.h" +#include "matching_comparator.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class LinearMergerTest : public Test { + protected: + virtual void SetUp() { + shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + + shared_ptr<MockDataArray> data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*data_array, GetSentenceId(_)) + .WillRepeatedly(Return(1)); + + shared_ptr<MatchingComparator> comparator = + make_shared<MatchingComparator>(1, 20); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + linear_merger = make_shared<LinearMerger>(vocabulary, data_array, + comparator); + } + + shared_ptr<LinearMerger> linear_merger; + shared_ptr<PhraseBuilder> phrase_builder; +}; + +TEST_F(LinearMergerTest, aXbTest) { + vector<int> locations; + // Encoding for him X it (see Adam's dissertation). + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{2, 6, 10, 15}; + vector<int> suffix_locs{0, 4, 8, 13}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + vector<int> expected_locations{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + EXPECT_EQ(expected_locations, locations); +} + +TEST_F(LinearMergerTest, aXbXcTest) { + vector<int> locations; + // Encoding for it X him X it (see Adam's dissertation). + vector<int> symbols{1, -1, 2, -2, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2, -2, 1}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, 8, 15, + 13, 15}; + vector<int> suffix_locs{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 2); + + vector<int> expected_locs{0, 2, 4, 0, 2, 8, 0, 2, 13, 0, 6, 8, 0, 6, 13, 0, + 10, 13, 4, 6, 8, 4, 6, 13, 4, 10, 13, 8, 10, 13}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(LinearMergerTest, abXcXdTest) { + // Sentence: Anna has many many nuts and sour apples and juicy apples. + // Phrase: Anna has X and X apples. + vector<int> locations; + vector<int> symbols{1, 2, -1, 3, -2, 4}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{2, -1, 3, -2, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs{1, 6, 1, 9}; + vector<int> suffix_locs{2, 6, 8, 2, 6, 11, 2, 9, 11}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 3); + + vector<int> expected_locs{1, 6, 8, 1, 6, 11, 1, 9, 11}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(LinearMergerTest, LargeTest) { + vector<int> locations; + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 20 + 1); + } + vector<int> suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 20 + 5); + suffix_locs.push_back(i * 20 + 13); + } + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(400, locations.size()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * 20 + 1, locations[4 * i]); + EXPECT_EQ(i * 20 + 5, locations[4 * i + 1]); + EXPECT_EQ(i * 20 + 1, locations[4 * i + 2]); + EXPECT_EQ(i * 20 + 13, locations[4 * i + 3]); + } +} + +TEST_F(LinearMergerTest, EmptyResultTest) { + vector<int> locations; + vector<int> symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector<int> prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 200 + 1); + } + vector<int> suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 200 + 101); + } + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(0, locations.size()); +} + +} // namespace diff --git a/extractor/matching.cc b/extractor/matching.cc new file mode 100644 index 00000000..16a3ed6f --- /dev/null +++ b/extractor/matching.cc @@ -0,0 +1,12 @@ +#include "matching.h" + +Matching::Matching(vector<int>::iterator start, int len, int sentence_id) : + positions(start, start + len), sentence_id(sentence_id) {} + +vector<int> Matching::Merge(const Matching& other, int num_subpatterns) const { + vector<int> result = positions; + if (num_subpatterns > positions.size()) { + result.push_back(other.positions.back()); + } + return result; +} diff --git a/extractor/matching.h b/extractor/matching.h new file mode 100644 index 00000000..4c46559e --- /dev/null +++ b/extractor/matching.h @@ -0,0 +1,18 @@ +#ifndef _MATCHING_H_ +#define _MATCHING_H_ + +#include <memory> +#include <vector> + +using namespace std; + +struct Matching { + Matching(vector<int>::iterator start, int len, int sentence_id); + + vector<int> Merge(const Matching& other, int num_subpatterns) const; + + vector<int> positions; + int sentence_id; +}; + +#endif diff --git a/extractor/matching_comparator.cc b/extractor/matching_comparator.cc new file mode 100644 index 00000000..03db95c0 --- /dev/null +++ b/extractor/matching_comparator.cc @@ -0,0 +1,46 @@ +#include "matching_comparator.h" + +#include "matching.h" +#include "vocabulary.h" + +MatchingComparator::MatchingComparator(int min_gap_size, int max_rule_span) : + min_gap_size(min_gap_size), max_rule_span(max_rule_span) {} + +int MatchingComparator::Compare(const Matching& left, + const Matching& right, + int last_chunk_len, + bool offset) const { + if (left.sentence_id != right.sentence_id) { + return left.sentence_id < right.sentence_id ? -1 : 1; + } + + if (left.positions.size() == 1 && right.positions.size() == 1) { + // The prefix and the suffix must be separated by a non-terminal, otherwise + // we would be doing a suffix array lookup. + if (right.positions[0] - left.positions[0] <= min_gap_size) { + return 1; + } + } else if (offset) { + for (size_t i = 1; i < left.positions.size(); ++i) { + if (left.positions[i] != right.positions[i - 1]) { + return left.positions[i] < right.positions[i - 1] ? -1 : 1; + } + } + } else { + if (left.positions[0] + 1 != right.positions[0]) { + return left.positions[0] + 1 < right.positions[0] ? -1 : 1; + } + for (size_t i = 1; i < left.positions.size(); ++i) { + if (left.positions[i] != right.positions[i]) { + return left.positions[i] < right.positions[i] ? -1 : 1; + } + } + } + + if (right.positions.back() + last_chunk_len - left.positions.front() > + max_rule_span) { + return -1; + } + + return 0; +} diff --git a/extractor/matching_comparator.h b/extractor/matching_comparator.h new file mode 100644 index 00000000..6e1bb487 --- /dev/null +++ b/extractor/matching_comparator.h @@ -0,0 +1,23 @@ +#ifndef _MATCHING_COMPARATOR_H_ +#define _MATCHING_COMPARATOR_H_ + +#include <memory> + +using namespace std; + +class Vocabulary; +class Matching; + +class MatchingComparator { + public: + MatchingComparator(int min_gap_size, int max_rule_span); + + int Compare(const Matching& left, const Matching& right, + int last_chunk_len, bool offset) const; + + private: + int min_gap_size; + int max_rule_span; +}; + +#endif diff --git a/extractor/matching_comparator_test.cc b/extractor/matching_comparator_test.cc new file mode 100644 index 00000000..b8f898cf --- /dev/null +++ b/extractor/matching_comparator_test.cc @@ -0,0 +1,139 @@ +#include <gtest/gtest.h> + +#include "matching.h" +#include "matching_comparator.h" + +using namespace ::testing; + +namespace { + +class MatchingComparatorTest : public Test { + protected: + virtual void SetUp() { + comparator = make_shared<MatchingComparator>(1, 20); + } + + shared_ptr<MatchingComparator> comparator; +}; + +TEST_F(MatchingComparatorTest, SmallerSentenceId) { + vector<int> left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector<int> right_locations{100}; + Matching right(right_locations.begin(), 1, 5); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreaterSentenceId) { + vector<int> left_locations{100}; + Matching left(left_locations.begin(), 1, 5); + vector<int> right_locations{1}; + Matching right(right_locations.begin(), 1, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmalleraXb) { + vector<int> left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector<int> right_locations{21}; + Matching right(right_locations.begin(), 1, 1); + // The matching exceeds the max rule span. + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, EqualaXb) { + vector<int> left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector<int> lower_right_locations{3}; + Matching right(lower_right_locations.begin(), 1, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); + + vector<int> higher_right_locations{20}; + right = Matching(higher_right_locations.begin(), 1, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreateraXb) { + vector<int> left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector<int> right_locations{2}; + Matching right(right_locations.begin(), 1, 1); + // The gap between the prefix and the suffix is of size 0, less than the + // min gap size. + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmalleraXbXc) { + vector<int> left_locations{1, 3}; + Matching left(left_locations.begin(), 2, 1); + vector<int> right_locations{4, 6}; + // The common part doesn't match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); + + // The common part matches, but the rule exceeds the max span. + vector<int> other_right_locations{3, 21}; + right = Matching(other_right_locations.begin(), 2, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, EqualaXbXc) { + vector<int> left_locations{1, 3}; + Matching left(left_locations.begin(), 2, 1); + vector<int> right_locations{3, 5}; + // The leftmost possible match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); + + // The rightmost possible match. + vector<int> other_right_locations{3, 20}; + right = Matching(other_right_locations.begin(), 2, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreateraXbXc) { + vector<int> left_locations{1, 4}; + Matching left(left_locations.begin(), 2, 1); + vector<int> right_locations{3, 5}; + // The common part doesn't match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmallerabXcXd) { + vector<int> left_locations{9, 13}; + Matching left(left_locations.begin(), 2, 1); + // The suffix doesn't start on the next position. + vector<int> right_locations{11, 13, 15}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, false)); + + // The common part doesn't match. + vector<int> other_right_locations{10, 16, 18}; + right = Matching(other_right_locations.begin(), 3, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, false)); +} + +TEST_F(MatchingComparatorTest, EqualabXcXd) { + vector<int> left_locations{10, 13}; + Matching left(left_locations.begin(), 2, 1); + vector<int> right_locations{11, 13, 15}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, false)); +} + +TEST_F(MatchingComparatorTest, GreaterabXcXd) { + vector<int> left_locations{9, 15}; + Matching left(left_locations.begin(), 2, 1); + // The suffix doesn't start on the next position. + vector<int> right_locations{7, 15, 17}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, false)); + + // The common part doesn't match. + vector<int> other_right_locations{10, 13, 16}; + right = Matching(other_right_locations.begin(), 3, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, false)); +} + +} // namespace diff --git a/extractor/matching_test.cc b/extractor/matching_test.cc new file mode 100644 index 00000000..9593aa86 --- /dev/null +++ b/extractor/matching_test.cc @@ -0,0 +1,25 @@ +#include <gtest/gtest.h> + +#include <vector> + +#include "matching.h" + +using namespace std; + +namespace { + +TEST(MatchingTest, SameSize) { + vector<int> positions{1, 2, 3}; + Matching left(positions.begin(), positions.size(), 0); + Matching right(positions.begin(), positions.size(), 0); + EXPECT_EQ(positions, left.Merge(right, positions.size())); +} + +TEST(MatchingTest, DifferentSize) { + vector<int> positions{1, 2, 3}; + Matching left(positions.begin(), positions.size() - 1, 0); + Matching right(positions.begin() + 1, positions.size() - 1, 0); + vector<int> result = left.Merge(right, positions.size()); +} + +} // namespace diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc new file mode 100644 index 00000000..eaf493b2 --- /dev/null +++ b/extractor/matchings_finder.cc @@ -0,0 +1,21 @@ +#include "matchings_finder.h" + +#include "suffix_array.h" +#include "phrase_location.h" + +MatchingsFinder::MatchingsFinder(shared_ptr<SuffixArray> suffix_array) : + suffix_array(suffix_array) {} + +MatchingsFinder::MatchingsFinder() {} + +MatchingsFinder::~MatchingsFinder() {} + +PhraseLocation MatchingsFinder::Find(PhraseLocation& location, + const string& word, int offset) { + if (location.sa_low == -1 && location.sa_high == -1) { + location.sa_low = 0; + location.sa_high = suffix_array->GetSize(); + } + + return suffix_array->Lookup(location.sa_low, location.sa_high, word, offset); +} diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h new file mode 100644 index 00000000..ed04d8b8 --- /dev/null +++ b/extractor/matchings_finder.h @@ -0,0 +1,28 @@ +#ifndef _MATCHINGS_FINDER_H_ +#define _MATCHINGS_FINDER_H_ + +#include <memory> +#include <string> + +using namespace std; + +class PhraseLocation; +class SuffixArray; + +class MatchingsFinder { + public: + MatchingsFinder(shared_ptr<SuffixArray> suffix_array); + + virtual ~MatchingsFinder(); + + virtual PhraseLocation Find(PhraseLocation& location, const string& word, + int offset); + + protected: + MatchingsFinder(); + + private: + shared_ptr<SuffixArray> suffix_array; +}; + +#endif diff --git a/extractor/matchings_finder_test.cc b/extractor/matchings_finder_test.cc new file mode 100644 index 00000000..817f1635 --- /dev/null +++ b/extractor/matchings_finder_test.cc @@ -0,0 +1,42 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "matchings_finder.h" +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MatchingsFinderTest : public Test { + protected: + virtual void SetUp() { + suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, Lookup(0, 10, _, _)) + .Times(1) + .WillOnce(Return(PhraseLocation(3, 5))); + + matchings_finder = make_shared<MatchingsFinder>(suffix_array); + } + + shared_ptr<MatchingsFinder> matchings_finder; + shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(MatchingsFinderTest, TestFind) { + PhraseLocation phrase_location(0, 10), expected_result(3, 5); + EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); +} + +TEST_F(MatchingsFinderTest, ResizeUnsetRange) { + EXPECT_CALL(*suffix_array, GetSize()).Times(1).WillOnce(Return(10)); + + PhraseLocation phrase_location, expected_result(3, 5); + EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); + EXPECT_EQ(PhraseLocation(0, 10), phrase_location); +} + +} // namespace diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc new file mode 100644 index 00000000..921ec582 --- /dev/null +++ b/extractor/matchings_trie.cc @@ -0,0 +1,19 @@ +#include "matchings_trie.h" + +void MatchingsTrie::Reset() { + ResetTree(root); + root = make_shared<TrieNode>(); +} + +shared_ptr<TrieNode> MatchingsTrie::GetRoot() const { + return root; +} + +void MatchingsTrie::ResetTree(shared_ptr<TrieNode> root) { + if (root != NULL) { + for (auto child: root->children) { + ResetTree(child.second); + } + root.reset(); + } +} diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h new file mode 100644 index 00000000..6e72b2db --- /dev/null +++ b/extractor/matchings_trie.h @@ -0,0 +1,47 @@ +#ifndef _MATCHINGS_TRIE_ +#define _MATCHINGS_TRIE_ + +#include <memory> +#include <unordered_map> + +#include "phrase.h" +#include "phrase_location.h" + +using namespace std; + +struct TrieNode { + TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(), + Phrase phrase = Phrase(), + PhraseLocation matchings = PhraseLocation()) : + suffix_link(suffix_link), phrase(phrase), matchings(matchings) {} + + void AddChild(int key, shared_ptr<TrieNode> child_node) { + children[key] = child_node; + } + + bool HasChild(int key) { + return children.count(key); + } + + shared_ptr<TrieNode> GetChild(int key) { + return children[key]; + } + + shared_ptr<TrieNode> suffix_link; + Phrase phrase; + PhraseLocation matchings; + unordered_map<int, shared_ptr<TrieNode> > children; +}; + +class MatchingsTrie { + public: + void Reset(); + shared_ptr<TrieNode> GetRoot() const; + + private: + void ResetTree(shared_ptr<TrieNode> root); + + shared_ptr<TrieNode> root; +}; + +#endif diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h new file mode 100644 index 00000000..4a5077ad --- /dev/null +++ b/extractor/mocks/mock_alignment.h @@ -0,0 +1,10 @@ +#include <gmock/gmock.h> + +#include "../alignment.h" + +typedef vector<pair<int, int> > SentenceLinks; + +class MockAlignment : public Alignment { + public: + MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id)); +}; diff --git a/extractor/mocks/mock_binary_search_merger.h b/extractor/mocks/mock_binary_search_merger.h new file mode 100644 index 00000000..e23386f0 --- /dev/null +++ b/extractor/mocks/mock_binary_search_merger.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include <vector> + +#include "../binary_search_merger.h" +#include "../phrase.h" + +using namespace std; + +class MockBinarySearchMerger: public BinarySearchMerger { + public: + MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&, + const vector<int>::iterator&, const vector<int>::iterator&, + const vector<int>::iterator&, const vector<int>::iterator&, int, int)); +}; diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h new file mode 100644 index 00000000..004e8906 --- /dev/null +++ b/extractor/mocks/mock_data_array.h @@ -0,0 +1,19 @@ +#include <gmock/gmock.h> + +#include "../data_array.h" + +class MockDataArray : public DataArray { + public: + MOCK_CONST_METHOD0(GetData, const vector<int>&()); + MOCK_CONST_METHOD1(AtIndex, int(int index)); + MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); + MOCK_CONST_METHOD0(GetSize, int()); + MOCK_CONST_METHOD0(GetVocabularySize, int()); + MOCK_CONST_METHOD1(HasWord, bool(const string& word)); + MOCK_CONST_METHOD1(GetWordId, int(const string& word)); + MOCK_CONST_METHOD1(GetWord, string(int word_id)); + MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); + MOCK_CONST_METHOD0(GetNumSentences, int()); + MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id)); + MOCK_CONST_METHOD1(GetSentenceId, int(int position)); +}; diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h new file mode 100644 index 00000000..201386f2 --- /dev/null +++ b/extractor/mocks/mock_fast_intersector.h @@ -0,0 +1,11 @@ +#include <gmock/gmock.h> + +#include "../fast_intersector.h" +#include "../phrase.h" +#include "../phrase_location.h" + +class MockFastIntersector : public FastIntersector { + public: + MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&, + const Phrase&)); +}; diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h new file mode 100644 index 00000000..d2137629 --- /dev/null +++ b/extractor/mocks/mock_feature.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../features/feature.h" + +class MockFeature : public Feature { + public: + MOCK_CONST_METHOD1(Score, double(const FeatureContext& context)); + MOCK_CONST_METHOD0(GetName, string()); +}; diff --git a/extractor/mocks/mock_intersector.h b/extractor/mocks/mock_intersector.h new file mode 100644 index 00000000..372fa7ea --- /dev/null +++ b/extractor/mocks/mock_intersector.h @@ -0,0 +1,11 @@ +#include <gmock/gmock.h> + +#include "../intersector.h" +#include "../phrase.h" +#include "../phrase_location.h" + +class MockIntersector : public Intersector { + public: + MOCK_METHOD5(Intersect, PhraseLocation(const Phrase&, PhraseLocation&, + const Phrase&, PhraseLocation&, const Phrase&)); +}; diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h new file mode 100644 index 00000000..522c1f31 --- /dev/null +++ b/extractor/mocks/mock_linear_merger.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include <vector> + +#include "../linear_merger.h" +#include "../phrase.h" + +using namespace std; + +class MockLinearMerger: public LinearMerger { + public: + MOCK_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&, + vector<int>::iterator, vector<int>::iterator, vector<int>::iterator, + vector<int>::iterator, int, int)); +}; diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h new file mode 100644 index 00000000..3e80d266 --- /dev/null +++ b/extractor/mocks/mock_matchings_finder.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../matchings_finder.h" +#include "../phrase_location.h" + +class MockMatchingsFinder : public MatchingsFinder { + public: + MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int)); +}; diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h new file mode 100644 index 00000000..987bdb2f --- /dev/null +++ b/extractor/mocks/mock_precomputation.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../precomputation.h" + +class MockPrecomputation : public Precomputation { + public: + MOCK_CONST_METHOD0(GetInvertedIndex, const Index&()); + MOCK_CONST_METHOD0(GetCollocations, const Index&()); +}; diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h new file mode 100644 index 00000000..f18e009a --- /dev/null +++ b/extractor/mocks/mock_rule_extractor.h @@ -0,0 +1,12 @@ +#include <gmock/gmock.h> + +#include "../phrase.h" +#include "../phrase_builder.h" +#include "../rule.h" +#include "../rule_extractor.h" + +class MockRuleExtractor : public RuleExtractor { + public: + MOCK_CONST_METHOD2(ExtractRules, vector<Rule>(const Phrase&, + const PhraseLocation&)); +}; diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h new file mode 100644 index 00000000..63ff1048 --- /dev/null +++ b/extractor/mocks/mock_rule_extractor_helper.h @@ -0,0 +1,78 @@ +#include <gmock/gmock.h> + +#include <vector> + +#include "../rule_extractor_helper.h" + +using namespace std; + +typedef unordered_map<int, int> Indexes; + +class MockRuleExtractorHelper : public RuleExtractorHelper { + public: + MOCK_CONST_METHOD5(GetLinksSpans, void(vector<int>&, vector<int>&, + vector<int>&, vector<int>&, int)); + MOCK_CONST_METHOD3(CheckAlignedTerminals, bool(const vector<int>&, + const vector<int>&, const vector<int>&)); + MOCK_CONST_METHOD3(CheckTightPhrases, bool(const vector<int>&, + const vector<int>&, const vector<int>&)); + MOCK_CONST_METHOD1(GetGapOrder, vector<int>(const vector<pair<int, int> >&)); + MOCK_CONST_METHOD3(GetSourceIndexes, Indexes(const vector<int>&, + const vector<int>&, int)); + + // We need to implement these methods, because Google Mock doesn't support + // methods with more than 10 arguments. + bool FindFixPoint( + int, int, const vector<int>&, const vector<int>&, int& target_phrase_low, + int& target_phrase_high, const vector<int>&, const vector<int>&, + int& source_back_low, int& source_back_high, int, int, int, int, bool, + bool, bool) const { + target_phrase_low = this->target_phrase_low; + target_phrase_high = this->target_phrase_high; + source_back_low = this->source_back_low; + source_back_high = this->source_back_high; + return find_fix_point; + } + + bool GetGaps(vector<pair<int, int> >& source_gaps, + vector<pair<int, int> >& target_gaps, + const vector<int>&, const vector<int>&, const vector<int>&, + const vector<int>&, const vector<int>&, const vector<int>&, + int, int, int, int, int& num_symbols, + bool& met_constraints) const { + source_gaps = this->source_gaps; + target_gaps = this->target_gaps; + num_symbols = this->num_symbols; + met_constraints = this->met_constraints; + return get_gaps; + } + + void SetUp( + int target_phrase_low, int target_phrase_high, int source_back_low, + int source_back_high, bool find_fix_point, + vector<pair<int, int> > source_gaps, vector<pair<int, int> > target_gaps, + int num_symbols, bool met_constraints, bool get_gaps) { + this->target_phrase_low = target_phrase_low; + this->target_phrase_high = target_phrase_high; + this->source_back_low = source_back_low; + this->source_back_high = source_back_high; + this->find_fix_point = find_fix_point; + this->source_gaps = source_gaps; + this->target_gaps = target_gaps; + this->num_symbols = num_symbols; + this->met_constraints = met_constraints; + this->get_gaps = get_gaps; + } + + private: + int target_phrase_low; + int target_phrase_high; + int source_back_low; + int source_back_high; + bool find_fix_point; + vector<pair<int, int> > source_gaps; + vector<pair<int, int> > target_gaps; + int num_symbols; + bool met_constraints; + bool get_gaps; +}; diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h new file mode 100644 index 00000000..2a96be93 --- /dev/null +++ b/extractor/mocks/mock_rule_factory.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../grammar.h" +#include "../rule_factory.h" + +class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { + public: + MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids)); +}; diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h new file mode 100644 index 00000000..b2306109 --- /dev/null +++ b/extractor/mocks/mock_sampler.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../phrase_location.h" +#include "../sampler.h" + +class MockSampler : public Sampler { + public: + MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); +}; diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h new file mode 100644 index 00000000..48115ef4 --- /dev/null +++ b/extractor/mocks/mock_scorer.h @@ -0,0 +1,10 @@ +#include <gmock/gmock.h> + +#include "../scorer.h" +#include "../features/feature.h" + +class MockScorer : public Scorer { + public: + MOCK_CONST_METHOD1(Score, vector<double>(const FeatureContext& context)); + MOCK_CONST_METHOD0(GetFeatureNames, vector<string>()); +}; diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h new file mode 100644 index 00000000..11a3a443 --- /dev/null +++ b/extractor/mocks/mock_suffix_array.h @@ -0,0 +1,19 @@ +#include <gmock/gmock.h> + +#include <memory> +#include <string> + +#include "../data_array.h" +#include "../phrase_location.h" +#include "../suffix_array.h" + +using namespace std; + +class MockSuffixArray : public SuffixArray { + public: + MOCK_CONST_METHOD0(GetSize, int()); + MOCK_CONST_METHOD0(GetData, shared_ptr<DataArray>()); + MOCK_CONST_METHOD0(BuildLCPArray, vector<int>()); + MOCK_CONST_METHOD1(GetSuffix, int(int)); + MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int)); +}; diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h new file mode 100644 index 00000000..6dc6bba6 --- /dev/null +++ b/extractor/mocks/mock_target_phrase_extractor.h @@ -0,0 +1,12 @@ +#include <gmock/gmock.h> + +#include "../target_phrase_extractor.h" + +typedef pair<Phrase, PhraseAlignment> PhraseExtract; + +class MockTargetPhraseExtractor : public TargetPhraseExtractor { + public: + MOCK_CONST_METHOD6(ExtractPhrases, vector<PhraseExtract>( + const vector<pair<int, int> > &, const vector<int>&, int, int, + const unordered_map<int, int>&, int)); +}; diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h new file mode 100644 index 00000000..a35c9327 --- /dev/null +++ b/extractor/mocks/mock_translation_table.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../translation_table.h" + +class MockTranslationTable : public TranslationTable { + public: + MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&)); + MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&)); +}; diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h new file mode 100644 index 00000000..e5c191f5 --- /dev/null +++ b/extractor/mocks/mock_vocabulary.h @@ -0,0 +1,9 @@ +#include <gmock/gmock.h> + +#include "../vocabulary.h" + +class MockVocabulary : public Vocabulary { + public: + MOCK_METHOD1(GetTerminalValue, string(int word_id)); + MOCK_METHOD1(GetTerminalIndex, int(const string& word)); +}; diff --git a/extractor/phrase.cc b/extractor/phrase.cc new file mode 100644 index 00000000..6dc242db --- /dev/null +++ b/extractor/phrase.cc @@ -0,0 +1,54 @@ +#include "phrase.h" + +int Phrase::Arity() const { + return var_pos.size(); +} + +int Phrase::GetChunkLen(int index) const { + if (var_pos.size() == 0) { + return symbols.size(); + } else if (index == 0) { + return var_pos[0]; + } else if (index == var_pos.size()) { + return symbols.size() - var_pos.back() - 1; + } else { + return var_pos[index] - var_pos[index - 1] - 1; + } +} + +vector<int> Phrase::Get() const { + return symbols; +} + +int Phrase::GetSymbol(int position) const { + return symbols[position]; +} + +int Phrase::GetNumSymbols() const { + return symbols.size(); +} + +vector<string> Phrase::GetWords() const { + return words; +} + +int Phrase::operator<(const Phrase& other) const { + return symbols < other.symbols; +} + +ostream& operator<<(ostream& os, const Phrase& phrase) { + int current_word = 0; + for (size_t i = 0; i < phrase.symbols.size(); ++i) { + if (phrase.symbols[i] < 0) { + os << "[X," << -phrase.symbols[i] << "]"; + } else { + os << phrase.words[current_word]; + ++current_word; + } + + if (i + 1 < phrase.symbols.size()) { + os << " "; + } + } + return os; +} diff --git a/extractor/phrase.h b/extractor/phrase.h new file mode 100644 index 00000000..f40a8169 --- /dev/null +++ b/extractor/phrase.h @@ -0,0 +1,41 @@ +#ifndef _PHRASE_H_ +#define _PHRASE_H_ + +#include <iostream> +#include <string> +#include <vector> + +#include "phrase_builder.h" + +using namespace std; + +class Phrase { + public: + friend Phrase PhraseBuilder::Build(const vector<int>& phrase); + + int Arity() const; + + int GetChunkLen(int index) const; + + vector<int> Get() const; + + int GetSymbol(int position) const; + + //TODO(pauldb): Unit test this method. + int GetNumSymbols() const; + + //TODO(pauldb): Add unit tests. + vector<string> GetWords() const; + + //TODO(pauldb): Add unit tests. + int operator<(const Phrase& other) const; + + friend ostream& operator<<(ostream& os, const Phrase& phrase); + + private: + vector<int> symbols; + vector<int> var_pos; + vector<string> words; +}; + +#endif diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc new file mode 100644 index 00000000..4325390c --- /dev/null +++ b/extractor/phrase_builder.cc @@ -0,0 +1,44 @@ +#include "phrase_builder.h" + +#include "phrase.h" +#include "vocabulary.h" + +PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) : + vocabulary(vocabulary) {} + +Phrase PhraseBuilder::Build(const vector<int>& symbols) { + Phrase phrase; + phrase.symbols = symbols; + for (size_t i = 0; i < symbols.size(); ++i) { + if (vocabulary->IsTerminal(symbols[i])) { + phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i])); + } else { + phrase.var_pos.push_back(i); + } + } + return phrase; +} + +Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) { + vector<int> symbols = phrase.Get(); + int num_nonterminals = 0; + if (start_x) { + num_nonterminals = 1; + symbols.insert(symbols.begin(), + vocabulary->GetNonterminalIndex(num_nonterminals)); + } + + for (size_t i = start_x; i < symbols.size(); ++i) { + if (!vocabulary->IsTerminal(symbols[i])) { + ++num_nonterminals; + symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals); + } + } + + if (end_x) { + ++num_nonterminals; + symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals)); + } + + return Build(symbols); +} diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h new file mode 100644 index 00000000..a49af457 --- /dev/null +++ b/extractor/phrase_builder.h @@ -0,0 +1,24 @@ +#ifndef _PHRASE_BUILDER_H_ +#define _PHRASE_BUILDER_H_ + +#include <memory> +#include <vector> + +using namespace std; + +class Phrase; +class Vocabulary; + +class PhraseBuilder { + public: + PhraseBuilder(shared_ptr<Vocabulary> vocabulary); + + Phrase Build(const vector<int>& symbols); + + Phrase Extend(const Phrase& phrase, bool start_x, bool end_x); + + private: + shared_ptr<Vocabulary> vocabulary; +}; + +#endif diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc new file mode 100644 index 00000000..b0bfed80 --- /dev/null +++ b/extractor/phrase_location.cc @@ -0,0 +1,39 @@ +#include "phrase_location.h" + +PhraseLocation::PhraseLocation(int sa_low, int sa_high) : + sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {} + +PhraseLocation::PhraseLocation(const vector<int>& matchings, + int num_subpatterns) : + sa_low(0), sa_high(0), + matchings(make_shared<vector<int> >(matchings)), + num_subpatterns(num_subpatterns) {} + +bool PhraseLocation::IsEmpty() const { + return GetSize() == 0; +} + +int PhraseLocation::GetSize() const { + if (num_subpatterns > 0) { + return matchings->size(); + } else { + return sa_high - sa_low; + } +} + +bool operator==(const PhraseLocation& a, const PhraseLocation& b) { + if (a.sa_low != b.sa_low || a.sa_high != b.sa_high || + a.num_subpatterns != b.num_subpatterns) { + return false; + } + + if (a.matchings == NULL && b.matchings == NULL) { + return true; + } + + if (a.matchings == NULL || b.matchings == NULL) { + return false; + } + + return *a.matchings == *b.matchings; +} diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h new file mode 100644 index 00000000..a0eb36c8 --- /dev/null +++ b/extractor/phrase_location.h @@ -0,0 +1,25 @@ +#ifndef _PHRASE_LOCATION_H_ +#define _PHRASE_LOCATION_H_ + +#include <memory> +#include <vector> + +using namespace std; + +struct PhraseLocation { + PhraseLocation(int sa_low = -1, int sa_high = -1); + + PhraseLocation(const vector<int>& matchings, int num_subpatterns); + + bool IsEmpty() const; + + int GetSize() const; + + friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); + + int sa_low, sa_high; + shared_ptr<vector<int> > matchings; + int num_subpatterns; +}; + +#endif diff --git a/extractor/phrase_test.cc b/extractor/phrase_test.cc new file mode 100644 index 00000000..2b553b6f --- /dev/null +++ b/extractor/phrase_test.cc @@ -0,0 +1,61 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class PhraseTest : public Test { + protected: + virtual void SetUp() { + shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + shared_ptr<PhraseBuilder> phrase_builder = + make_shared<PhraseBuilder>(vocabulary); + + symbols1 = vector<int>{1, 2, 3}; + phrase1 = phrase_builder->Build(symbols1); + symbols2 = vector<int>{1, 2, -1, 3, -2, 4}; + phrase2 = phrase_builder->Build(symbols2); + } + + vector<int> symbols1, symbols2; + Phrase phrase1, phrase2; +}; + +TEST_F(PhraseTest, TestArity) { + EXPECT_EQ(0, phrase1.Arity()); + EXPECT_EQ(2, phrase2.Arity()); +} + +TEST_F(PhraseTest, GetChunkLen) { + EXPECT_EQ(3, phrase1.GetChunkLen(0)); + + EXPECT_EQ(2, phrase2.GetChunkLen(0)); + EXPECT_EQ(1, phrase2.GetChunkLen(1)); + EXPECT_EQ(1, phrase2.GetChunkLen(2)); +} + +TEST_F(PhraseTest, TestGet) { + EXPECT_EQ(symbols1, phrase1.Get()); + EXPECT_EQ(symbols2, phrase2.Get()); +} + +TEST_F(PhraseTest, TestGetSymbol) { + for (size_t i = 0; i < symbols1.size(); ++i) { + EXPECT_EQ(symbols1[i], phrase1.GetSymbol(i)); + } + for (size_t i = 0; i < symbols2.size(); ++i) { + EXPECT_EQ(symbols2[i], phrase2.GetSymbol(i)); + } +} + +} // namespace diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc new file mode 100644 index 00000000..8a76beb1 --- /dev/null +++ b/extractor/precomputation.cc @@ -0,0 +1,192 @@ +#include "precomputation.h" + +#include <iostream> +#include <queue> + +#include "data_array.h" +#include "suffix_array.h" + +using namespace std; + +int Precomputation::NON_TERMINAL = -1; + +Precomputation::Precomputation( + shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, + int num_super_frequent_patterns, int max_rule_span, + int max_rule_symbols, int min_gap_size, + int max_frequent_phrase_len, int min_frequency) { + vector<int> data = suffix_array->GetData()->GetData(); + vector<vector<int> > frequent_patterns = FindMostFrequentPatterns( + suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + min_frequency); + + unordered_set<vector<int>, VectorHash> frequent_patterns_set; + unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; + for (size_t i = 0; i < frequent_patterns.size(); ++i) { + frequent_patterns_set.insert(frequent_patterns[i]); + if (i < num_super_frequent_patterns) { + super_frequent_patterns_set.insert(frequent_patterns[i]); + } + } + + vector<tuple<int, int, int> > matchings; + for (size_t i = 0; i < data.size(); ++i) { + if (data[i] == DataArray::END_OF_LINE) { + AddCollocations(matchings, data, max_rule_span, min_gap_size, + max_rule_symbols); + matchings.clear(); + continue; + } + vector<int> pattern; + for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { + pattern.push_back(data[i + j - 1]); + if (frequent_patterns_set.count(pattern)) { + inverted_index[pattern].push_back(i); + int is_super_frequent = super_frequent_patterns_set.count(pattern); + matchings.push_back(make_tuple(i, j, is_super_frequent)); + } else { + // If the current pattern is not frequent, any longer pattern having the + // current pattern as prefix will not be frequent. + break; + } + } + } +} + +Precomputation::Precomputation() {} + +Precomputation::~Precomputation() {} + +vector<vector<int> > Precomputation::FindMostFrequentPatterns( + shared_ptr<SuffixArray> suffix_array, const vector<int>& data, + int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + vector<int> lcp = suffix_array->BuildLCPArray(); + vector<int> run_start(max_frequent_phrase_len); + + priority_queue<pair<int, pair<int, int> > > heap; + for (size_t i = 1; i < lcp.size(); ++i) { + for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { + int frequency = i - run_start[len]; + // TODO(pauldb): Only add patterns that don't span across multiple + // sentences. + if (frequency >= min_frequency) { + heap.push(make_pair(frequency, + make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); + } + run_start[len] = i; + } + } + + vector<vector<int> > frequent_patterns; + while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { + int start = heap.top().second.first; + int len = heap.top().second.second; + heap.pop(); + + vector<int> pattern(data.begin() + start, data.begin() + start + len); + if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == + pattern.end()) { + frequent_patterns.push_back(pattern); + } + } + return frequent_patterns; +} + +void Precomputation::AddCollocations( + const vector<tuple<int, int, int> >& matchings, const vector<int>& data, + int max_rule_span, int min_gap_size, int max_rule_symbols) { + for (size_t i = 0; i < matchings.size(); ++i) { + int start1, size1, is_super1; + tie(start1, size1, is_super1) = matchings[i]; + + for (size_t j = i + 1; j < matchings.size(); ++j) { + int start2, size2, is_super2; + tie(start2, size2, is_super2) = matchings[j]; + if (start2 - start1 >= max_rule_span) { + break; + } + + if (start2 - start1 - size1 >= min_gap_size + && start2 + size2 - start1 <= max_rule_span + && size1 + size2 + 1 <= max_rule_symbols) { + vector<int> pattern(data.begin() + start1, + data.begin() + start1 + size1); + pattern.push_back(Precomputation::NON_TERMINAL); + pattern.insert(pattern.end(), data.begin() + start2, + data.begin() + start2 + size2); + AddStartPositions(collocations[pattern], start1, start2); + + if (is_super2) { + pattern.push_back(Precomputation::NON_TERMINAL); + for (size_t k = j + 1; k < matchings.size(); ++k) { + int start3, size3, is_super3; + tie(start3, size3, is_super3) = matchings[k]; + if (start3 - start1 >= max_rule_span) { + break; + } + + if (start3 - start2 - size2 >= min_gap_size + && start3 + size3 - start1 <= max_rule_span + && size1 + size2 + size3 + 2 <= max_rule_symbols + && (is_super1 || is_super3)) { + pattern.insert(pattern.end(), data.begin() + start3, + data.begin() + start3 + size3); + AddStartPositions(collocations[pattern], start1, start2, start3); + pattern.erase(pattern.end() - size3); + } + } + } + } + } + } +} + +void Precomputation::AddStartPositions( + vector<int>& positions, int pos1, int pos2) { + positions.push_back(pos1); + positions.push_back(pos2); +} + +void Precomputation::AddStartPositions( + vector<int>& positions, int pos1, int pos2, int pos3) { + positions.push_back(pos1); + positions.push_back(pos2); + positions.push_back(pos3); +} + +void Precomputation::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "w"); + + // TODO(pauldb): Refactor this code. + int size = inverted_index.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: inverted_index) { + size = entry.first.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.first.data(), sizeof(int), size, file); + + size = entry.second.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.second.data(), sizeof(int), size, file); + } + + size = collocations.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: collocations) { + size = entry.first.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.first.data(), sizeof(int), size, file); + + size = entry.second.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.second.data(), sizeof(int), size, file); + } +} + +const Index& Precomputation::GetInvertedIndex() const { + return inverted_index; +} + +const Index& Precomputation::GetCollocations() const { + return collocations; +} diff --git a/extractor/precomputation.h b/extractor/precomputation.h new file mode 100644 index 00000000..28426bfa --- /dev/null +++ b/extractor/precomputation.h @@ -0,0 +1,56 @@ +#ifndef _PRECOMPUTATION_H_ +#define _PRECOMPUTATION_H_ + +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <tuple> +#include <vector> + +#include <boost/filesystem.hpp> +#include <boost/functional/hash.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +class SuffixArray; + +typedef boost::hash<vector<int> > VectorHash; +typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; + +class Precomputation { + public: + Precomputation( + shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, + int num_super_frequent_patterns, int max_rule_span, + int max_rule_symbols, int min_gap_size, + int max_frequent_phrase_len, int min_frequency); + + virtual ~Precomputation(); + + void WriteBinary(const fs::path& filepath) const; + + virtual const Index& GetInvertedIndex() const; + virtual const Index& GetCollocations() const; + + static int NON_TERMINAL; + + protected: + Precomputation(); + + private: + vector<vector<int> > FindMostFrequentPatterns( + shared_ptr<SuffixArray> suffix_array, const vector<int>& data, + int num_frequent_patterns, int max_frequent_phrase_len, + int min_frequency); + void AddCollocations( + const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data, + int max_rule_span, int min_gap_size, int max_rule_symbols); + void AddStartPositions(vector<int>& positions, int pos1, int pos2); + void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); + + Index inverted_index; + Index collocations; +}; + +#endif diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc new file mode 100644 index 00000000..9edb29db --- /dev/null +++ b/extractor/precomputation_test.cc @@ -0,0 +1,138 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "precomputation.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class PrecomputationTest : public Test { + protected: + virtual void SetUp() { + data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; + data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + + vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; + vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; + suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, + GetSuffix(i)).WillRepeatedly(Return(suffixes[i])); + } + EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); + } + + vector<int> data; + shared_ptr<MockDataArray> data_array; + shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(PrecomputationTest, TestInvertedIndex) { + Precomputation precomputation(suffix_array, 100, 3, 10, 5, 1, 4, 2); + Index inverted_index = precomputation.GetInvertedIndex(); + + EXPECT_EQ(8, inverted_index.size()); + vector<int> key = {2}; + vector<int> expected_value = {1, 5, 8, 11}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {3}; + expected_value = {2, 6, 9}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {4}; + expected_value = {0, 10}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {5}; + expected_value = {3, 7}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {4, 2}; + expected_value = {0, 10}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {2, 3}; + expected_value = {1, 5, 8}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {3, 5}; + expected_value = {2, 6}; + EXPECT_EQ(expected_value, inverted_index[key]); + key = {2, 3, 5}; + expected_value = {1, 5}; + EXPECT_EQ(expected_value, inverted_index[key]); + + key = {2, 4}; + EXPECT_EQ(0, inverted_index.count(key)); +} + +TEST_F(PrecomputationTest, TestCollocations) { + Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); + Index collocations = precomputation.GetCollocations(); + + EXPECT_EQ(-1, precomputation.NON_TERMINAL); + vector<int> key = {2, 3, -1, 2}; + vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, 3, -1, 2, 3}; + expected_value = {1, 5, 1, 8, 5, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, 3, -1, 3}; + expected_value = {1, 6, 1, 9, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2}; + expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3}; + expected_value = {2, 6, 2, 9, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, 3}; + expected_value = {2, 5, 2, 8, 6, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2}; + expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2, 3}; + expected_value = {1, 5, 1, 8, 5, 8}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3}; + expected_value = {1, 6, 1, 9, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + + key = {2, -1, 2, -1, 2}; + expected_value = {1, 5, 8, 5, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 2, -1, 3}; + expected_value = {1, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3, -1, 2}; + expected_value = {1, 6, 8, 5, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {2, -1, 3, -1, 3}; + expected_value = {1, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, -1, 2}; + expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 2, -1, 3}; + expected_value = {2, 5, 9}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3, -1, 2}; + expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; + EXPECT_EQ(expected_value, collocations[key]); + key = {3, -1, 3, -1, 3}; + expected_value = {2, 6, 9}; + EXPECT_EQ(expected_value, collocations[key]); + + // Exceeds max_rule_symbols. + key = {2, -1, 2, -1, 2, 3}; + EXPECT_EQ(0, collocations.count(key)); + // Contains non frequent pattern. + key = {2, -1, 5}; + EXPECT_EQ(0, collocations.count(key)); +} + +} // namespace diff --git a/extractor/rule.cc b/extractor/rule.cc new file mode 100644 index 00000000..9c7ac9b5 --- /dev/null +++ b/extractor/rule.cc @@ -0,0 +1,10 @@ +#include "rule.h" + +Rule::Rule(const Phrase& source_phrase, + const Phrase& target_phrase, + const vector<double>& scores, + const vector<pair<int, int> >& alignment) : + source_phrase(source_phrase), + target_phrase(target_phrase), + scores(scores), + alignment(alignment) {} diff --git a/extractor/rule.h b/extractor/rule.h new file mode 100644 index 00000000..64ff8794 --- /dev/null +++ b/extractor/rule.h @@ -0,0 +1,20 @@ +#ifndef _RULE_H_ +#define _RULE_H_ + +#include <vector> + +#include "phrase.h" + +using namespace std; + +struct Rule { + Rule(const Phrase& source_phrase, const Phrase& target_phrase, + const vector<double>& scores, const vector<pair<int, int> >& alignment); + + Phrase source_phrase; + Phrase target_phrase; + vector<double> scores; + vector<pair<int, int> > alignment; +}; + +#endif diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc new file mode 100644 index 00000000..92343241 --- /dev/null +++ b/extractor/rule_extractor.cc @@ -0,0 +1,315 @@ +#include "rule_extractor.h" + +#include <map> + +#include "alignment.h" +#include "data_array.h" +#include "features/feature.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule.h" +#include "rule_extractor_helper.h" +#include "scorer.h" +#include "target_phrase_extractor.h" + +using namespace std; + +RuleExtractor::RuleExtractor( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + target_data_array(target_data_array), + source_data_array(source_data_array), + phrase_builder(phrase_builder), + scorer(scorer), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), + require_tight_phrases(require_tight_phrases) { + helper = make_shared<RuleExtractorHelper>( + source_data_array, target_data_array, alignment, max_rule_span, + max_rule_symbols, require_aligned_terminal, require_aligned_chunks, + require_tight_phrases); + target_phrase_extractor = make_shared<TargetPhraseExtractor>( + target_data_array, alignment, phrase_builder, helper, vocabulary, + max_rule_span, require_tight_phrases); +} + +RuleExtractor::RuleExtractor( + shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases) : + source_data_array(source_data_array), + phrase_builder(phrase_builder), + scorer(scorer), + target_phrase_extractor(target_phrase_extractor), + helper(helper), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), + require_tight_phrases(require_tight_phrases) {} + +RuleExtractor::RuleExtractor() {} + +RuleExtractor::~RuleExtractor() {} + +vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const { + int num_subpatterns = location.num_subpatterns; + vector<int> matchings = *location.matchings; + + map<Phrase, double> source_phrase_counter; + map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter; + for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { + vector<int> matching(i, i + num_subpatterns); + vector<Extract> extracts = ExtractAlignments(phrase, matching); + + for (Extract e: extracts) { + source_phrase_counter[e.source_phrase] += e.pairs_count; + alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1; + } + } + + int num_samples = matchings.size() / num_subpatterns; + vector<Rule> rules; + for (auto source_phrase_entry: alignments_counter) { + Phrase source_phrase = source_phrase_entry.first; + for (auto target_phrase_entry: source_phrase_entry.second) { + Phrase target_phrase = target_phrase_entry.first; + + int max_locations = 0, num_locations = 0; + PhraseAlignment most_frequent_alignment; + for (auto alignment_entry: target_phrase_entry.second) { + num_locations += alignment_entry.second; + if (alignment_entry.second > max_locations) { + most_frequent_alignment = alignment_entry.first; + max_locations = alignment_entry.second; + } + } + + FeatureContext context(source_phrase, target_phrase, + source_phrase_counter[source_phrase], num_locations, num_samples); + vector<double> scores = scorer->Score(context); + rules.push_back(Rule(source_phrase, target_phrase, scores, + most_frequent_alignment)); + } + } + return rules; +} + +vector<Extract> RuleExtractor::ExtractAlignments( + const Phrase& phrase, const vector<int>& matching) const { + vector<Extract> extracts; + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + vector<int> source_low, source_high, target_low, target_high; + helper->GetLinksSpans(source_low, source_high, target_low, target_high, + sentence_id); + + int num_subpatterns = matching.size(); + vector<int> chunklen(num_subpatterns); + for (size_t i = 0; i < num_subpatterns; ++i) { + chunklen[i] = phrase.GetChunkLen(i); + } + + if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) || + !helper->CheckTightPhrases(matching, chunklen, source_low)) { + return extracts; + } + + int source_back_low = -1, source_back_high = -1; + int source_phrase_low = matching[0] - source_sent_start; + int source_phrase_high = matching.back() + chunklen.back() - + source_sent_start; + int target_phrase_low = -1, target_phrase_high = -1; + if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high, + target_low, target_high, source_back_low, + source_back_high, sentence_id, min_gap_size, 0, + max_nonterminals - matching.size() + 1, true, true, + false)) { + return extracts; + } + + bool met_constraints = true; + int num_symbols = phrase.GetNumSymbols(); + vector<pair<int, int> > source_gaps, target_gaps; + if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low, + source_high, target_low, target_high, source_phrase_low, + source_phrase_high, source_back_low, source_back_high, + num_symbols, met_constraints)) { + return extracts; + } + + bool starts_with_x = source_back_low != source_phrase_low; + bool ends_with_x = source_back_high != source_phrase_high; + Phrase source_phrase = phrase_builder->Extend( + phrase, starts_with_x, ends_with_x); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, starts_with_x); + if (met_constraints) { + AddExtracts(extracts, source_phrase, source_indexes, target_gaps, + target_low, target_phrase_low, target_phrase_high, sentence_id); + } + + if (source_gaps.size() >= max_nonterminals || + source_phrase.GetNumSymbols() >= max_rule_symbols || + source_back_high - source_back_low + min_gap_size > max_rule_span) { + // Cannot add any more nonterminals. + return extracts; + } + + for (int i = 0; i < 2; ++i) { + for (int j = 1 - i; j < 2; ++j) { + AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, + source_back_low, source_back_high, source_low, source_high, + target_low, target_high, target_gaps, sentence_id, starts_with_x, + ends_with_x, i, j); + } + } + + return extracts; +} + +void RuleExtractor::AddExtracts( + vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const { + auto target_phrases = target_phrase_extractor->ExtractPhrases( + target_gaps, target_low, target_phrase_low, target_phrase_high, + source_indexes, sentence_id); + + if (target_phrases.size() > 0) { + double pairs_count = 1.0 / target_phrases.size(); + for (auto target_phrase: target_phrases) { + extracts.push_back(Extract(source_phrase, target_phrase.first, + pairs_count, target_phrase.second)); + } + } +} + +void RuleExtractor::AddNonterminalExtremities( + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, + const vector<int>& source_high, const vector<int>& target_low, + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int starts_with_x, int ends_with_x, int extend_left, + int extend_right) const { + int source_x_low = source_back_low, source_x_high = source_back_high; + + if (require_tight_phrases) { + if (source_low[source_back_low - extend_left] == -1 || + source_low[source_back_high + extend_right - 1] == -1) { + return; + } + } + + if (extend_left) { + if (starts_with_x || source_back_low < min_gap_size) { + return; + } + + source_x_low = source_back_low - min_gap_size; + if (require_tight_phrases) { + while (source_x_low >= 0 && source_low[source_x_low] == -1) { + --source_x_low; + } + } + if (source_x_low < 0) { + return; + } + } + + if (extend_right) { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { + return; + } + source_x_high = source_back_high + min_gap_size; + if (require_tight_phrases) { + while (source_x_high <= source_sent_len && + source_low[source_x_high - 1] == -1) { + ++source_x_high; + } + } + + if (source_x_high > source_sent_len) { + return; + } + } + + int new_nonterminals = extend_left + extend_right; + if (source_x_high - source_x_low > max_rule_span || + target_gaps.size() + new_nonterminals > max_nonterminals || + source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) { + return; + } + + int target_x_low = -1, target_x_high = -1; + if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, + source_high, target_x_low, target_x_high, + target_low, target_high, source_x_low, + source_x_high, sentence_id, 1, 1, + new_nonterminals, extend_left, extend_right, + true)) { + return; + } + + if (extend_left) { + int source_gap_low = -1, source_gap_high = -1; + int target_gap_low = -1, target_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_low] == -1) || + !helper->FindFixPoint(source_x_low, source_back_low, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.insert(target_gaps.begin(), + make_pair(target_gap_low, target_gap_high)); + } + + if (extend_right) { + int target_gap_low = -1, target_gap_high = -1; + int source_gap_low = -1, source_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_high - 1] == -1) || + !helper->FindFixPoint(source_back_high, source_x_high, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); + } + + Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, + extend_right); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, extend_left || starts_with_x); + AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps, + target_low, target_x_low, target_x_high, sentence_id); +} diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h new file mode 100644 index 00000000..a087dc6d --- /dev/null +++ b/extractor/rule_extractor.h @@ -0,0 +1,104 @@ +#ifndef _RULE_EXTRACTOR_H_ +#define _RULE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include "phrase.h" + +using namespace std; + +class Alignment; +class DataArray; +class PhraseBuilder; +class PhraseLocation; +class Rule; +class RuleExtractorHelper; +class Scorer; +class TargetPhraseExtractor; + +typedef vector<pair<int, int> > PhraseAlignment; + +struct Extract { + Extract(const Phrase& source_phrase, const Phrase& target_phrase, + double pairs_count, const PhraseAlignment& alignment) : + source_phrase(source_phrase), target_phrase(target_phrase), + pairs_count(pairs_count), alignment(alignment) {} + + Phrase source_phrase; + Phrase target_phrase; + double pairs_count; + PhraseAlignment alignment; +}; + +class RuleExtractor { + public: + RuleExtractor(shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alingment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<Vocabulary> vocabulary, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases); + + // For testing only. + RuleExtractor(shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases); + + virtual ~RuleExtractor(); + + virtual vector<Rule> ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const; + + protected: + RuleExtractor(); + + private: + vector<Extract> ExtractAlignments(const Phrase& phrase, + const vector<int>& matching) const; + + void AddExtracts( + vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const; + + void AddNonterminalExtremities( + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, + const vector<int>& source_high, const vector<int>& target_low, + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int starts_with_x, int ends_with_x, int extend_left, + int extend_right) const; + + private: + shared_ptr<DataArray> target_data_array; + shared_ptr<DataArray> source_data_array; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<Scorer> scorer; + shared_ptr<TargetPhraseExtractor> target_phrase_extractor; + shared_ptr<RuleExtractorHelper> helper; + int max_rule_span; + int min_gap_size; + int max_nonterminals; + int max_rule_symbols; + bool require_tight_phrases; +}; + +#endif diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc new file mode 100644 index 00000000..ed6ae3a1 --- /dev/null +++ b/extractor/rule_extractor_helper.cc @@ -0,0 +1,356 @@ +#include "rule_extractor_helper.h" + +#include "data_array.h" +#include "alignment.h" + +RuleExtractorHelper::RuleExtractorHelper( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + int max_rule_span, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + source_data_array(source_data_array), + target_data_array(target_data_array), + alignment(alignment), + max_rule_span(max_rule_span), + max_rule_symbols(max_rule_symbols), + require_aligned_terminal(require_aligned_terminal), + require_aligned_chunks(require_aligned_chunks), + require_tight_phrases(require_tight_phrases) {} + +RuleExtractorHelper::RuleExtractorHelper() {} + +RuleExtractorHelper::~RuleExtractorHelper() {} + +void RuleExtractorHelper::GetLinksSpans( + vector<int>& source_low, vector<int>& source_high, + vector<int>& target_low, vector<int>& target_high, int sentence_id) const { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + source_low = vector<int>(source_sent_len, -1); + source_high = vector<int>(source_sent_len, -1); + + // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we + // can speed it up. + target_low = vector<int>(target_sent_len, -1); + target_high = vector<int>(target_sent_len, -1); + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + for (auto link: links) { + if (source_low[link.first] == -1 || source_low[link.first] > link.second) { + source_low[link.first] = link.second; + } + source_high[link.first] = max(source_high[link.first], link.second + 1); + + if (target_low[link.second] == -1 || target_low[link.second] > link.first) { + target_low[link.second] = link.first; + } + target_high[link.second] = max(target_high[link.second], link.first + 1); + } +} + +bool RuleExtractorHelper::CheckAlignedTerminals( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const { + if (!require_aligned_terminal) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + int num_aligned_chunks = 0; + for (size_t i = 0; i < chunklen.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + int sent_index = matching[i] - source_sent_start + j; + if (source_low[sent_index] != -1) { + ++num_aligned_chunks; + break; + } + } + } + + if (num_aligned_chunks == 0) { + return false; + } + + return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); +} + +bool RuleExtractorHelper::CheckTightPhrases( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const { + if (!require_tight_phrases) { + return true; + } + + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - 1 - source_sent_start; + if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { + return false; + } + } + + return true; +} + +bool RuleExtractorHelper::FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector<int>& target_low, const vector<int>& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, bool allow_low_x, bool allow_high_x, + bool allow_arbitrary_expansion) const { + int prev_target_low = target_phrase_low; + int prev_target_high = target_phrase_high; + + FindProjection(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high); + + if (target_phrase_low == -1) { + // TODO(pauldb): Low priority corner case inherited from Adam's code: + // If w is unaligned, but we don't require aligned terminals, returning an + // error here prevents the extraction of the allowed rule + // X -> X_1 w X_2 / X_1 X_2 + return false; + } + + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + if (prev_target_low != -1 && target_phrase_low != prev_target_low) { + if (prev_target_low - target_phrase_low < min_target_gap_size) { + target_phrase_low = prev_target_low - min_target_gap_size; + if (target_phrase_low < 0) { + return false; + } + } + } + + if (prev_target_high != -1 && target_phrase_high != prev_target_high) { + if (target_phrase_high - prev_target_high < min_target_gap_size) { + target_phrase_high = prev_target_high + min_target_gap_size; + if (target_phrase_high > target_sent_len) { + return false; + } + } + } + + if (target_phrase_high - target_phrase_low > max_rule_span) { + return false; + } + + source_back_low = source_back_high = -1; + FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, + source_back_low, source_back_high); + int new_x = 0; + bool new_low_x = false, new_high_x = false; + while (true) { + source_back_low = min(source_back_low, source_phrase_low); + source_back_high = max(source_back_high, source_phrase_high); + + if (source_back_low == source_phrase_low && + source_back_high == source_phrase_high) { + return true; + } + + if (!allow_low_x && source_back_low < source_phrase_low) { + // Extension on the left side not allowed. + return false; + } + if (!allow_high_x && source_back_high > source_phrase_high) { + // Extension on the right side not allowed. + return false; + } + + // Extend left side. + if (source_back_low < source_phrase_low) { + if (new_low_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_low_x = true; + ++new_x; + } + if (source_phrase_low - source_back_low < min_source_gap_size) { + source_back_low = source_phrase_low - min_source_gap_size; + if (source_back_low < 0) { + return false; + } + } + } + + // Extend right side. + if (source_back_high > source_phrase_high) { + if (new_high_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_high_x = true; + ++new_x; + } + if (source_back_high - source_phrase_high < min_source_gap_size) { + source_back_high = source_phrase_high + min_source_gap_size; + if (source_back_high > source_sent_len) { + return false; + } + } + } + + if (source_back_high - source_back_low > max_rule_span) { + // Rule span too wide. + return false; + } + + prev_target_low = target_phrase_low; + prev_target_high = target_phrase_high; + FindProjection(source_back_low, source_phrase_low, source_low, source_high, + target_phrase_low, target_phrase_high); + FindProjection(source_phrase_high, source_back_high, source_low, + source_high, target_phrase_low, target_phrase_high); + if (prev_target_low == target_phrase_low && + prev_target_high == target_phrase_high) { + return true; + } + + if (!allow_arbitrary_expansion) { + // Arbitrary expansion not allowed. + return false; + } + if (target_phrase_high - target_phrase_low > max_rule_span) { + // Target side too wide. + return false; + } + + source_phrase_low = source_back_low; + source_phrase_high = source_back_high; + FindProjection(target_phrase_low, prev_target_low, target_low, target_high, + source_back_low, source_back_high); + FindProjection(prev_target_high, target_phrase_high, target_low, + target_high, source_back_low, source_back_high); + } + + return false; +} + +void RuleExtractorHelper::FindProjection( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high) const { + for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { + if (source_low[i] != -1) { + if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { + target_phrase_low = source_low[i]; + } + target_phrase_high = max(target_phrase_high, source_high[i]); + } + } +} + +bool RuleExtractorHelper::GetGaps( + vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, + const vector<int>& matching, const vector<int>& chunklen, + const vector<int>& source_low, const vector<int>& source_high, + const vector<int>& target_low, const vector<int>& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const { + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + if (source_back_low < source_phrase_low) { + source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_back_low] == -1 || + source_low[source_phrase_low - 1] == -1)) { + // Inside edges of preceding gap are not tight. + return false; + } + } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - source_sent_start; + source_gaps.push_back(make_pair(gap_start, gap_end)); + } + + if (source_phrase_high < source_back_high) { + source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_phrase_high] == -1 || + source_low[source_back_high - 1] == -1)) { + // Inside edges of following gap are not tight. + return false; + } + } else if (require_tight_phrases && + source_low[source_phrase_high - 1] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); + for (size_t i = 0; i < source_gaps.size(); ++i) { + if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, + source_high, target_gaps[i].first, target_gaps[i].second, + target_low, target_high, source_gaps[i].first, + source_gaps[i].second, sentence_id, 0, 0, 0, false, false, + false)) { + // Gap fails integrity check. + return false; + } + } + + return true; +} + +vector<int> RuleExtractorHelper::GetGapOrder( + const vector<pair<int, int> >& gaps) const { + vector<int> gap_order(gaps.size()); + for (size_t i = 0; i < gap_order.size(); ++i) { + for (size_t j = 0; j < i; ++j) { + if (gaps[gap_order[j]] < gaps[i]) { + ++gap_order[i]; + } else { + ++gap_order[j]; + } + } + } + return gap_order; +} + +unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes( + const vector<int>& matching, const vector<int>& chunklen, + int starts_with_x) const { + unordered_map<int, int> source_indexes; + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + int num_symbols = starts_with_x; + for (size_t i = 0; i < matching.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + source_indexes[matching[i] + j - source_sent_start] = num_symbols; + ++num_symbols; + } + ++num_symbols; + } + return source_indexes; +} diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h new file mode 100644 index 00000000..3478bfc8 --- /dev/null +++ b/extractor/rule_extractor_helper.h @@ -0,0 +1,82 @@ +#ifndef _RULE_EXTRACTOR_HELPER_H_ +#define _RULE_EXTRACTOR_HELPER_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +class Alignment; +class DataArray; + +class RuleExtractorHelper { + public: + RuleExtractorHelper(shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + int max_rule_span, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases); + + virtual ~RuleExtractorHelper(); + + virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high, + vector<int>& target_low, vector<int>& target_high, + int sentence_id) const; + + virtual bool CheckAlignedTerminals(const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const; + + virtual bool CheckTightPhrases(const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low) const; + + virtual bool FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector<int>& target_low, const vector<int>& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, bool allow_low_x, bool allow_high_x, + bool allow_arbitrary_expansion) const; + + virtual bool GetGaps( + vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, + const vector<int>& matching, const vector<int>& chunklen, + const vector<int>& source_low, const vector<int>& source_high, + const vector<int>& target_low, const vector<int>& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int& num_symbols, bool& met_constraints) const; + + virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const; + + // TODO(pauldb): Add unit tests. + virtual unordered_map<int, int> GetSourceIndexes( + const vector<int>& matching, const vector<int>& chunklen, + int starts_with_x) const; + + protected: + RuleExtractorHelper(); + + private: + void FindProjection( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high) const; + + shared_ptr<DataArray> source_data_array; + shared_ptr<DataArray> target_data_array; + shared_ptr<Alignment> alignment; + int max_rule_span; + int max_rule_symbols; + bool require_aligned_terminal; + bool require_aligned_chunks; + bool require_tight_phrases; +}; + +#endif diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc new file mode 100644 index 00000000..29213312 --- /dev/null +++ b/extractor/rule_extractor_helper_test.cc @@ -0,0 +1,622 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "rule_extractor_helper.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleExtractorHelperTest : public Test { + protected: + virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(12)); + EXPECT_CALL(*source_data_array, GetSentenceId(_)) + .WillRepeatedly(Return(5)); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)) + .WillRepeatedly(Return(10)); + + target_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(12)); + + vector<pair<int, int> > links = { + make_pair(0, 0), make_pair(0, 1), make_pair(2, 2), make_pair(3, 1) + }; + alignment = make_shared<MockAlignment>(); + EXPECT_CALL(*alignment, GetLinks(_)).WillRepeatedly(Return(links)); + } + + shared_ptr<MockDataArray> source_data_array; + shared_ptr<MockDataArray> target_data_array; + shared_ptr<MockAlignment> alignment; + shared_ptr<RuleExtractorHelper> helper; +}; + +TEST_F(RuleExtractorHelperTest, TestGetLinksSpans) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(4)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low, source_high, target_low, target_high; + helper->GetLinksSpans(source_low, source_high, target_low, target_high, 0); + + vector<int> expected_source_low = {0, -1, 2, 1}; + EXPECT_EQ(expected_source_low, source_low); + vector<int> expected_source_high = {2, -1, 3, 2}; + EXPECT_EQ(expected_source_high, source_high); + vector<int> expected_target_low = {0, 0, 2}; + EXPECT_EQ(expected_target_low, target_low); + vector<int> expected_target_high = {1, 4, 3}; + EXPECT_EQ(expected_target_high, target_high); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedFalse) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, false, false, true); + EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + + vector<int> matching, chunklen, source_low; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedTerminal) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, false, true); + + vector<int> matching = {10, 12}; + vector<int> chunklen = {1, 3}; + vector<int> source_low = {-1, 1, -1, 3, -1}; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {-1, 1, -1, -1, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedChunks) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> matching = {10, 12}; + vector<int> chunklen = {1, 3}; + vector<int> source_low = {2, 1, -1, 3, -1}; + EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {-1, 1, -1, 3, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); + source_low = {2, 1, -1, -1, -1}; + EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low)); +} + + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrasesFalse) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, false); + EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + + vector<int> matching, chunklen, source_low; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrases) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> matching = {10, 14, 18}; + vector<int> chunklen = {2, 3, 1}; + // No missing links. + vector<int> source_low = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); + + // Missing link at the beginning or ending of a gap. + source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + source_low = {0, 1, 2, -1, 4, 5, 6, 7, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + source_low = {0, 1, 2, 3, 4, 5, 6, -1, 8}; + EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low)); + + // Missing link inside the gap. + chunklen = {1, 3, 1}; + source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; + EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointBadEdgeCase) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<int> source_low = {0, -1, 2}; + vector<int> source_high = {1, -1, 3}; + vector<int> target_low = {0, -1, 2}; + vector<int> target_high = {1, -1, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = 1; + + // This should be in fact true. See comment about the inherited bug. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 0, 0, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSentenceOutOfBounds) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = 1, target_phrase_high = 2; + + // Extend out of sentence to left. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 2, + 0, false, false, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 2, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetTooWide) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 5, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {0, 0, 0, 0, 0, 0, 0}; + vector<int> source_high = {7, 7, 7, 7, 7, 7, 7}; + vector<int> target_low = {0, -1, -1, -1, -1, -1, 0}; + vector<int> target_high = {7, -1, -1, -1, -1, -1, 7}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Projection is too wide. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPoint) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {1, 1, 1, 3, 4, 5, 5}; + vector<int> source_high = {2, 2, 3, 4, 6, 6, 6}; + vector<int> target_low = {-1, 0, 2, 3, 4, 4, -1}; + vector<int> target_high = {-1, 3, 3, 4, 5, 7, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = 2, target_phrase_high = 5; + + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 1, 1, 1, + 2, true, true, false)); + EXPECT_EQ(1, target_phrase_low); + EXPECT_EQ(6, target_phrase_high); + EXPECT_EQ(0, source_back_low); + EXPECT_EQ(7, source_back_high); + + source_low = {0, -1, 1, 3, 4, -1, 6}; + source_high = {1, -1, 3, 4, 6, -1, 7}; + target_low = {0, 2, 2, 3, 4, 4, 6}; + target_high = {1, 3, 3, 4, 5, 5, 7}; + source_phrase_low = 2, source_phrase_high = 5; + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 1, 1, 1, + 2, true, true, false)); + EXPECT_EQ(1, target_phrase_low); + EXPECT_EQ(6, target_phrase_high); + EXPECT_EQ(2, source_back_low); + EXPECT_EQ(5, source_back_high); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointExtensionsNotAllowed) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Extension on the left side not allowed. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 1, false, true, false)); + // Extension on the left side is allowed, but we can't add anymore X. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, true, true, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + // Extension on the right side not allowed. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 1, true, false, false)); + // Extension on the right side is allowed, but we can't add anymore X. + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 0, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointSourceSentenceOutOfBounds) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(3)); + + vector<int> source_low = {0, 0, 2}; + vector<int> source_high = {1, 2, 3}; + vector<int> target_low = {0, 1, 2}; + vector<int> target_high = {2, 2, 3}; + int source_phrase_low = 1, source_phrase_high = 2; + int source_back_low, source_back_high; + int target_phrase_low = 1, target_phrase_high = 2; + // Extend out of sentence to left. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 1, + 1, true, true, false)); + source_low = {0, 1, 2}; + source_high = {1, 3, 3}; + target_low = {0, 1, 1}; + target_high = {1, 2, 3}; + target_phrase_low = 1, target_phrase_high = 2; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 2, 1, + 1, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSourceWide) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 5, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + vector<int> source_low = {2, -1, 2, 3, 4, -1, 4}; + vector<int> source_high = {3, -1, 3, 4, 5, -1, 5}; + vector<int> target_low = {-1, -1, 0, 3, 4, -1, -1}; + vector<int> target_high = {-1, -1, 3, 4, 7, -1, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + + // Second projection (on source side) is too wide. + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 2, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointArbitraryExpansion) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 20, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(11)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(11)); + + vector<int> source_low = {1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9}; + vector<int> source_high = {2, 3, 4, 5, 5, 6, 7, 8, 9, 10, 10}; + vector<int> target_low = {-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, -1}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, 8, 9, 10, 11, -1}; + int source_phrase_low = 4, source_phrase_high = 7; + int source_back_low, source_back_high; + int target_phrase_low = -1, target_phrase_high = -1; + EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 10, true, true, false)); + + source_phrase_low = 4, source_phrase_high = 7; + target_phrase_low = -1, target_phrase_high = -1; + EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, + source_low, source_high, target_phrase_low, + target_phrase_high, target_low, target_high, + source_back_low, source_back_high, 0, 1, 1, + 10, true, true, true)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapOrder) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + + vector<pair<int, int> > gaps = + {make_pair(0, 3), make_pair(5, 8), make_pair(11, 12), make_pair(15, 17)}; + vector<int> expected_gap_order = {0, 1, 2, 3}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + + gaps = {make_pair(15, 17), make_pair(8, 9), make_pair(5, 6), make_pair(0, 3)}; + expected_gap_order = {3, 2, 1, 0}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + + gaps = {make_pair(8, 9), make_pair(5, 6), make_pair(0, 3), make_pair(15, 17)}; + expected_gap_order = {2, 1, 0, 3}; + EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExceedNumSymbols) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {1, 1, 2, 3, 4, 5, 6}; + vector<int> source_high = {2, 2, 3, 4, 5, 6, 7}; + vector<int> target_low = {-1, 0, 2, 3, 4, 5, 6}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, 7}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 0, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + + source_low = {0, 1, 2, 3, 4, 5, 5}; + source_high = {1, 2, 3, 4, 5, 6, 6}; + target_low = {0, 1, 2, 3, 4, 5, -1}; + target_high = {1, 2, 3, 4, 5, 7, -1}; + source_phrase_low = 1, source_phrase_high = 6; + source_back_low = 1, source_back_high = 7; + num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExtensionsNotTight) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 7, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 1, 2, 3, 4, 5, -1}; + vector<int> source_high = {-1, 2, 3, 4, 5, 6, -1}; + vector<int> target_low = {-1, 1, 2, 3, 4, 5, -1}; + vector<int> target_high = {-1, 2, 3, 4, 5, 6, -1}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 0, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + + source_phrase_low = 1, source_phrase_high = 6; + source_back_low = 1, source_back_high = 7; + num_symbols = 5; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsNotTightExtremities) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 7, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, -1, 2, 3, 4, 5, 6}; + vector<int> source_high = {-1, -1, 3, 4, 5, 6, 7}; + vector<int> target_low = {-1, -1, 2, 3, 4, 5, 6}; + vector<int> target_high = {-1, -1, 3, 4, 5, 6, 7}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + EXPECT_FALSE(met_constraints); + vector<pair<int, int> > expected_gaps = {make_pair(2, 3), make_pair(4, 5)}; + EXPECT_EQ(expected_gaps, source_gaps); + EXPECT_EQ(expected_gaps, target_gaps); + + source_low = {-1, 1, 2, 3, 4, -1, 6}; + source_high = {-1, 2, 3, 4, 5, -1, 7}; + target_low = {-1, 1, 2, 3, 4, -1, 6}; + target_high = {-1, 2, 3, 4, 5, -1, 7}; + met_constraints = true; + source_gaps.clear(); + target_gaps.clear(); + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + EXPECT_FALSE(met_constraints); + EXPECT_EQ(expected_gaps, source_gaps); + EXPECT_EQ(expected_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsWithExtensions) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 5, 2, 3, 4, 1, -1}; + vector<int> source_high = {-1, 6, 3, 4, 5, 2, -1}; + vector<int> target_low = {-1, 5, 2, 3, 4, 1, -1}; + vector<int> target_high = {-1, 6, 3, 4, 5, 2, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {12, 14}; + vector<int> chunklen = {1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 3; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + vector<pair<int, int> > expected_source_gaps = { + make_pair(1, 2), make_pair(3, 4), make_pair(5, 6) + }; + EXPECT_EQ(expected_source_gaps, source_gaps); + vector<pair<int, int> > expected_target_gaps = { + make_pair(5, 6), make_pair(3, 4), make_pair(1, 2) + }; + EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGaps) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 1, 4, 3, 2, 5, -1}; + vector<int> source_high = {-1, 2, 5, 4, 3, 6, -1}; + vector<int> target_low = {-1, 1, 4, 3, 2, 5, -1}; + vector<int> target_high = {-1, 2, 5, 4, 3, 6, -1}; + int source_phrase_low = 1, source_phrase_high = 6; + int source_back_low = 1, source_back_high = 6; + vector<int> matching = {11, 13, 15}; + vector<int> chunklen = {1, 1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 5; + EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); + vector<pair<int, int> > expected_source_gaps = { + make_pair(2, 3), make_pair(4, 5) + }; + EXPECT_EQ(expected_source_gaps, source_gaps); + vector<pair<int, int> > expected_target_gaps = { + make_pair(4, 5), make_pair(2, 3) + }; + EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) { + helper = make_shared<RuleExtractorHelper>(source_data_array, + target_data_array, alignment, 10, 5, true, true, true); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + EXPECT_CALL(*target_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(7)); + + bool met_constraints = true; + vector<int> source_low = {-1, 3, 2, 3, 4, 3, -1}; + vector<int> source_high = {-1, 4, 3, 4, 5, 4, -1}; + vector<int> target_low = {-1, -1, 2, 1, 4, -1, -1}; + vector<int> target_high = {-1, -1, 3, 6, 5, -1, -1}; + int source_phrase_low = 2, source_phrase_high = 5; + int source_back_low = 2, source_back_high = 5; + vector<int> matching = {12, 14}; + vector<int> chunklen = {1, 1}; + vector<pair<int, int> > source_gaps, target_gaps; + int num_symbols = 3; + EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, + source_low, source_high, target_low, target_high, + source_phrase_low, source_phrase_high, + source_back_low, source_back_high, num_symbols, + met_constraints)); +} + +} // namespace diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc new file mode 100644 index 00000000..0be44d4d --- /dev/null +++ b/extractor/rule_extractor_test.cc @@ -0,0 +1,166 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_target_phrase_extractor.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_extractor.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleExtractorTest : public Test { + protected: + virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetSentenceId(_)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*source_data_array, GetSentenceStart(_)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*source_data_array, GetSentenceLength(_)) + .WillRepeatedly(Return(10)); + + helper = make_shared<MockRuleExtractorHelper>(); + EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _)) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*helper, CheckTightPhrases(_, _, _)) + .WillRepeatedly(Return(true)); + unordered_map<int, int> source_indexes; + EXPECT_CALL(*helper, GetSourceIndexes(_, _, _)) + .WillRepeatedly(Return(source_indexes)); + + vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(87)) + .WillRepeatedly(Return("a")); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + vector<int> symbols = {87}; + Phrase target_phrase = phrase_builder->Build(symbols); + PhraseAlignment phrase_alignment = {make_pair(0, 0)}; + + target_phrase_extractor = make_shared<MockTargetPhraseExtractor>(); + vector<pair<Phrase, PhraseAlignment> > target_phrases = { + make_pair(target_phrase, phrase_alignment) + }; + EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) + .WillRepeatedly(Return(target_phrases)); + + scorer = make_shared<MockScorer>(); + vector<double> scores = {0.3, 7.2}; + EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); + + extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder, + scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); + } + + shared_ptr<MockDataArray> source_data_array; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockRuleExtractorHelper> helper; + shared_ptr<MockScorer> scorer; + shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor; + shared_ptr<RuleExtractor> extractor; +}; + +TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _)) + .WillRepeatedly(Return(false)); + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + EXPECT_CALL(*helper, CheckTightPhrases(_, _, _)) + .WillRepeatedly(Return(false)); + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + // Set FindFixPoint to return false. + vector<pair<int, int> > gaps; + helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + // Set CheckGaps to return false. + vector<pair<int, int> > gaps; + helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); + vector<pair<int, int> > gaps(3); + // Set FindFixPoint to return true. The number of gaps equals the number of + // nonterminals, so we won't add any extremities. + helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(1, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { + vector<int> symbols = {87}; + Phrase phrase = phrase_builder->Build(symbols); + vector<int> matching = {2}; + PhraseLocation phrase_location(matching, 1); + + vector<int> links(10, -1); + EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( + SetArgReferee<0>(links), + SetArgReferee<1>(links), + SetArgReferee<2>(links), + SetArgReferee<3>(links))); + + vector<pair<int, int> > gaps; + // Set FindFixPoint to return true. The number of gaps equals the number of + // nonterminals, so we won't add any extremities. + helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); + + vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); + EXPECT_EQ(4, rules.size()); +} + +} // namespace diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc new file mode 100644 index 00000000..4101fcfa --- /dev/null +++ b/extractor/rule_factory.cc @@ -0,0 +1,305 @@ +#include "rule_factory.h" + +#include <chrono> +#include <memory> +#include <queue> +#include <vector> + +#include "grammar.h" +#include "fast_intersector.h" +#include "intersector.h" +#include "matchings_finder.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "vocabulary.h" + +using namespace std; +using namespace chrono; + +typedef high_resolution_clock Clock; + +struct State { + State(int start, int end, const vector<int>& phrase, + const vector<int>& subpatterns_start, shared_ptr<TrieNode> node, + bool starts_with_x) : + start(start), end(end), phrase(phrase), + subpatterns_start(subpatterns_start), node(node), + starts_with_x(starts_with_x) {} + + int start, end; + vector<int> phrase, subpatterns_start; + shared_ptr<TrieNode> node; + bool starts_with_x; +}; + +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + const shared_ptr<Vocabulary>& vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + int max_samples, + bool use_fast_intersect, + bool use_baeza_yates, + bool require_tight_phrases) : + vocabulary(vocabulary), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_nonterminals + 1), + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) { + matchings_finder = make_shared<MatchingsFinder>(source_suffix_array); + shared_ptr<MatchingComparator> comparator = + make_shared<MatchingComparator>(min_gap_size, max_rule_span); + intersector = make_shared<Intersector>(vocabulary, precomputation, + source_suffix_array, comparator, use_baeza_yates); + fast_intersector = make_shared<FastIntersector>(source_suffix_array, + precomputation, vocabulary, max_rule_span, min_gap_size); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(), + target_data_array, alignment, phrase_builder, scorer, vocabulary, + max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, + false, require_tight_phrases); + sampler = make_shared<Sampler>(source_suffix_array, max_samples); +} + +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<FastIntersector> fast_intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols, + bool use_fast_intersect) : + matchings_finder(finder), + intersector(intersector), + fast_intersector(fast_intersector), + phrase_builder(phrase_builder), + rule_extractor(rule_extractor), + vocabulary(vocabulary), + sampler(sampler), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_chunks), + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {} + +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { + intersector->sort_time = 0; + Clock::time_point start_time = Clock::now(); + double total_extract_time = 0; + double total_intersect_time = 0; + double total_lookup_time = 0; + // Clear cache for every new sentence. + trie.Reset(); + shared_ptr<TrieNode> root = trie.GetRoot(); + + int first_x = vocabulary->GetNonterminalIndex(1); + shared_ptr<TrieNode> x_root(new TrieNode(root)); + root->AddChild(first_x, x_root); + + queue<State> states; + for (size_t i = 0; i < word_ids.size(); ++i) { + states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false)); + } + for (size_t i = min_gap_size; i < word_ids.size(); ++i) { + states.push(State(i - min_gap_size, i, vector<int>(1, first_x), + vector<int>(1, i), x_root, true)); + } + + vector<Rule> rules; + while (!states.empty()) { + State state = states.front(); + states.pop(); + + shared_ptr<TrieNode> node = state.node; + vector<int> phrase = state.phrase; + int word_id = word_ids[state.end]; + phrase.push_back(word_id); + Phrase next_phrase = phrase_builder->Build(phrase); + shared_ptr<TrieNode> next_node; + + if (CannotHaveMatchings(node, word_id)) { + if (!node->HasChild(word_id)) { + node->AddChild(word_id, shared_ptr<TrieNode>()); + } + continue; + } + + if (RequiresLookup(node, word_id)) { + shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? + trie.GetRoot() : node->suffix_link->GetChild(word_id); + if (state.starts_with_x) { + // If the phrase starts with a non terminal, we simply use the matchings + // from the suffix link. + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, next_suffix_link->matchings); + } else { + PhraseLocation phrase_location; + if (next_phrase.Arity() > 0) { + Clock::time_point intersect_start = Clock::now(); + if (use_fast_intersect) { + phrase_location = fast_intersector->Intersect( + node->matchings, next_suffix_link->matchings, next_phrase); + } else { + phrase_location = intersector->Intersect( + node->phrase, + node->matchings, + next_suffix_link->phrase, + next_suffix_link->matchings, + next_phrase); + } + Clock::time_point intersect_stop = Clock::now(); + total_intersect_time += GetDuration(intersect_start, intersect_stop); + } else { + Clock::time_point lookup_start = Clock::now(); + phrase_location = matchings_finder->Find( + node->matchings, + vocabulary->GetTerminalValue(word_id), + state.phrase.size()); + Clock::time_point lookup_stop = Clock::now(); + total_lookup_time += GetDuration(lookup_start, lookup_stop); + } + + if (phrase_location.IsEmpty()) { + continue; + } + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, phrase_location); + } + node->AddChild(word_id, next_node); + + // Automatically adds a trailing non terminal if allowed. Simply copy the + // matchings from the prefix node. + AddTrailingNonterminal(phrase, next_phrase, next_node, + state.starts_with_x); + + Clock::time_point extract_start = Clock::now(); + if (!state.starts_with_x) { + PhraseLocation sample = sampler->Sample(next_node->matchings); + vector<Rule> new_rules = + rule_extractor->ExtractRules(next_phrase, sample); + rules.insert(rules.end(), new_rules.begin(), new_rules.end()); + } + Clock::time_point extract_stop = Clock::now(); + total_extract_time += GetDuration(extract_start, extract_stop); + } else { + next_node = node->GetChild(word_id); + } + + vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase, + next_node); + for (State new_state: new_states) { + states.push(new_state); + } + } + + Clock::time_point stop_time = Clock::now(); + cerr << "Total time for rule lookup, extraction, and scoring = " + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Extract time = " << total_extract_time << " seconds" << endl; + cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; + cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; + return Grammar(rules, scorer->GetFeatureNames()); +} + +bool HieroCachingRuleFactory::CannotHaveMatchings( + shared_ptr<TrieNode> node, int word_id) { + if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) { + return true; + } + + shared_ptr<TrieNode> suffix_link = node->suffix_link; + return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL; +} + +bool HieroCachingRuleFactory::RequiresLookup( + shared_ptr<TrieNode> node, int word_id) { + return !node->HasChild(word_id); +} + +void HieroCachingRuleFactory::AddTrailingNonterminal( + vector<int> symbols, + const Phrase& prefix, + const shared_ptr<TrieNode>& prefix_node, + bool starts_with_x) { + if (prefix.Arity() >= max_nonterminals) { + return; + } + + int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); + symbols.push_back(var_id); + Phrase var_phrase = phrase_builder->Build(symbols); + + int suffix_var_id = vocabulary->GetNonterminalIndex( + prefix.Arity() + (starts_with_x == 0)); + shared_ptr<TrieNode> var_suffix_link = + prefix_node->suffix_link->GetChild(suffix_var_id); + + prefix_node->AddChild(var_id, make_shared<TrieNode>( + var_suffix_link, var_phrase, prefix_node->matchings)); +} + +vector<State> HieroCachingRuleFactory::ExtendState( + const vector<int>& word_ids, + const State& state, + vector<int> symbols, + const Phrase& phrase, + const shared_ptr<TrieNode>& node) { + int span = state.end - state.start; + vector<State> new_states; + if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() || + span >= max_rule_span) { + return new_states; + } + + new_states.push_back(State(state.start, state.end + 1, symbols, + state.subpatterns_start, node, state.starts_with_x)); + + int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); + if (symbols.size() + 1 >= max_rule_symbols || + phrase.Arity() >= max_nonterminals || + num_subpatterns >= max_chunks) { + return new_states; + } + + int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); + symbols.push_back(var_id); + vector<int> subpatterns_start = state.subpatterns_start; + size_t i = state.end + 1 + min_gap_size; + while (i < word_ids.size() && i - state.start <= max_rule_span) { + subpatterns_start.push_back(i); + new_states.push_back(State(state.start, i, symbols, subpatterns_start, + node->GetChild(var_id), state.starts_with_x)); + subpatterns_start.pop_back(); + ++i; + } + + return new_states; +} diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h new file mode 100644 index 00000000..a39386a8 --- /dev/null +++ b/extractor/rule_factory.h @@ -0,0 +1,102 @@ +#ifndef _RULE_FACTORY_H_ +#define _RULE_FACTORY_H_ + +#include <memory> +#include <vector> + +#include "matchings_trie.h" +#include "phrase_builder.h" + +using namespace std; + +class Alignment; +class DataArray; +class Grammar; +class MatchingsFinder; +class FastIntersector; +class Intersector; +class Precomputation; +class Rule; +class RuleExtractor; +class Sampler; +class Scorer; +class State; +class SuffixArray; +class Vocabulary; + +class HieroCachingRuleFactory { + public: + HieroCachingRuleFactory( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + const shared_ptr<Vocabulary>& vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + int max_samples, + bool use_fast_intersect, + bool use_beaza_yates, + bool require_tight_phrases); + + // For testing only. + HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<FastIntersector> fast_intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols, + bool use_fast_intersect); + + virtual ~HieroCachingRuleFactory(); + + virtual Grammar GetGrammar(const vector<int>& word_ids); + + protected: + HieroCachingRuleFactory(); + + private: + bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id); + + bool RequiresLookup(shared_ptr<TrieNode> node, int word_id); + + void AddTrailingNonterminal(vector<int> symbols, + const Phrase& prefix, + const shared_ptr<TrieNode>& prefix_node, + bool starts_with_x); + + vector<State> ExtendState(const vector<int>& word_ids, + const State& state, + vector<int> symbols, + const Phrase& phrase, + const shared_ptr<TrieNode>& node); + + shared_ptr<MatchingsFinder> matchings_finder; + shared_ptr<Intersector> intersector; + shared_ptr<FastIntersector> fast_intersector; + MatchingsTrie trie; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<RuleExtractor> rule_extractor; + shared_ptr<Vocabulary> vocabulary; + shared_ptr<Sampler> sampler; + shared_ptr<Scorer> scorer; + int min_gap_size; + int max_rule_span; + int max_nonterminals; + int max_chunks; + int max_rule_symbols; + bool use_fast_intersect; +}; + +#endif diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc new file mode 100644 index 00000000..d329382a --- /dev/null +++ b/extractor/rule_factory_test.cc @@ -0,0 +1,146 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "mocks/mock_fast_intersector.h" +#include "mocks/mock_intersector.h" +#include "mocks/mock_matchings_finder.h" +#include "mocks/mock_rule_extractor.h" +#include "mocks/mock_sampler.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_vocabulary.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_factory.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class RuleFactoryTest : public Test { + protected: + virtual void SetUp() { + finder = make_shared<MockMatchingsFinder>(); + intersector = make_shared<MockIntersector>(); + fast_intersector = make_shared<MockFastIntersector>(); + + vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a")); + EXPECT_CALL(*vocabulary, GetTerminalValue(3)).WillRepeatedly(Return("b")); + EXPECT_CALL(*vocabulary, GetTerminalValue(4)).WillRepeatedly(Return("c")); + + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + + scorer = make_shared<MockScorer>(); + feature_names = {"f1"}; + EXPECT_CALL(*scorer, GetFeatureNames()) + .WillRepeatedly(Return(feature_names)); + + sampler = make_shared<MockSampler>(); + EXPECT_CALL(*sampler, Sample(_)) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + Phrase phrase; + vector<double> scores = {0.5}; + vector<pair<int, int> > phrase_alignment = {make_pair(0, 0)}; + vector<Rule> rules = {Rule(phrase, phrase, scores, phrase_alignment)}; + extractor = make_shared<MockRuleExtractor>(); + EXPECT_CALL(*extractor, ExtractRules(_, _)) + .WillRepeatedly(Return(rules)); + } + + vector<string> feature_names; + shared_ptr<MockMatchingsFinder> finder; + shared_ptr<MockIntersector> intersector; + shared_ptr<MockFastIntersector> fast_intersector; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockScorer> scorer; + shared_ptr<MockSampler> sampler; + shared_ptr<MockRuleExtractor> extractor; + shared_ptr<HieroCachingRuleFactory> factory; +}; + +TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(6) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) + .Times(1) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); + + vector<int> word_ids = {2, 3, 4}; + Grammar grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(7, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(6) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(1) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(7, grammar.GetRules().size()); +} + +TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(12) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) + .Times(16) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); + + vector<int> word_ids = {2, 3, 4, 2, 3}; + Grammar grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(28, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared<HieroCachingRuleFactory>(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(12) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(16) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(28, grammar.GetRules().size()); +} + +} // namespace diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc new file mode 100644 index 00000000..38f10a5f --- /dev/null +++ b/extractor/run_extractor.cc @@ -0,0 +1,208 @@ +#include <chrono> +#include <fstream> +#include <iostream> +#include <string> +#include <vector> + +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; + +int main(int argc, char** argv) { + // TODO(pauldb): Also take arguments from config file. + po::options_description desc("Command line options"); + desc.add_options() + ("help,h", "Show available options") + ("source,f", po::value<string>(), "Source language corpus") + ("target,e", po::value<string>(), "Target language corpus") + ("bitext,b", po::value<string>(), "Parallel text (source ||| target)") + ("alignment,a", po::value<string>()->required(), "Bitext word alignment") + ("grammars,g", po::value<string>()->required(), "Grammars output path") + ("frequent", po::value<int>()->default_value(100), + "Number of precomputed frequent patterns") + ("super_frequent", po::value<int>()->default_value(10), + "Number of precomputed super frequent patterns") + ("max_rule_span", po::value<int>()->default_value(15), + "Maximum rule span") + ("max_rule_symbols,l", po::value<int>()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size") + ("max_phrase_len", po::value<int>()->default_value(4), + "Maximum frequent phrase length") + ("max_nonterminals", po::value<int>()->default_value(2), + "Maximum number of nonterminals in a rule") + ("min_frequency", po::value<int>()->default_value(1000), + "Minimum number of occurences for a pharse to be considered frequent") + ("max_samples", po::value<int>()->default_value(300), + "Maximum number of samples") + ("fast_intersect", po::value<bool>()->default_value(false), + "Enable fast intersect") + // TODO(pauldb): Check if this works when set to false. + ("tight_phrases", po::value<bool>()->default_value(true), + "False if phrases may be loose (better, but slower)") + ("baeza_yates", po::value<bool>()->default_value(true), + "Use double binary search"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + + // Check for help argument before notify, so we don't need to pass in the + // required parameters. + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + po::notify(vm); + + if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { + cerr << "A paralel corpus is required. " + << "Use -f (source) with -e (target) or -b (bitext)." + << endl; + return 1; + } + + Clock::time_point preprocess_start_time = Clock::now(); + cerr << "Reading source and target data..." << endl; + Clock::time_point start_time = Clock::now(); + shared_ptr<DataArray> source_data_array, target_data_array; + if (vm.count("bitext")) { + source_data_array = make_shared<DataArray>( + vm["bitext"].as<string>(), SOURCE); + target_data_array = make_shared<DataArray>( + vm["bitext"].as<string>(), TARGET); + } else { + source_data_array = make_shared<DataArray>(vm["source"].as<string>()); + target_data_array = make_shared<DataArray>(vm["target"].as<string>()); + } + Clock::time_point stop_time = Clock::now(); + cerr << "Reading data took " << GetDuration(start_time, stop_time) + << " seconds" << endl; + + cerr << "Creating source suffix array..." << endl; + start_time = Clock::now(); + shared_ptr<SuffixArray> source_suffix_array = + make_shared<SuffixArray>(source_data_array); + stop_time = Clock::now(); + cerr << "Creating suffix array took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + cerr << "Reading alignment..." << endl; + start_time = Clock::now(); + shared_ptr<Alignment> alignment = + make_shared<Alignment>(vm["alignment"].as<string>()); + stop_time = Clock::now(); + cerr << "Reading alignment took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + cerr << "Precomputating collocations..." << endl; + start_time = Clock::now(); + shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( + source_suffix_array, + vm["frequent"].as<int>(), + vm["super_frequent"].as<int>(), + vm["max_rule_span"].as<int>(), + vm["max_rule_symbols"].as<int>(), + vm["min_gap_size"].as<int>(), + vm["max_phrase_len"].as<int>(), + vm["min_frequency"].as<int>()); + stop_time = Clock::now(); + cerr << "Precomputing collocations took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + cerr << "Precomputing conditional probabilities..." << endl; + start_time = Clock::now(); + shared_ptr<TranslationTable> table = make_shared<TranslationTable>( + source_data_array, target_data_array, alignment); + stop_time = Clock::now(); + cerr << "Precomputing conditional probabilities took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + Clock::time_point preprocess_stop_time = Clock::now(); + cerr << "Overall preprocessing step took " + << GetDuration(preprocess_start_time, preprocess_stop_time) + << " seconds" << endl; + + Clock::time_point extraction_start_time = Clock::now(); + vector<shared_ptr<Feature> > features = { + make_shared<TargetGivenSourceCoherent>(), + make_shared<SampleSourceCount>(), + make_shared<CountSourceTarget>(), + make_shared<MaxLexSourceGivenTarget>(table), + make_shared<MaxLexTargetGivenSource>(table), + make_shared<IsSourceSingleton>(), + make_shared<IsSourceTargetSingleton>() + }; + shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + + // TODO(pauldb): Add parallelization. + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + scorer, + vm["min_gap_size"].as<int>(), + vm["max_rule_span"].as<int>(), + vm["max_nonterminals"].as<int>(), + vm["max_rule_symbols"].as<int>(), + vm["max_samples"].as<int>(), + vm["fast_intersect"].as<bool>(), + vm["baeza_yates"].as<bool>(), + vm["tight_phrases"].as<bool>()); + + int grammar_id = 0; + fs::path grammar_path = vm["grammars"].as<string>(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + + string sentence, delimiter = "|||"; + while (getline(cin, sentence)) { + string suffix = ""; + int position = sentence.find(delimiter); + if (position != sentence.npos) { + suffix = sentence.substr(position); + sentence = sentence.substr(0, position); + } + + Grammar grammar = extractor.GetGrammar(sentence); + string file_name = "grammar." + to_string(grammar_id); + fs::path grammar_file = grammar_path / file_name; + ofstream output(grammar_file.c_str()); + output << grammar; + + cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << grammar_id + << "\"> " << sentence << " </seg> " << suffix << endl; + ++grammar_id; + } + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; + + return 0; +} diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt new file mode 100644 index 00000000..80b446a4 --- /dev/null +++ b/extractor/sample_alignment.txt @@ -0,0 +1,2 @@ +0-0 1-1 2-2 +1-0 2-1 diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt new file mode 100644 index 00000000..93d6b39d --- /dev/null +++ b/extractor/sample_bitext.txt @@ -0,0 +1,2 @@ +ana are mere . ||| anna has apples . +ana bea mult lapte . ||| anna drinks a lot of milk . diff --git a/extractor/sampler.cc b/extractor/sampler.cc new file mode 100644 index 00000000..5067ca8a --- /dev/null +++ b/extractor/sampler.cc @@ -0,0 +1,41 @@ +#include "sampler.h" + +#include "phrase_location.h" +#include "suffix_array.h" + +Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) : + suffix_array(suffix_array), max_samples(max_samples) {} + +Sampler::Sampler() {} + +Sampler::~Sampler() {} + +PhraseLocation Sampler::Sample(const PhraseLocation& location) const { + vector<int> sample; + int num_subpatterns; + if (location.matchings == NULL) { + num_subpatterns = 1; + int low = location.sa_low, high = location.sa_high; + double step = max(1.0, (double) (high - low) / max_samples); + for (double i = low; i < high && sample.size() < max_samples; i += step) { + sample.push_back(suffix_array->GetSuffix(Round(i))); + } + } else { + num_subpatterns = location.num_subpatterns; + int num_matchings = location.matchings->size() / num_subpatterns; + double step = max(1.0, (double) num_matchings / max_samples); + for (double i = 0, num_samples = 0; + i < num_matchings && num_samples < max_samples; + i += step, ++num_samples) { + int start = Round(i) * num_subpatterns; + sample.insert(sample.end(), location.matchings->begin() + start, + location.matchings->begin() + start + num_subpatterns); + } + } + return PhraseLocation(sample, num_subpatterns); +} + +int Sampler::Round(double x) const { + // TODO(pauldb): Remove EPS. + return x + 0.5 + 1e-8; +} diff --git a/extractor/sampler.h b/extractor/sampler.h new file mode 100644 index 00000000..9cf321fb --- /dev/null +++ b/extractor/sampler.h @@ -0,0 +1,29 @@ +#ifndef _SAMPLER_H_ +#define _SAMPLER_H_ + +#include <memory> + +using namespace std; + +class PhraseLocation; +class SuffixArray; + +class Sampler { + public: + Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); + + virtual ~Sampler(); + + virtual PhraseLocation Sample(const PhraseLocation& location) const; + + protected: + Sampler(); + + private: + int Round(double x) const; + + shared_ptr<SuffixArray> suffix_array; + int max_samples; +}; + +#endif diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc new file mode 100644 index 00000000..4f91965b --- /dev/null +++ b/extractor/sampler_test.cc @@ -0,0 +1,72 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" +#include "sampler.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class SamplerTest : public Test { + protected: + virtual void SetUp() { + suffix_array = make_shared<MockSuffixArray>(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); + } + } + + shared_ptr<MockSuffixArray> suffix_array; + shared_ptr<Sampler> sampler; +}; + +TEST_F(SamplerTest, TestSuffixArrayRange) { + PhraseLocation location(0, 10); + + sampler = make_shared<Sampler>(suffix_array, 1); + vector<int> expected_locations = {0}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 2); + expected_locations = {0, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 3); + expected_locations = {0, 3, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 4); + expected_locations = {0, 3, 5, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 100); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +} + +TEST_F(SamplerTest, TestSubstringsSample) { + vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + PhraseLocation location(locations, 2); + + sampler = make_shared<Sampler>(suffix_array, 1); + vector<int> expected_locations = {0, 1}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 2); + expected_locations = {0, 1, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 3); + expected_locations = {0, 1, 4, 5, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + + sampler = make_shared<Sampler>(suffix_array, 7); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +} + +} // namespace diff --git a/extractor/scorer.cc b/extractor/scorer.cc new file mode 100644 index 00000000..f28b3181 --- /dev/null +++ b/extractor/scorer.cc @@ -0,0 +1,26 @@ +#include "scorer.h" + +#include "features/feature.h" + +Scorer::Scorer(const vector<shared_ptr<Feature> >& features) : + features(features) {} + +Scorer::Scorer() {} + +Scorer::~Scorer() {} + +vector<double> Scorer::Score(const FeatureContext& context) const { + vector<double> scores; + for (auto feature: features) { + scores.push_back(feature->Score(context)); + } + return scores; +} + +vector<string> Scorer::GetFeatureNames() const { + vector<string> feature_names; + for (auto feature: features) { + feature_names.push_back(feature->GetName()); + } + return feature_names; +} diff --git a/extractor/scorer.h b/extractor/scorer.h new file mode 100644 index 00000000..ba71a6ee --- /dev/null +++ b/extractor/scorer.h @@ -0,0 +1,30 @@ +#ifndef _SCORER_H_ +#define _SCORER_H_ + +#include <memory> +#include <string> +#include <vector> + +using namespace std; + +class Feature; +class FeatureContext; + +class Scorer { + public: + Scorer(const vector<shared_ptr<Feature> >& features); + + virtual ~Scorer(); + + virtual vector<double> Score(const FeatureContext& context) const; + + virtual vector<string> GetFeatureNames() const; + + protected: + Scorer(); + + private: + vector<shared_ptr<Feature> > features; +}; + +#endif diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc new file mode 100644 index 00000000..56a85762 --- /dev/null +++ b/extractor/scorer_test.cc @@ -0,0 +1,47 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_feature.h" +#include "scorer.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class ScorerTest : public Test { + protected: + virtual void SetUp() { + feature1 = make_shared<MockFeature>(); + EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5)); + EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1")); + + feature2 = make_shared<MockFeature>(); + EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3)); + EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2")); + + vector<shared_ptr<Feature> > features = {feature1, feature2}; + scorer = make_shared<Scorer>(features); + } + + shared_ptr<MockFeature> feature1; + shared_ptr<MockFeature> feature2; + shared_ptr<Scorer> scorer; +}; + +TEST_F(ScorerTest, TestScore) { + vector<double> expected_scores = {0.5, -1.3}; + Phrase phrase; + FeatureContext context(phrase, phrase, 0.3, 2, 11); + EXPECT_EQ(expected_scores, scorer->Score(context)); +} + +TEST_F(ScorerTest, TestGetNames) { + vector<string> expected_names = {"f1", "f2"}; + EXPECT_EQ(expected_names, scorer->GetFeatureNames()); +} + +} // namespace diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc new file mode 100644 index 00000000..23c458a4 --- /dev/null +++ b/extractor/suffix_array.cc @@ -0,0 +1,229 @@ +#include "suffix_array.h" + +#include <chrono> +#include <iostream> +#include <string> +#include <vector> + +#include "data_array.h" +#include "phrase_location.h" +#include "time_util.h" + +namespace fs = boost::filesystem; +using namespace std; +using namespace chrono; + +SuffixArray::SuffixArray(shared_ptr<DataArray> data_array) : + data_array(data_array) { + BuildSuffixArray(); +} + +SuffixArray::SuffixArray() {} + +SuffixArray::~SuffixArray() {} + +void SuffixArray::BuildSuffixArray() { + vector<int> groups = data_array->GetData(); + groups.reserve(groups.size() + 1); + groups.push_back(DataArray::NULL_WORD); + suffix_array.resize(groups.size()); + word_start.resize(data_array->GetVocabularySize() + 1); + + InitialBucketSort(groups); + + int combined_group_size = 0; + for (size_t i = 1; i < word_start.size(); ++i) { + if (word_start[i] - word_start[i - 1] == 1) { + ++combined_group_size; + suffix_array[word_start[i] - combined_group_size] = -combined_group_size; + } else { + combined_group_size = 0; + } + } + + PrefixDoublingSort(groups); + cerr << "\tFinalizing sort..." << endl; + + for (size_t i = 0; i < groups.size(); ++i) { + suffix_array[groups[i]] = i; + } +} + +void SuffixArray::InitialBucketSort(vector<int>& groups) { + Clock::time_point start_time = Clock::now(); + for (size_t i = 0; i < groups.size(); ++i) { + ++word_start[groups[i]]; + } + + for (size_t i = 1; i < word_start.size(); ++i) { + word_start[i] += word_start[i - 1]; + } + + for (size_t i = 0; i < groups.size(); ++i) { + --word_start[groups[i]]; + suffix_array[word_start[groups[i]]] = i; + } + + for (size_t i = 0; i < suffix_array.size(); ++i) { + groups[i] = word_start[groups[i] + 1] - 1; + } + Clock::time_point stop_time = Clock::now(); + cerr << "\tBucket sort took " << GetDuration(start_time, stop_time) + << " seconds" << endl; +} + +void SuffixArray::PrefixDoublingSort(vector<int>& groups) { + int step = 1; + while (suffix_array[0] != -suffix_array.size()) { + int combined_group_size = 0; + int i = 0; + while (i < suffix_array.size()) { + if (suffix_array[i] < 0) { + int skip = -suffix_array[i]; + combined_group_size += skip; + i += skip; + suffix_array[i - combined_group_size] = -combined_group_size; + } else { + combined_group_size = 0; + int j = groups[suffix_array[i]]; + TernaryQuicksort(i, j, step, groups); + i = j + 1; + } + } + step *= 2; + } +} + +void SuffixArray::TernaryQuicksort(int left, int right, int step, + vector<int>& groups) { + if (left > right) { + return; + } + + int pivot = left + rand() % (right - left + 1); + int pivot_value = groups[suffix_array[pivot] + step]; + swap(suffix_array[pivot], suffix_array[left]); + int mid_left = left, mid_right = left; + for (int i = left + 1; i <= right; ++i) { + if (groups[suffix_array[i] + step] < pivot_value) { + ++mid_right; + int temp = suffix_array[i]; + suffix_array[i] = suffix_array[mid_right]; + suffix_array[mid_right] = suffix_array[mid_left]; + suffix_array[mid_left] = temp; + ++mid_left; + } else if (groups[suffix_array[i] + step] == pivot_value) { + ++mid_right; + int temp = suffix_array[i]; + suffix_array[i] = suffix_array[mid_right]; + suffix_array[mid_right] = temp; + } + } + + TernaryQuicksort(left, mid_left - 1, step, groups); + + if (mid_left == mid_right) { + groups[suffix_array[mid_left]] = mid_left; + suffix_array[mid_left] = -1; + } else { + for (int i = mid_left; i <= mid_right; ++i) { + groups[suffix_array[i]] = mid_right; + } + } + + TernaryQuicksort(mid_right + 1, right, step, groups); +} + +vector<int> SuffixArray::BuildLCPArray() const { + Clock::time_point start_time = Clock::now(); + cerr << "Constructing LCP array..." << endl; + + vector<int> lcp(suffix_array.size()); + vector<int> rank(suffix_array.size()); + const vector<int>& data = data_array->GetData(); + + for (size_t i = 0; i < suffix_array.size(); ++i) { + rank[suffix_array[i]] = i; + } + + int prefix_len = 0; + for (size_t i = 0; i < suffix_array.size(); ++i) { + if (rank[i] == 0) { + lcp[rank[i]] = -1; + } else { + int j = suffix_array[rank[i] - 1]; + while (i + prefix_len < data.size() && j + prefix_len < data.size() + && data[i + prefix_len] == data[j + prefix_len]) { + ++prefix_len; + } + lcp[rank[i]] = prefix_len; + } + + if (prefix_len > 0) { + --prefix_len; + } + } + + Clock::time_point stop_time = Clock::now(); + cerr << "Constructing LCP took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + return lcp; +} + +int SuffixArray::GetSuffix(int rank) const { + return suffix_array[rank]; +} + +int SuffixArray::GetSize() const { + return suffix_array.size(); +} + +shared_ptr<DataArray> SuffixArray::GetData() const { + return data_array; +} + +void SuffixArray::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "r"); + data_array->WriteBinary(file); + + int size = suffix_array.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(suffix_array.data(), sizeof(int), size, file); + + size = word_start.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(word_start.data(), sizeof(int), size, file); +} + +PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, + int offset) const { + if (!data_array->HasWord(word)) { + // Return empty phrase location. + return PhraseLocation(0, 0); + } + + int word_id = data_array->GetWordId(word); + if (offset == 0) { + return PhraseLocation(word_start[word_id], word_start[word_id + 1]); + } + + return PhraseLocation(LookupRangeStart(low, high, word_id, offset), + LookupRangeStart(low, high, word_id + 1, offset)); +} + +int SuffixArray::LookupRangeStart(int low, int high, int word_id, + int offset) const { + int result = high; + while (low < high) { + int middle = low + (high - low) / 2; + if (suffix_array[middle] + offset >= data_array->GetSize() || + data_array->AtIndex(suffix_array[middle] + offset) < word_id) { + low = middle + 1; + } else { + result = middle; + high = middle; + } + } + return result; +} diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h new file mode 100644 index 00000000..79a22694 --- /dev/null +++ b/extractor/suffix_array.h @@ -0,0 +1,54 @@ +#ifndef _SUFFIX_ARRAY_H_ +#define _SUFFIX_ARRAY_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +class DataArray; +class PhraseLocation; + +class SuffixArray { + public: + SuffixArray(shared_ptr<DataArray> data_array); + + virtual ~SuffixArray(); + + virtual int GetSize() const; + + virtual shared_ptr<DataArray> GetData() const; + + virtual vector<int> BuildLCPArray() const; + + virtual int GetSuffix(int rank) const; + + virtual PhraseLocation Lookup(int low, int high, const string& word, + int offset) const; + + void WriteBinary(const fs::path& filepath) const; + + protected: + SuffixArray(); + + private: + void BuildSuffixArray(); + + void InitialBucketSort(vector<int>& groups); + + void TernaryQuicksort(int left, int right, int step, vector<int>& groups); + + void PrefixDoublingSort(vector<int>& groups); + + int LookupRangeStart(int low, int high, int word_id, int offset) const; + + shared_ptr<DataArray> data_array; + vector<int> suffix_array; + vector<int> word_start; +}; + +#endif diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc new file mode 100644 index 00000000..60295567 --- /dev/null +++ b/extractor/suffix_array_test.cc @@ -0,0 +1,76 @@ +#include <gtest/gtest.h> + +#include "mocks/mock_data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +#include <vector> + +using namespace std; +using namespace ::testing; + +namespace { + +class SuffixArrayTest : public Test { + protected: + virtual void SetUp() { + data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; + data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); + EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); + suffix_array = make_shared<SuffixArray>(data_array); + } + + vector<int> data; + shared_ptr<SuffixArray> suffix_array; + shared_ptr<MockDataArray> data_array; +}; + +TEST_F(SuffixArrayTest, TestData) { + EXPECT_EQ(data_array, suffix_array->GetData()); + EXPECT_EQ(14, suffix_array->GetSize()); +} + +TEST_F(SuffixArrayTest, TestBuildSuffixArray) { + vector<int> expected_suffix_array = + {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8}; + for (size_t i = 0; i < expected_suffix_array.size(); ++i) { + EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); + } +} + +TEST_F(SuffixArrayTest, TestBuildLCP) { + vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1}; + EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); +} + +TEST_F(SuffixArrayTest, TestLookup) { + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + } + + EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); + EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0)); + + EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); + EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); + + EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1)); + + EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2)); + + EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3)); + + EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4)); + EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); +} + +} // namespace diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc new file mode 100644 index 00000000..ac583953 --- /dev/null +++ b/extractor/target_phrase_extractor.cc @@ -0,0 +1,144 @@ +#include "target_phrase_extractor.h" + +#include <unordered_set> + +#include "alignment.h" +#include "data_array.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "rule_extractor_helper.h" +#include "vocabulary.h" + +using namespace std; + +TargetPhraseExtractor::TargetPhraseExtractor( + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractorHelper> helper, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + bool require_tight_phrases) : + target_data_array(target_data_array), + alignment(alignment), + phrase_builder(phrase_builder), + helper(helper), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + require_tight_phrases(require_tight_phrases) {} + +TargetPhraseExtractor::TargetPhraseExtractor() {} + +TargetPhraseExtractor::~TargetPhraseExtractor() {} + +vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const { + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + + vector<int> target_gap_order = helper->GetGapOrder(target_gaps); + + int target_x_low = target_phrase_low, target_x_high = target_phrase_high; + if (!require_tight_phrases) { + while (target_x_low > 0 && + target_phrase_high - target_x_low < max_rule_span && + target_low[target_x_low - 1] == -1) { + --target_x_low; + } + while (target_x_high < target_sent_len && + target_x_high - target_phrase_low < max_rule_span && + target_low[target_x_high] == -1) { + ++target_x_high; + } + } + + vector<pair<int, int> > gaps(target_gaps.size()); + for (size_t i = 0; i < gaps.size(); ++i) { + gaps[i] = target_gaps[target_gap_order[i]]; + if (!require_tight_phrases) { + while (gaps[i].first > target_x_low && + target_low[gaps[i].first - 1] == -1) { + --gaps[i].first; + } + while (gaps[i].second < target_x_high && + target_low[gaps[i].second] == -1) { + ++gaps[i].second; + } + } + } + + vector<pair<int, int> > ranges(2 * gaps.size() + 2); + ranges.front() = make_pair(target_x_low, target_phrase_low); + ranges.back() = make_pair(target_phrase_high, target_x_high); + for (size_t i = 0; i < gaps.size(); ++i) { + int j = target_gap_order[i]; + ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first); + ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second); + } + + vector<pair<Phrase, PhraseAlignment> > target_phrases; + vector<int> subpatterns(ranges.size()); + GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, + target_phrase_low, target_phrase_high, source_indexes, + sentence_id); + return target_phrases; +} + +void TargetPhraseExtractor::GeneratePhrases( + vector<pair<Phrase, PhraseAlignment> >& target_phrases, + const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns, + const vector<int>& target_gap_order, int target_phrase_low, + int target_phrase_high, const unordered_map<int, int>& source_indexes, + int sentence_id) const { + if (index >= ranges.size()) { + if (subpatterns.back() - subpatterns.front() > max_rule_span) { + return; + } + + vector<int> symbols; + unordered_map<int, int> target_indexes; + + int target_sent_start = target_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { + for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { + target_indexes[j] = symbols.size(); + string target_word = target_data_array->GetWordAtIndex( + target_sent_start + j); + symbols.push_back(vocabulary->GetTerminalIndex(target_word)); + } + if (i < target_gap_order.size()) { + symbols.push_back(vocabulary->GetNonterminalIndex( + target_gap_order[i] + 1)); + } + } + + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + vector<pair<int, int> > alignment; + for (pair<int, int> link: links) { + if (target_indexes.count(link.second)) { + alignment.push_back(make_pair(source_indexes.find(link.first)->second, + target_indexes[link.second])); + } + } + + Phrase target_phrase = phrase_builder->Build(symbols); + target_phrases.push_back(make_pair(target_phrase, alignment)); + return; + } + + subpatterns[index] = ranges[index].first; + if (index > 0) { + subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); + } + while (subpatterns[index] <= ranges[index].second) { + subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); + while (subpatterns[index + 1] <= ranges[index + 1].second) { + GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, + target_gap_order, target_phrase_low, target_phrase_high, + source_indexes, sentence_id); + ++subpatterns[index + 1]; + } + ++subpatterns[index]; + } +} diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h new file mode 100644 index 00000000..134f24cc --- /dev/null +++ b/extractor/target_phrase_extractor.h @@ -0,0 +1,56 @@ +#ifndef _TARGET_PHRASE_EXTRACTOR_H_ +#define _TARGET_PHRASE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +class Alignment; +class DataArray; +class Phrase; +class PhraseBuilder; +class RuleExtractorHelper; +class Vocabulary; + +typedef vector<pair<int, int> > PhraseAlignment; + +class TargetPhraseExtractor { + public: + TargetPhraseExtractor(shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractorHelper> helper, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + bool require_tight_phrases); + + virtual ~TargetPhraseExtractor(); + + virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases( + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const; + + protected: + TargetPhraseExtractor(); + + private: + void GeneratePhrases( + vector<pair<Phrase, PhraseAlignment> >& target_phrases, + const vector<pair<int, int> >& ranges, int index, + vector<int>& subpatterns, const vector<int>& target_gap_order, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const; + + shared_ptr<DataArray> target_data_array; + shared_ptr<Alignment> alignment; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<RuleExtractorHelper> helper; + shared_ptr<Vocabulary> vocabulary; + int max_rule_span; + bool require_tight_phrases; +}; + +#endif diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc new file mode 100644 index 00000000..7394f4d9 --- /dev/null +++ b/extractor/target_phrase_extractor_test.cc @@ -0,0 +1,116 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "target_phrase_extractor.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class TargetPhraseExtractorTest : public Test { + protected: + virtual void SetUp() { + data_array = make_shared<MockDataArray>(); + alignment = make_shared<MockAlignment>(); + vocabulary = make_shared<MockVocabulary>(); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + helper = make_shared<MockRuleExtractorHelper>(); + } + + shared_ptr<MockDataArray> data_array; + shared_ptr<MockAlignment> alignment; + shared_ptr<MockVocabulary> vocabulary; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<MockRuleExtractorHelper> helper; + shared_ptr<TargetPhraseExtractor> extractor; +}; + +TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) { + EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5)); + EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3)); + + vector<string> target_words = {"a", "b", "c", "d", "e"}; + vector<int> target_symbols = {20, 21, 22, 23, 24}; + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordAtIndex(i + 3)) + .WillRepeatedly(Return(target_words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) + .WillRepeatedly(Return(target_symbols[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) + .WillRepeatedly(Return(target_words[i])); + } + + vector<pair<int, int> > links = { + make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1), + make_pair(4, 4) + }; + EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links)); + + vector<int> gap_order = {1, 0}; + EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + + extractor = make_shared<TargetPhraseExtractor>( + data_array, alignment, phrase_builder, helper, vocabulary, 10, true); + + vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)}; + vector<int> target_low = {0, 3, 2, 1, 4}; + unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}}; + + vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases( + target_gaps, target_low, 0, 5, source_indexes, 1); + EXPECT_EQ(1, results.size()); + vector<int> expected_symbols = {20, -2, 22, -1, 24}; + EXPECT_EQ(expected_symbols, results[0].first.Get()); + vector<string> expected_words = {"a", "c", "e"}; + EXPECT_EQ(expected_words, results[0].first.GetWords()); + vector<pair<int, int> > expected_alignment = { + make_pair(0, 0), make_pair(2, 2), make_pair(4, 4) + }; + EXPECT_EQ(expected_alignment, results[0].second); +} + +TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) { + vector<string> target_words = {"a", "b", "c", "d", "e", "f"}; + vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 26}; + EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6)); + EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0)); + + for (size_t i = 0; i < target_words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordAtIndex(i)) + .WillRepeatedly(Return(target_words[i])); + EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) + .WillRepeatedly(Return(target_symbols[i])); + EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) + .WillRepeatedly(Return(target_words[i])); + } + + vector<pair<int, int> > links = {make_pair(1, 1)}; + EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links)); + + vector<int> gap_order = {0}; + EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + + extractor = make_shared<TargetPhraseExtractor>( + data_array, alignment, phrase_builder, helper, vocabulary, 10, false); + + vector<pair<int, int> > target_gaps = {make_pair(2, 4)}; + vector<int> target_low = {-1, 1, -1, -1, -1, -1}; + unordered_map<int, int> source_indexes = {{1, 1}}; + + vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases( + target_gaps, target_low, 1, 5, source_indexes, 0); + EXPECT_EQ(10, results.size()); + // TODO(pauldb): Finish unit test once it's clear how these alignments should + // look like. +} + +} // namespace diff --git a/extractor/time_util.cc b/extractor/time_util.cc new file mode 100644 index 00000000..88395f77 --- /dev/null +++ b/extractor/time_util.cc @@ -0,0 +1,6 @@ +#include "time_util.h" + +double GetDuration(const Clock::time_point& start_time, + const Clock::time_point& stop_time) { + return duration_cast<milliseconds>(stop_time - start_time).count() / 1000.0; +} diff --git a/extractor/time_util.h b/extractor/time_util.h new file mode 100644 index 00000000..6f7eda70 --- /dev/null +++ b/extractor/time_util.h @@ -0,0 +1,14 @@ +#ifndef _TIME_UTIL_H_ +#define _TIME_UTIL_H_ + +#include <chrono> + +using namespace std; +using namespace chrono; + +typedef high_resolution_clock Clock; + +double GetDuration(const Clock::time_point& start_time, + const Clock::time_point& stop_time); + +#endif diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc new file mode 100644 index 00000000..a48c0657 --- /dev/null +++ b/extractor/translation_table.cc @@ -0,0 +1,117 @@ +#include "translation_table.h" + +#include <string> +#include <vector> + +#include <boost/functional/hash.hpp> + +#include "alignment.h" +#include "data_array.h" + +using namespace std; + +TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment) : + source_data_array(source_data_array), target_data_array(target_data_array) { + const vector<int>& source_data = source_data_array->GetData(); + const vector<int>& target_data = target_data_array->GetData(); + + unordered_map<int, int> source_links_count; + unordered_map<int, int> target_links_count; + unordered_map<pair<int, int>, int, PairHash> links_count; + + for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { + vector<pair<int, int> > links = alignment->GetLinks(i); + int source_start = source_data_array->GetSentenceStart(i); + int target_start = target_data_array->GetSentenceStart(i); + // Ignore END_OF_LINE markers. + int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1; + int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1; + vector<int> source_sentence(source_data.begin() + source_start, + source_data.begin() + next_source_start); + vector<int> target_sentence(target_data.begin() + target_start, + target_data.begin() + next_target_start); + vector<int> source_linked_words(source_sentence.size()); + vector<int> target_linked_words(target_sentence.size()); + + for (pair<int, int> link: links) { + source_linked_words[link.first] = 1; + target_linked_words[link.second] = 1; + IncreaseLinksCount(source_links_count, target_links_count, links_count, + source_sentence[link.first], target_sentence[link.second]); + } + + for (size_t i = 0; i < source_sentence.size(); ++i) { + if (!source_linked_words[i]) { + IncreaseLinksCount(source_links_count, target_links_count, links_count, + source_sentence[i], DataArray::NULL_WORD); + } + } + + for (size_t i = 0; i < target_sentence.size(); ++i) { + if (!target_linked_words[i]) { + IncreaseLinksCount(source_links_count, target_links_count, links_count, + DataArray::NULL_WORD, target_sentence[i]); + } + } + } + + for (pair<pair<int, int>, int> link_count: links_count) { + int source_word = link_count.first.first; + int target_word = link_count.first.second; + double score1 = 1.0 * link_count.second / source_links_count[source_word]; + double score2 = 1.0 * link_count.second / target_links_count[target_word]; + translation_probabilities[link_count.first] = make_pair(score1, score2); + } +} + +TranslationTable::TranslationTable() {} + +TranslationTable::~TranslationTable() {} + +void TranslationTable::IncreaseLinksCount( + unordered_map<int, int>& source_links_count, + unordered_map<int, int>& target_links_count, + unordered_map<pair<int, int>, int, PairHash>& links_count, + int source_word_id, + int target_word_id) const { + ++source_links_count[source_word_id]; + ++target_links_count[target_word_id]; + ++links_count[make_pair(source_word_id, target_word_id)]; +} + +double TranslationTable::GetTargetGivenSourceScore( + const string& source_word, const string& target_word) { + if (!source_data_array->HasWord(source_word) || + !target_data_array->HasWord(target_word)) { + return -1; + } + + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + return translation_probabilities[make_pair(source_id, target_id)].first; +} + +double TranslationTable::GetSourceGivenTargetScore( + const string& source_word, const string& target_word) { + if (!source_data_array->HasWord(source_word) || + !target_data_array->HasWord(target_word)) { + return -1; + } + + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + return translation_probabilities[make_pair(source_id, target_id)].second; +} + +void TranslationTable::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "w"); + + int size = translation_probabilities.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: translation_probabilities) { + fwrite(&entry.first, sizeof(entry.first), 1, file); + fwrite(&entry.second, sizeof(entry.second), 1, file); + } +} diff --git a/extractor/translation_table.h b/extractor/translation_table.h new file mode 100644 index 00000000..157ad3af --- /dev/null +++ b/extractor/translation_table.h @@ -0,0 +1,53 @@ +#ifndef _TRANSLATION_TABLE_ +#define _TRANSLATION_TABLE_ + +#include <memory> +#include <string> +#include <unordered_map> + +#include <boost/filesystem.hpp> +#include <boost/functional/hash.hpp> + +using namespace std; +namespace fs = boost::filesystem; + +class Alignment; +class DataArray; + +typedef boost::hash<pair<int, int> > PairHash; + +class TranslationTable { + public: + TranslationTable( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment); + + virtual ~TranslationTable(); + + virtual double GetTargetGivenSourceScore(const string& source_word, + const string& target_word); + + virtual double GetSourceGivenTargetScore(const string& source_word, + const string& target_word); + + void WriteBinary(const fs::path& filepath) const; + + protected: + TranslationTable(); + + private: + void IncreaseLinksCount( + unordered_map<int, int>& source_links_count, + unordered_map<int, int>& target_links_count, + unordered_map<pair<int, int>, int, PairHash>& links_count, + int source_word_id, + int target_word_id) const; + + shared_ptr<DataArray> source_data_array; + shared_ptr<DataArray> target_data_array; + unordered_map<pair<int, int>, pair<double, double>, PairHash> + translation_probabilities; +}; + +#endif diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc new file mode 100644 index 00000000..c99f3f93 --- /dev/null +++ b/extractor/translation_table_test.cc @@ -0,0 +1,82 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "translation_table.h" + +using namespace std; +using namespace ::testing; + +namespace { + +TEST(TranslationTableTest, TestScores) { + vector<string> words = {"a", "b", "c"}; + + vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; + vector<int> source_sentence_start = {0, 6, 10, 14}; + shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetData()) + .WillRepeatedly(ReturnRef(source_data)); + EXPECT_CALL(*source_data_array, GetNumSentences()) + .WillRepeatedly(Return(3)); + for (size_t i = 0; i < source_sentence_start.size(); ++i) { + EXPECT_CALL(*source_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(source_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*source_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*source_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*source_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; + vector<int> target_sentence_start = {0, 7, 10, 13}; + shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*target_data_array, GetData()) + .WillRepeatedly(ReturnRef(target_data)); + for (size_t i = 0; i < target_sentence_start.size(); ++i) { + EXPECT_CALL(*target_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(target_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*target_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*target_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*target_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<pair<int, int> > links1 = { + make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), + make_pair(4, 4), make_pair(4, 5) + }; + vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)}; + vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)}; + shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); + EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); + EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); + EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); + + shared_ptr<TranslationTable> table = make_shared<TranslationTable>( + source_data_array, target_data_array, alignment); + + EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a")); + EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b")); + EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c")); + EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d")); + + EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a")); + EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b")); + EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c")); + EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); +} + +} // namespace diff --git a/extractor/veb.cc b/extractor/veb.cc new file mode 100644 index 00000000..f38f672e --- /dev/null +++ b/extractor/veb.cc @@ -0,0 +1,25 @@ +#include "veb.h" + +#include "veb_bitset.h" +#include "veb_tree.h" + +int VEB::MIN_BOTTOM_BITS = 5; +int VEB::MIN_BOTTOM_SIZE = 1 << VEB::MIN_BOTTOM_BITS; + +shared_ptr<VEB> VEB::Create(int size) { + if (size > MIN_BOTTOM_SIZE) { + return shared_ptr<VEB>(new VEBTree(size)); + } else { + return shared_ptr<VEB>(new VEBBitset(size)); + } +} + +int VEB::GetMinimum() { + return min; +} + +int VEB::GetMaximum() { + return max; +} + +VEB::VEB(int min, int max) : min(min), max(max) {} diff --git a/extractor/veb.h b/extractor/veb.h new file mode 100644 index 00000000..c8209cf7 --- /dev/null +++ b/extractor/veb.h @@ -0,0 +1,29 @@ +#ifndef _VEB_H_ +#define _VEB_H_ + +#include <memory> + +using namespace std; + +class VEB { + public: + static shared_ptr<VEB> Create(int size); + + virtual void Insert(int value) = 0; + + virtual int GetSuccessor(int value) = 0; + + int GetMinimum(); + + int GetMaximum(); + + static int MIN_BOTTOM_BITS; + static int MIN_BOTTOM_SIZE; + + protected: + VEB(int min = -1, int max = -1); + + int min, max; +}; + +#endif diff --git a/extractor/veb_bitset.cc b/extractor/veb_bitset.cc new file mode 100644 index 00000000..4e364cc5 --- /dev/null +++ b/extractor/veb_bitset.cc @@ -0,0 +1,25 @@ +#include "veb_bitset.h" + +using namespace std; + +VEBBitset::VEBBitset(int size) : bitset(size) { + min = max = -1; +} + +void VEBBitset::Insert(int value) { + bitset[value] = 1; + if (min == -1 || value < min) { + min = value; + } + if (max == - 1 || value > max) { + max = value; + } +} + +int VEBBitset::GetSuccessor(int value) { + int next_value = bitset.find_next(value); + if (next_value == bitset.npos) { + return -1; + } + return next_value; +} diff --git a/extractor/veb_bitset.h b/extractor/veb_bitset.h new file mode 100644 index 00000000..f8a91234 --- /dev/null +++ b/extractor/veb_bitset.h @@ -0,0 +1,22 @@ +#ifndef _VEB_BITSET_H_ +#define _VEB_BITSET_H_ + +#include <boost/dynamic_bitset.hpp> + +#include "veb.h" + +class VEBBitset: public VEB { + public: + VEBBitset(int size); + + void Insert(int value); + + int GetMinimum(); + + int GetSuccessor(int value); + + private: + boost::dynamic_bitset<> bitset; +}; + +#endif diff --git a/extractor/veb_test.cc b/extractor/veb_test.cc new file mode 100644 index 00000000..c40c9f28 --- /dev/null +++ b/extractor/veb_test.cc @@ -0,0 +1,56 @@ +#include <gtest/gtest.h> + +#include <algorithm> +#include <vector> + +#include "veb.h" + +using namespace std; + +namespace { + +class VEBTest : public ::testing::Test { + protected: + void VEBSortTester(vector<int> values, int max_value) { + shared_ptr<VEB> veb = VEB::Create(max_value); + for (int value: values) { + veb->Insert(value); + } + + sort(values.begin(), values.end()); + EXPECT_EQ(values.front(), veb->GetMinimum()); + EXPECT_EQ(values.back(), veb->GetMaximum()); + for (size_t i = 0; i + 1 < values.size(); ++i) { + EXPECT_EQ(values[i + 1], veb->GetSuccessor(values[i])); + } + EXPECT_EQ(-1, veb->GetSuccessor(values.back())); + } +}; + +TEST_F(VEBTest, SmallRange) { + vector<int> values{8, 13, 5, 1, 4, 15, 2, 10, 6, 7}; + VEBSortTester(values, 16); +} + +TEST_F(VEBTest, MediumRange) { + vector<int> values{167, 243, 88, 12, 137, 199, 212, 45, 150, 189}; + VEBSortTester(values, 255); +} + +TEST_F(VEBTest, LargeRangeSparse) { + vector<int> values; + for (size_t i = 0; i < 100; ++i) { + values.push_back(i * 1000000); + } + VEBSortTester(values, 100000000); +} + +TEST_F(VEBTest, LargeRangeDense) { + vector<int> values; + for (size_t i = 0; i < 1000000; ++i) { + values.push_back(i); + } + VEBSortTester(values, 1000000); +} + +} // namespace diff --git a/extractor/veb_tree.cc b/extractor/veb_tree.cc new file mode 100644 index 00000000..f8945445 --- /dev/null +++ b/extractor/veb_tree.cc @@ -0,0 +1,71 @@ +#include <cmath> + +#include "veb_tree.h" + +VEBTree::VEBTree(int size) { + int num_bits = ceil(log2(size)); + + lower_bits = num_bits >> 1; + upper_size = (size >> lower_bits) + 1; + + clusters.reserve(upper_size); + clusters.resize(upper_size); +} + +int VEBTree::GetNextValue(int value) { + return value & ((1 << lower_bits) - 1); +} + +int VEBTree::GetCluster(int value) { + return value >> lower_bits; +} + +int VEBTree::Compose(int cluster, int value) { + return (cluster << lower_bits) + value; +} + +void VEBTree::Insert(int value) { + if (min == -1 && max == -1) { + min = max = value; + return; + } + + if (value < min) { + swap(min, value); + } + + int cluster = GetCluster(value), next_value = GetNextValue(value); + if (clusters[cluster] == NULL) { + clusters[cluster] = VEB::Create(1 << lower_bits); + if (summary == NULL) { + summary = VEB::Create(upper_size); + } + summary->Insert(cluster); + } + clusters[cluster]->Insert(next_value); + + if (value > max) { + max = value; + } +} + +int VEBTree::GetSuccessor(int value) { + if (value >= max) { + return -1; + } + if (value < min) { + return min; + } + + int cluster = GetCluster(value), next_value = GetNextValue(value); + if (clusters[cluster] != NULL && + next_value < clusters[cluster]->GetMaximum()) { + return Compose(cluster, clusters[cluster]->GetSuccessor(next_value)); + } else { + int next_cluster = summary->GetSuccessor(cluster); + if (next_cluster == -1) { + return -1; + } + return Compose(next_cluster, clusters[next_cluster]->GetMinimum()); + } +} diff --git a/extractor/veb_tree.h b/extractor/veb_tree.h new file mode 100644 index 00000000..578d3e6a --- /dev/null +++ b/extractor/veb_tree.h @@ -0,0 +1,29 @@ +#ifndef _VEB_TREE_H_ +#define _VEB_TREE_H_ + +#include <memory> +#include <vector> + +using namespace std; + +#include "veb.h" + +class VEBTree: public VEB { + public: + VEBTree(int size); + + void Insert(int value); + + int GetSuccessor(int value); + + private: + int GetNextValue(int value); + int GetCluster(int value); + int Compose(int cluster, int value); + + int lower_bits, upper_size; + shared_ptr<VEB> summary; + vector<shared_ptr<VEB> > clusters; +}; + +#endif diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc new file mode 100644 index 00000000..5c379a29 --- /dev/null +++ b/extractor/vocabulary.cc @@ -0,0 +1,26 @@ +#include "vocabulary.h" + +Vocabulary::~Vocabulary() {} + +int Vocabulary::GetTerminalIndex(const string& word) { + if (!dictionary.count(word)) { + int word_id = words.size(); + dictionary[word] = word_id; + words.push_back(word); + return word_id; + } + + return dictionary[word]; +} + +int Vocabulary::GetNonterminalIndex(int position) { + return -position; +} + +bool Vocabulary::IsTerminal(int symbol) { + return symbol >= 0; +} + +string Vocabulary::GetTerminalValue(int symbol) { + return words[symbol]; +} diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h new file mode 100644 index 00000000..c6a8b3e8 --- /dev/null +++ b/extractor/vocabulary.h @@ -0,0 +1,27 @@ +#ifndef _VOCABULARY_H_ +#define _VOCABULARY_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +using namespace std; + +class Vocabulary { + public: + virtual ~Vocabulary(); + + virtual int GetTerminalIndex(const string& word); + + int GetNonterminalIndex(int position); + + bool IsTerminal(int symbol); + + virtual string GetTerminalValue(int symbol); + + private: + unordered_map<string, int> dictionary; + vector<string> words; +}; + +#endif |