From 5530575ae0ad939e17f08d6bd49978acea388ab7 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 28 Jan 2013 11:56:31 +0000 Subject: Initial working commit. --- extractor/Makefile.am | 85 ++++++++++++ extractor/alignment.cc | 47 +++++++ extractor/alignment.h | 24 ++++ extractor/binary_search_merger.cc | 245 +++++++++++++++++++++++++++++++++ extractor/binary_search_merger.h | 68 +++++++++ extractor/binary_search_merger_test.cc | 157 +++++++++++++++++++++ extractor/compile.cc | 98 +++++++++++++ extractor/data_array.cc | 146 ++++++++++++++++++++ extractor/data_array.h | 71 ++++++++++ extractor/data_array_test.cc | 70 ++++++++++ extractor/grammar_extractor.cc | 45 ++++++ extractor/grammar_extractor.h | 39 ++++++ extractor/intersector.cc | 129 +++++++++++++++++ extractor/intersector.h | 57 ++++++++ extractor/linear_merger.cc | 63 +++++++++ extractor/linear_merger.h | 35 +++++ extractor/linear_merger_test.cc | 149 ++++++++++++++++++++ extractor/matching.cc | 12 ++ extractor/matching.h | 18 +++ extractor/matching_comparator.cc | 46 +++++++ extractor/matching_comparator.h | 23 ++++ extractor/matching_comparator_test.cc | 139 +++++++++++++++++++ extractor/matching_test.cc | 25 ++++ extractor/matchings_finder.cc | 17 +++ extractor/matchings_finder.h | 22 +++ extractor/matchings_finder_test.cc | 42 ++++++ extractor/matchings_trie.cc | 11 ++ extractor/matchings_trie.h | 46 +++++++ extractor/mocks/mock_data_array.h | 14 ++ extractor/mocks/mock_linear_merger.h | 21 +++ extractor/mocks/mock_suffix_array.h | 17 +++ extractor/mocks/mock_vocabulary.h | 8 ++ extractor/phrase.cc | 25 ++++ extractor/phrase.h | 29 ++++ extractor/phrase_builder.cc | 21 +++ extractor/phrase_builder.h | 22 +++ extractor/phrase_location.cc | 35 +++++ extractor/phrase_location.h | 23 ++++ extractor/phrase_test.cc | 61 ++++++++ extractor/precomputation.cc | 192 ++++++++++++++++++++++++++ extractor/precomputation.h | 52 +++++++ extractor/rule_extractor.cc | 10 ++ extractor/rule_extractor.h | 22 +++ extractor/rule_factory.cc | 215 +++++++++++++++++++++++++++++ extractor/rule_factory.h | 67 +++++++++ extractor/run_extractor.cc | 109 +++++++++++++++ extractor/sample_bitext.txt | 2 + extractor/scorer.cc | 9 ++ extractor/scorer.h | 19 +++ extractor/suffix_array.cc | 211 ++++++++++++++++++++++++++++ extractor/suffix_array.h | 51 +++++++ extractor/suffix_array_test.cc | 75 ++++++++++ extractor/translation_table.cc | 94 +++++++++++++ extractor/translation_table.h | 38 +++++ extractor/veb.cc | 25 ++++ extractor/veb.h | 29 ++++ extractor/veb_bitset.cc | 25 ++++ extractor/veb_bitset.h | 22 +++ extractor/veb_test.cc | 56 ++++++++ extractor/veb_tree.cc | 71 ++++++++++ extractor/veb_tree.h | 29 ++++ extractor/vocabulary.cc | 26 ++++ extractor/vocabulary.h | 28 ++++ 63 files changed, 3682 insertions(+) create mode 100644 extractor/Makefile.am create mode 100644 extractor/alignment.cc create mode 100644 extractor/alignment.h create mode 100644 extractor/binary_search_merger.cc create mode 100644 extractor/binary_search_merger.h create mode 100644 extractor/binary_search_merger_test.cc create mode 100644 extractor/compile.cc create mode 100644 extractor/data_array.cc create mode 100644 extractor/data_array.h create mode 100644 extractor/data_array_test.cc create mode 100644 extractor/grammar_extractor.cc create mode 100644 extractor/grammar_extractor.h create mode 100644 extractor/intersector.cc create mode 100644 extractor/intersector.h create mode 100644 extractor/linear_merger.cc create mode 100644 extractor/linear_merger.h create mode 100644 extractor/linear_merger_test.cc create mode 100644 extractor/matching.cc create mode 100644 extractor/matching.h create mode 100644 extractor/matching_comparator.cc create mode 100644 extractor/matching_comparator.h create mode 100644 extractor/matching_comparator_test.cc create mode 100644 extractor/matching_test.cc create mode 100644 extractor/matchings_finder.cc create mode 100644 extractor/matchings_finder.h create mode 100644 extractor/matchings_finder_test.cc create mode 100644 extractor/matchings_trie.cc create mode 100644 extractor/matchings_trie.h create mode 100644 extractor/mocks/mock_data_array.h create mode 100644 extractor/mocks/mock_linear_merger.h create mode 100644 extractor/mocks/mock_suffix_array.h create mode 100644 extractor/mocks/mock_vocabulary.h create mode 100644 extractor/phrase.cc create mode 100644 extractor/phrase.h create mode 100644 extractor/phrase_builder.cc create mode 100644 extractor/phrase_builder.h create mode 100644 extractor/phrase_location.cc create mode 100644 extractor/phrase_location.h create mode 100644 extractor/phrase_test.cc create mode 100644 extractor/precomputation.cc create mode 100644 extractor/precomputation.h create mode 100644 extractor/rule_extractor.cc create mode 100644 extractor/rule_extractor.h create mode 100644 extractor/rule_factory.cc create mode 100644 extractor/rule_factory.h create mode 100644 extractor/run_extractor.cc create mode 100644 extractor/sample_bitext.txt create mode 100644 extractor/scorer.cc create mode 100644 extractor/scorer.h create mode 100644 extractor/suffix_array.cc create mode 100644 extractor/suffix_array.h create mode 100644 extractor/suffix_array_test.cc create mode 100644 extractor/translation_table.cc create mode 100644 extractor/translation_table.h create mode 100644 extractor/veb.cc create mode 100644 extractor/veb.h create mode 100644 extractor/veb_bitset.cc create mode 100644 extractor/veb_bitset.h create mode 100644 extractor/veb_test.cc create mode 100644 extractor/veb_tree.cc create mode 100644 extractor/veb_tree.h create mode 100644 extractor/vocabulary.cc create mode 100644 extractor/vocabulary.h (limited to 'extractor') diff --git a/extractor/Makefile.am b/extractor/Makefile.am new file mode 100644 index 00000000..844c0ef3 --- /dev/null +++ b/extractor/Makefile.am @@ -0,0 +1,85 @@ +bin_PROGRAMS = compile run_extractor + +noinst_PROGRAMS = \ + binary_search_merger_test \ + data_array_test \ + linear_merger_test \ + matching_comparator_test \ + matching_test \ + matchings_finder_test \ + phrase_test \ + precomputation_test \ + suffix_array_test \ + veb_test + +TESTS = precomputation_test +#TESTS = binary_search_merger_test \ +# data_array_test \ +# linear_merger_test \ +# matching_comparator_test \ +# matching_test \ +# phrase_test \ +# suffix_array_test \ +# veb_test + +binary_search_merger_test_SOURCES = binary_search_merger_test.cc +binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +data_array_test_SOURCES = data_array_test.cc +data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +linear_merger_test_SOURCES = linear_merger_test.cc +linear_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matching_comparator_test_SOURCES = matching_comparator_test.cc +matching_comparator_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +matching_test_SOURCES = matching_test.cc +matching_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +matchings_finder_test_SOURCES = matchings_finder_test.cc +matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_test_SOURCES = phrase_test.cc +phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +precomputation_test_SOURCES = precomputation_test.cc +precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_test_SOURCES = suffix_array_test.cc +suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +veb_test_SOURCES = veb_test.cc +veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a + +noinst_LIBRARIES = libextractor.a libcompile.a + +compile_SOURCES = compile.cc +compile_LDADD = libcompile.a +run_extractor_SOURCES = run_extractor.cc +run_extractor_LDADD = libextractor.a + +libcompile_a_SOURCES = \ + alignment.cc \ + data_array.cc \ + phrase_location.cc \ + precomputation.cc \ + suffix_array.cc \ + translation_table.cc + +libextractor_a_SOURCES = \ + alignment.cc \ + binary_search_merger.cc \ + data_array.cc \ + grammar_extractor.cc \ + matching.cc \ + matching_comparator.cc \ + matchings_finder.cc \ + intersector.cc \ + linear_merger.cc \ + matchings_trie.cc \ + phrase.cc \ + phrase_builder.cc \ + phrase_location.cc \ + precomputation.cc \ + rule_extractor.cc \ + rule_factory.cc \ + suffix_array.cc \ + translation_table.cc \ + veb.cc \ + veb_bitset.cc \ + veb_tree.cc \ + vocabulary.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) diff --git a/extractor/alignment.cc b/extractor/alignment.cc new file mode 100644 index 00000000..cad28a72 --- /dev/null +++ b/extractor/alignment.cc @@ -0,0 +1,47 @@ +#include "alignment.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = boost::filesystem; +using namespace std; + +Alignment::Alignment(const string& filename) { + ifstream infile(filename.c_str()); + string line; + while (getline(infile, line)) { + vector items; + boost::split(items, line, boost::is_any_of(" -")); + vector > alignment; + alignment.reserve(items.size() / 2); + for (size_t i = 0; i < items.size(); i += 2) { + alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1]))); + } + alignments.push_back(alignment); + } + // Note: shrink_to_fit does nothing for vector > on g++ 4.6.3, + // but let's hope that the bug will be fixed in a newer version. + alignments.shrink_to_fit(); +} + +vector > Alignment::GetLinks(int sentence_index) const { + return alignments[sentence_index]; +} + +void Alignment::WriteBinary(const fs::path& filepath) { + FILE* file = fopen(filepath.string().c_str(), "w"); + int size = alignments.size(); + fwrite(&size, sizeof(int), 1, file); + for (vector > alignment: alignments) { + size = alignment.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(alignment.data(), sizeof(pair), size, file); + } +} diff --git a/extractor/alignment.h b/extractor/alignment.h new file mode 100644 index 00000000..e357e468 --- /dev/null +++ b/extractor/alignment.h @@ -0,0 +1,24 @@ +#ifndef _ALIGNMENT_H_ +#define _ALIGNMENT_H_ + +#include +#include + +#include + +namespace fs = boost::filesystem; +using namespace std; + +class Alignment { + public: + Alignment(const string& filename); + + vector > GetLinks(int sentence_index) const; + + void WriteBinary(const fs::path& filepath); + + private: + vector > > alignments; +}; + +#endif diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc new file mode 100644 index 00000000..7b018876 --- /dev/null +++ b/extractor/binary_search_merger.cc @@ -0,0 +1,245 @@ +#include "binary_search_merger.h" + +#include "data_array.h" +#include "linear_merger.h" +#include "matching.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "vocabulary.h" + +double BinarySearchMerger::BAEZA_YATES_FACTOR = 1.0; + +BinarySearchMerger::BinarySearchMerger( + shared_ptr vocabulary, + shared_ptr linear_merger, + shared_ptr data_array, + shared_ptr comparator, + bool force_binary_search_merge) : + vocabulary(vocabulary), linear_merger(linear_merger), + data_array(data_array), comparator(comparator), + force_binary_search_merge(force_binary_search_merge) {} + +void BinarySearchMerger::Merge( + vector& locations, const Phrase& phrase, const Phrase& suffix, + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const { + if (IsIntersectionVoid(prefix_start, prefix_end, suffix_start, suffix_end, + prefix_subpatterns, suffix_subpatterns, suffix)) { + return; + } + + int prefix_set_size = prefix_end - prefix_start; + int suffix_set_size = suffix_end - suffix_start; + if (ShouldUseLinearMerge(prefix_set_size, suffix_set_size)) { + linear_merger->Merge(locations, phrase, suffix, prefix_start, prefix_end, + suffix_start, suffix_end, prefix_subpatterns, suffix_subpatterns); + return; + } + + vector::iterator low, high, prefix_low, prefix_high, suffix_mid; + if (prefix_set_size > suffix_set_size) { + // Binary search on the prefix set. + suffix_mid = GetMiddle(suffix_start, suffix_end, suffix_subpatterns); + low = prefix_start, high = prefix_end; + while (low < high) { + vector::iterator prefix_mid = + GetMiddle(low, high, prefix_subpatterns); + + GetComparableMatchings(prefix_start, prefix_end, prefix_mid, + prefix_subpatterns, prefix_low, prefix_high); + int comparison = CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix); + if (comparison == 0) { + break; + } else if (comparison < 0) { + low = prefix_mid + prefix_subpatterns; + } else { + high = prefix_mid; + } + } + } else { + // Binary search on the suffix set. + vector::iterator prefix_mid = + GetMiddle(prefix_start, prefix_end, prefix_subpatterns); + + GetComparableMatchings(prefix_start, prefix_end, prefix_mid, + prefix_subpatterns, prefix_low, prefix_high); + low = suffix_start, high = suffix_end; + while (low < high) { + suffix_mid = GetMiddle(low, high, suffix_subpatterns); + + int comparison = CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix); + if (comparison == 0) { + break; + } else if (comparison > 0) { + low = suffix_mid + suffix_subpatterns; + } else { + high = suffix_mid; + } + } + } + + vector result; + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + vector::iterator suffix_low, suffix_high; + if (low < high) { + // We found a group of prefixes with the same starting position that give + // different results when compared to the found suffix. + // Find all matching suffixes for the previously found set of prefixes. + suffix_low = suffix_mid; + suffix_high = suffix_mid + suffix_subpatterns; + for (auto i = prefix_low; i != prefix_high; i += prefix_subpatterns) { + Matching left(i, prefix_subpatterns, data_array->GetSentenceId(*i)); + while (suffix_low != suffix_start) { + Matching right(suffix_low - suffix_subpatterns, suffix_subpatterns, + data_array->GetSentenceId(*(suffix_low - suffix_subpatterns))); + if (comparator->Compare(left, right, last_chunk_len, offset) <= 0) { + suffix_low -= suffix_subpatterns; + } else { + break; + } + } + + for (auto j = suffix_low; j != suffix_end; j += suffix_subpatterns) { + Matching right(j, suffix_subpatterns, data_array->GetSentenceId(*j)); + int comparison = comparator->Compare(left, right, last_chunk_len, + offset); + if (comparison == 0) { + vector merged = left.Merge(right, phrase.Arity() + 1); + result.insert(result.end(), merged.begin(), merged.end()); + } else if (comparison < 0) { + break; + } + suffix_high = max(suffix_high, j + suffix_subpatterns); + } + } + + swap(suffix_low, suffix_high); + } else if (prefix_set_size > suffix_set_size) { + // We did the binary search on the prefix set. + suffix_low = suffix_mid; + suffix_high = suffix_mid + suffix_subpatterns; + if (CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix) < 0) { + prefix_low = prefix_high; + } else { + prefix_high = prefix_low; + } + } else { + // We did the binary search on the suffix set. + if (CompareMatchingsSet(prefix_low, prefix_high, suffix_mid, + prefix_subpatterns, suffix_subpatterns, suffix) < 0) { + suffix_low = suffix_mid; + suffix_high = suffix_mid; + } else { + suffix_low = suffix_mid + suffix_subpatterns; + suffix_high = suffix_mid + suffix_subpatterns; + } + } + + Merge(locations, phrase, suffix, prefix_start, prefix_low, suffix_start, + suffix_low, prefix_subpatterns, suffix_subpatterns); + locations.insert(locations.end(), result.begin(), result.end()); + Merge(locations, phrase, suffix, prefix_high, prefix_end, suffix_high, + suffix_end, prefix_subpatterns, suffix_subpatterns); +} + +bool BinarySearchMerger::IsIntersectionVoid( + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns, + const Phrase& suffix) const { + // Is any of the sets empty? + if (prefix_start >= prefix_end || suffix_start >= suffix_end) { + return true; + } + + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + // Is the first value from the first set larger than the last value in the + // second set? + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + Matching right(suffix_end - suffix_subpatterns, suffix_subpatterns, + data_array->GetSentenceId(*(suffix_end - suffix_subpatterns))); + if (comparator->Compare(left, right, last_chunk_len, offset) > 0) { + return true; + } + + // Is the last value from the first set smaller than the first value in the + // second set? + left = Matching(prefix_end - prefix_subpatterns, prefix_subpatterns, + data_array->GetSentenceId(*(prefix_end - prefix_subpatterns))); + right = Matching(suffix_start, suffix_subpatterns, + data_array->GetSentenceId(*suffix_start)); + if (comparator->Compare(left, right, last_chunk_len, offset) < 0) { + return true; + } + + return false; +} + +bool BinarySearchMerger::ShouldUseLinearMerge( + int prefix_set_size, int suffix_set_size) const { + if (force_binary_search_merge) { + return false; + } + + int min_size = min(prefix_set_size, suffix_set_size); + int max_size = max(prefix_set_size, suffix_set_size); + + return BAEZA_YATES_FACTOR * min_size * log2(max_size) > max_size; +} + +vector::iterator BinarySearchMerger::GetMiddle( + vector::iterator low, vector::iterator high, + int num_subpatterns) const { + return low + (((high - low) / num_subpatterns) / 2) * num_subpatterns; +} + +void BinarySearchMerger::GetComparableMatchings( + const vector::iterator& prefix_start, + const vector::iterator& prefix_end, + const vector::iterator& prefix_mid, + int num_subpatterns, + vector::iterator& prefix_low, + vector::iterator& prefix_high) const { + prefix_low = prefix_mid; + while (prefix_low != prefix_start + && *prefix_mid == *(prefix_low - num_subpatterns)) { + prefix_low -= num_subpatterns; + } + prefix_high = prefix_mid + num_subpatterns; + while (prefix_high != prefix_end + && *prefix_mid == *prefix_high) { + prefix_high += num_subpatterns; + } +} + +int BinarySearchMerger::CompareMatchingsSet( + const vector::iterator& prefix_start, + const vector::iterator& prefix_end, + const vector::iterator& suffix_mid, + int prefix_subpatterns, + int suffix_subpatterns, + const Phrase& suffix) const { + int result = 0; + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + + Matching right(suffix_mid, suffix_subpatterns, + data_array->GetSentenceId(*suffix_mid)); + for (auto i = prefix_start; i != prefix_end; i += prefix_subpatterns) { + Matching left(i, prefix_subpatterns, data_array->GetSentenceId(*i)); + int comparison = comparator->Compare(left, right, last_chunk_len, offset); + if (i == prefix_start) { + result = comparison; + } else if (comparison != result) { + return 0; + } + } + return result; +} diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h new file mode 100644 index 00000000..0e229b3b --- /dev/null +++ b/extractor/binary_search_merger.h @@ -0,0 +1,68 @@ +#ifndef _BINARY_SEARCH_MERGER_H_ +#define _BINARY_SEARCH_MERGER_H_ + +#include +#include + +using namespace std; + +class DataArray; +class LinearMerger; +class MatchingComparator; +class Phrase; +class Vocabulary; + +class BinarySearchMerger { + public: + BinarySearchMerger(shared_ptr vocabulary, + shared_ptr linear_merger, + shared_ptr data_array, + shared_ptr comparator, + bool force_binary_search_merge = false); + + void Merge( + vector& locations, const Phrase& phrase, const Phrase& suffix, + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const; + + static double BAEZA_YATES_FACTOR; + + private: + bool IsIntersectionVoid( + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns, + const Phrase& suffix) const; + + bool ShouldUseLinearMerge(int prefix_set_size, int suffix_set_size) const; + + vector::iterator GetMiddle(vector::iterator low, + vector::iterator high, + int num_subpatterns) const; + + void GetComparableMatchings( + const vector::iterator& prefix_start, + const vector::iterator& prefix_end, + const vector::iterator& prefix_mid, + int num_subpatterns, + vector::iterator& prefix_low, + vector::iterator& prefix_high) const; + + int CompareMatchingsSet( + const vector::iterator& prefix_low, + const vector::iterator& prefix_high, + const vector::iterator& suffix_mid, + int prefix_subpatterns, + int suffix_subpatterns, + const Phrase& suffix) const; + + shared_ptr vocabulary; + shared_ptr linear_merger; + shared_ptr data_array; + shared_ptr comparator; + // Should be true only for testing. + bool force_binary_search_merge; +}; + +#endif diff --git a/extractor/binary_search_merger_test.cc b/extractor/binary_search_merger_test.cc new file mode 100644 index 00000000..20350b1e --- /dev/null +++ b/extractor/binary_search_merger_test.cc @@ -0,0 +1,157 @@ +#include + +#include + +#include "binary_search_merger.h" +#include "matching_comparator.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_vocabulary.h" +#include "mocks/mock_linear_merger.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class BinarySearchMergerTest : public Test { + protected: + virtual void SetUp() { + shared_ptr vocabulary = make_shared(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + + shared_ptr data_array = make_shared(); + EXPECT_CALL(*data_array, GetSentenceId(_)) + .WillRepeatedly(Return(1)); + + shared_ptr comparator = + make_shared(1, 20); + + phrase_builder = make_shared(vocabulary); + + // We are going to force the binary_search_merger to do all the work, so we + // need to check that the linear_merger never gets called. + shared_ptr linear_merger = make_shared( + vocabulary, data_array, comparator); + EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0); + + binary_search_merger = make_shared( + vocabulary, linear_merger, data_array, comparator, true); + } + + shared_ptr binary_search_merger; + shared_ptr phrase_builder; +}; + +TEST_F(BinarySearchMergerTest, aXbTest) { + vector locations; + // Encoding for him X it (see Adam's dissertation). + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{2, 6, 10, 15}; + vector suffix_locs{0, 4, 8, 13}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + vector expected_locations{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + EXPECT_EQ(expected_locations, locations); +} + +TEST_F(BinarySearchMergerTest, aXbXcTest) { + vector locations; + // Encoding for it X him X it (see Adam's dissertation). + vector symbols{1, -1, 2, -2, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2, -2, 1}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, 8, 15, + 13, 15}; + vector suffix_locs{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 2); + + vector expected_locs{0, 2, 4, 0, 2, 8, 0, 2, 13, 0, 6, 8, 0, 6, 13, 0, + 10, 13, 4, 6, 8, 4, 6, 13, 4, 10, 13, 8, 10, 13}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(BinarySearchMergerTest, abXcXdTest) { + // Sentence: Anna has many many nuts and sour apples and juicy apples. + // Phrase: Anna has X and X apples. + vector locations; + vector symbols{1, 2, -1, 3, -2, 4}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{2, -1, 3, -2, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{1, 6, 1, 9}; + vector suffix_locs{2, 6, 8, 2, 6, 11, 2, 9, 11}; + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 3); + + vector expected_locs{1, 6, 8, 1, 6, 11, 1, 9, 11}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(BinarySearchMergerTest, LargeTest) { + vector locations; + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 20 + 1); + } + vector suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 20 + 5); + suffix_locs.push_back(i * 20 + 13); + } + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(400, locations.size()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * 20 + 1, locations[4 * i]); + EXPECT_EQ(i * 20 + 5, locations[4 * i + 1]); + EXPECT_EQ(i * 20 + 1, locations[4 * i + 2]); + EXPECT_EQ(i * 20 + 13, locations[4 * i + 3]); + } +} + +TEST_F(BinarySearchMergerTest, EmptyResultTest) { + vector locations; + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 200 + 1); + } + vector suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 200 + 101); + } + + binary_search_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(0, locations.size()); +} + +} // namespace diff --git a/extractor/compile.cc b/extractor/compile.cc new file mode 100644 index 00000000..c3ea3c8d --- /dev/null +++ b/extractor/compile.cc @@ -0,0 +1,98 @@ +#include +#include + +#include +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; + +int main(int argc, char** argv) { + po::options_description desc("Command line options"); + desc.add_options() + ("help,h", "Show available options") + ("source,f", po::value(), "Source language corpus") + ("target,e", po::value(), "Target language corpus") + ("bitext,b", po::value(), "Parallel text (source ||| target)") + ("alignment,a", po::value()->required(), "Bitext word alignment") + ("output,o", po::value()->required(), "Output path") + ("frequent", po::value()->default_value(100), + "Number of precomputed frequent patterns") + ("super_frequent", po::value()->default_value(10), + "Number of precomputed super frequent patterns") + ("max_rule_span,s", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols,l", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size,g", po::value()->default_value(1), "Minimum gap size") + ("max_phrase_len,p", po::value()->default_value(4), + "Maximum frequent phrase length") + ("min_frequency", po::value()->default_value(1000), + "Minimum number of occurences for a pharse to be considered frequent"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + + // Check for help argument before notify, so we don't need to pass in the + // required parameters. + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + po::notify(vm); + + if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { + cerr << "A paralel corpus is required. " + << "Use -f (source) with -e (target) or -b (bitext)." + << endl; + return 1; + } + + fs::path output_dir(vm["output"].as().c_str()); + if (!fs::exists(output_dir)) { + fs::create_directory(output_dir); + } + + shared_ptr source_data_array, target_data_array; + if (vm.count("bitext")) { + source_data_array = make_shared( + vm["bitext"].as(), SOURCE); + target_data_array = make_shared( + vm["bitext"].as(), TARGET); + } else { + source_data_array = make_shared(vm["source"].as()); + target_data_array = make_shared(vm["target"].as()); + } + shared_ptr source_suffix_array = + make_shared(source_data_array); + source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); + target_data_array->WriteBinary(output_dir / fs::path("e.bin")); + + Alignment alignment(vm["alignment"].as()); + alignment.WriteBinary(output_dir / fs::path("a.bin")); + + Precomputation precomputation( + source_suffix_array, + vm["frequent"].as(), + vm["super_frequent"].as(), + vm["max_rule_span"].as(), + vm["max_rule_symbols"].as(), + vm["min_gap_size"].as(), + vm["max_phrase_len"].as(), + vm["min_frequency"].as()); + precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); + + TranslationTable table(source_data_array, target_data_array, alignment); + table.WriteBinary(output_dir / fs::path("lex.bin")); + + return 0; +} diff --git a/extractor/data_array.cc b/extractor/data_array.cc new file mode 100644 index 00000000..383b08a7 --- /dev/null +++ b/extractor/data_array.cc @@ -0,0 +1,146 @@ +#include "data_array.h" + +#include +#include +#include +#include + +#include + +namespace fs = boost::filesystem; +using namespace std; + +int DataArray::END_OF_FILE = 0; +int DataArray::END_OF_LINE = 1; +string DataArray::END_OF_FILE_STR = "__END_OF_FILE__"; +string DataArray::END_OF_LINE_STR = "__END_OF_LINE__"; + +DataArray::DataArray() { + InitializeDataArray(); +} + +DataArray::DataArray(const string& filename) { + InitializeDataArray(); + ifstream infile(filename.c_str()); + vector lines; + string line; + while (getline(infile, line)) { + lines.push_back(line); + } + CreateDataArray(lines); +} + +DataArray::DataArray(const string& filename, const Side& side) { + InitializeDataArray(); + ifstream infile(filename.c_str()); + vector lines; + string line, delimiter = "|||"; + while (getline(infile, line)) { + int position = line.find(delimiter); + if (side == SOURCE) { + lines.push_back(line.substr(0, position)); + } else { + lines.push_back(line.substr(position + delimiter.size())); + } + } + CreateDataArray(lines); +} + +void DataArray::InitializeDataArray() { + word2id[END_OF_FILE_STR] = END_OF_FILE; + id2word.push_back(END_OF_FILE_STR); + word2id[END_OF_LINE_STR] = END_OF_FILE; + id2word.push_back(END_OF_LINE_STR); +} + +void DataArray::CreateDataArray(const vector& lines) { + for (size_t i = 0; i < lines.size(); ++i) { + sentence_start.push_back(data.size()); + + istringstream iss(lines[i]); + string word; + while (iss >> word) { + if (word2id.count(word) == 0) { + word2id[word] = id2word.size(); + id2word.push_back(word); + } + data.push_back(word2id[word]); + sentence_id.push_back(i); + } + data.push_back(END_OF_LINE); + sentence_id.push_back(i); + } + sentence_start.push_back(data.size()); + + data.shrink_to_fit(); + sentence_id.shrink_to_fit(); + sentence_start.shrink_to_fit(); +} + +DataArray::~DataArray() {} + +const vector& DataArray::GetData() const { + return data; +} + +int DataArray::AtIndex(int index) const { + return data[index]; +} + +int DataArray::GetSize() const { + return data.size(); +} + +int DataArray::GetVocabularySize() const { + return id2word.size(); +} + +int DataArray::GetNumSentences() const { + return sentence_start.size() - 1; +} + +int DataArray::GetSentenceStart(int position) const { + return sentence_start[position]; +} + +int DataArray::GetSentenceId(int position) const { + return sentence_id[position]; +} + +void DataArray::WriteBinary(const fs::path& filepath) const { + WriteBinary(fopen(filepath.string().c_str(), "w")); +} + +void DataArray::WriteBinary(FILE* file) const { + int size = id2word.size(); + fwrite(&size, sizeof(int), 1, file); + for (string word: id2word) { + size = word.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(word.data(), sizeof(char), size, file); + } + + size = data.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(data.data(), sizeof(int), size, file); + + size = sentence_id.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(sentence_id.data(), sizeof(int), size, file); + + size = sentence_start.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(sentence_start.data(), sizeof(int), 1, file); +} + +bool DataArray::HasWord(const string& word) const { + return word2id.count(word); +} + +int DataArray::GetWordId(const string& word) const { + return word2id.find(word)->second; +} + +string DataArray::GetWord(int word_id) const { + return id2word[word_id]; +} diff --git a/extractor/data_array.h b/extractor/data_array.h new file mode 100644 index 00000000..6d3e99d5 --- /dev/null +++ b/extractor/data_array.h @@ -0,0 +1,71 @@ +#ifndef _DATA_ARRAY_H_ +#define _DATA_ARRAY_H_ + +#include +#include +#include + +#include + +namespace fs = boost::filesystem; +using namespace std; +using namespace tr1; + +enum Side { + SOURCE, + TARGET +}; + +class DataArray { + public: + static int END_OF_FILE; + static int END_OF_LINE; + static string END_OF_FILE_STR; + static string END_OF_LINE_STR; + + DataArray(); + + DataArray(const string& filename); + + DataArray(const string& filename, const Side& side); + + virtual ~DataArray(); + + virtual const vector& GetData() const; + + virtual int AtIndex(int index) const; + + virtual int GetSize() const; + + virtual int GetVocabularySize() const; + + virtual bool HasWord(const string& word) const; + + virtual int GetWordId(const string& word) const; + + string GetWord(int word_id) const; + + int GetNumSentences() const; + + int GetSentenceStart(int position) const; + + virtual int GetSentenceId(int position) const; + + void WriteBinary(const fs::path& filepath) const; + + void WriteBinary(FILE* file) const; + + private: + void InitializeDataArray(); + void CreateDataArray(const vector& lines); + + unordered_map word2id; + vector id2word; + vector data; + // TODO(pauldb): We only need sentence_id for the source language. Maybe we + // can save some memory here. + vector sentence_id; + vector sentence_start; +}; + +#endif diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc new file mode 100644 index 00000000..772ba10e --- /dev/null +++ b/extractor/data_array_test.cc @@ -0,0 +1,70 @@ +#include + +#include +#include + +#include + +#include "data_array.h" + +using namespace std; +using namespace ::testing; +namespace fs = boost::filesystem; + +namespace { + +class DataArrayTest : public Test { + protected: + virtual void SetUp() { + string sample_test_file("sample_bitext.txt"); + source_data = make_shared(sample_test_file, SOURCE); + target_data = make_shared(sample_test_file, TARGET); + } + + shared_ptr source_data; + shared_ptr target_data; +}; + +TEST_F(DataArrayTest, TestGetData) { + vector expected_source_data{2, 3, 4, 5, 1, 2, 6, 7, 8, 5, 1}; + EXPECT_EQ(expected_source_data, source_data->GetData()); + EXPECT_EQ(expected_source_data.size(), source_data->GetSize()); + for (size_t i = 0; i < expected_source_data.size(); ++i) { + EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i)); + } + + vector expected_target_data{2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1}; + EXPECT_EQ(expected_target_data, target_data->GetData()); + EXPECT_EQ(expected_target_data.size(), target_data->GetSize()); + for (size_t i = 0; i < expected_target_data.size(); ++i) { + EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i)); + } +} + +TEST_F(DataArrayTest, TestVocabulary) { + EXPECT_EQ(9, source_data->GetVocabularySize()); + EXPECT_TRUE(source_data->HasWord("mere")); + EXPECT_EQ(4, source_data->GetWordId("mere")); + EXPECT_EQ("mere", source_data->GetWord(4)); + EXPECT_FALSE(source_data->HasWord("banane")); + + EXPECT_EQ(11, target_data->GetVocabularySize()); + EXPECT_TRUE(target_data->HasWord("apples")); + EXPECT_EQ(4, target_data->GetWordId("apples")); + EXPECT_EQ("apples", target_data->GetWord(4)); + EXPECT_FALSE(target_data->HasWord("bananas")); +} + +TEST_F(DataArrayTest, TestSentenceData) { + EXPECT_EQ(2, source_data->GetNumSentences()); + EXPECT_EQ(0, source_data->GetSentenceStart(0)); + EXPECT_EQ(5, source_data->GetSentenceStart(1)); + EXPECT_EQ(11, source_data->GetSentenceStart(2)); + + EXPECT_EQ(2, target_data->GetNumSentences()); + EXPECT_EQ(0, target_data->GetSentenceStart(0)); + EXPECT_EQ(5, target_data->GetSentenceStart(1)); + EXPECT_EQ(13, target_data->GetSentenceStart(2)); +} + +} // namespace diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc new file mode 100644 index 00000000..3014c2e9 --- /dev/null +++ b/extractor/grammar_extractor.cc @@ -0,0 +1,45 @@ +#include "grammar_extractor.h" + +#include +#include +#include + +using namespace std; + +vector Tokenize(const string& sentence) { + vector result; + result.push_back(""); + + istringstream buffer(sentence); + copy(istream_iterator(buffer), + istream_iterator(), + back_inserter(result)); + + result.push_back(""); + return result; +} + +GrammarExtractor::GrammarExtractor( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alignment, const Precomputation& precomputation, + int min_gap_size, int max_rule_span, int max_nonterminals, + int max_rule_symbols, bool use_baeza_yates) : + vocabulary(make_shared()), + rule_factory(source_suffix_array, target_data_array, alignment, + vocabulary, precomputation, min_gap_size, max_rule_span, + max_nonterminals, max_rule_symbols, use_baeza_yates) {} + +void GrammarExtractor::GetGrammar(const string& sentence) { + vector words = Tokenize(sentence); + vector word_ids = AnnotateWords(words); + rule_factory.GetGrammar(word_ids); +} + +vector GrammarExtractor::AnnotateWords(const vector& words) { + vector result; + for (string word: words) { + result.push_back(vocabulary->GetTerminalIndex(word)); + } + return result; +} diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h new file mode 100644 index 00000000..05e153fc --- /dev/null +++ b/extractor/grammar_extractor.h @@ -0,0 +1,39 @@ +#ifndef _GRAMMAR_EXTRACTOR_H_ +#define _GRAMMAR_EXTRACTOR_H_ + +#include +#include + +#include "rule_factory.h" +#include "vocabulary.h" + +using namespace std; + +class Alignment; +class DataArray; +class Precomputation; +class SuffixArray; + +class GrammarExtractor { + public: + GrammarExtractor( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alignment, + const Precomputation& precomputation, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool use_baeza_yates); + + void GetGrammar(const string& sentence); + + private: + vector AnnotateWords(const vector& words); + + shared_ptr vocabulary; + HieroCachingRuleFactory rule_factory; +}; + +#endif diff --git a/extractor/intersector.cc b/extractor/intersector.cc new file mode 100644 index 00000000..9d9b54c0 --- /dev/null +++ b/extractor/intersector.cc @@ -0,0 +1,129 @@ +#include "intersector.h" + +#include "data_array.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "phrase_location.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "veb.h" +#include "vocabulary.h" + +Intersector::Intersector(shared_ptr vocabulary, + const Precomputation& precomputation, + shared_ptr suffix_array, + shared_ptr comparator, + bool use_baeza_yates) : + vocabulary(vocabulary), + suffix_array(suffix_array), + use_baeza_yates(use_baeza_yates) { + linear_merger = make_shared( + vocabulary, suffix_array->GetData(), comparator); + binary_search_merger = make_shared( + vocabulary, linear_merger, suffix_array->GetData(), comparator); + + shared_ptr source_data_array = suffix_array->GetData(); + + const Index& precomputed_index = precomputation.GetInvertedIndex(); + for (pair, vector > entry: precomputed_index) { + vector phrase = Convert(entry.first, source_data_array); + inverted_index[phrase] = entry.second; + } + + const Index& precomputed_collocations = precomputation.GetCollocations(); + for (pair, vector > entry: precomputed_collocations) { + vector phrase = Convert(entry.first, source_data_array); + collocations[phrase] = entry.second; + } +} + +vector Intersector::Convert( + const vector& old_phrase, shared_ptr source_data_array) { + vector new_phrase; + new_phrase.reserve(old_phrase.size()); + + int arity = 0; + for (int word_id: old_phrase) { + if (word_id == Precomputation::NON_TERMINAL) { + ++arity; + new_phrase.push_back(vocabulary->GetNonterminalIndex(arity)); + } else { + new_phrase.push_back( + vocabulary->GetTerminalIndex(source_data_array->GetWord(word_id))); + } + } + + return new_phrase; +} + +PhraseLocation Intersector::Intersect( + const Phrase& prefix, PhraseLocation& prefix_location, + const Phrase& suffix, PhraseLocation& suffix_location, + const Phrase& phrase) { + vector symbols = phrase.Get(); + + // We should never attempt to do an intersect query for a pattern starting or + // ending with a non terminal. The RuleFactory should handle these cases, + // initializing the matchings list with the one for the pattern without the + // starting or ending terminal. + assert(vocabulary->IsTerminal(symbols.front()) + && vocabulary->IsTerminal(symbols.back())); + + if (collocations.count(symbols)) { + return PhraseLocation(make_shared >(collocations[symbols]), + phrase.Arity()); + } + + vector locations; + ExtendPhraseLocation(prefix, prefix_location); + ExtendPhraseLocation(suffix, suffix_location); + shared_ptr > prefix_matchings = prefix_location.matchings; + shared_ptr > suffix_matchings = suffix_location.matchings; + int prefix_subpatterns = prefix_location.num_subpatterns; + int suffix_subpatterns = prefix_location.num_subpatterns; + if (use_baeza_yates) { + binary_search_merger->Merge(locations, phrase, suffix, + prefix_matchings->begin(), prefix_matchings->end(), + suffix_matchings->begin(), suffix_matchings->end(), + prefix_subpatterns, suffix_subpatterns); + } else { + linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), + prefix_matchings->end(), suffix_matchings->begin(), + suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); + } + return PhraseLocation(shared_ptr >(new vector(locations)), + phrase.Arity() + 1); +} + +void Intersector::ExtendPhraseLocation( + const Phrase& phrase, PhraseLocation& phrase_location) { + int low = phrase_location.sa_low, high = phrase_location.sa_high; + if (phrase.Arity() || phrase_location.num_subpatterns || + phrase_location.IsEmpty()) { + return; + } + + phrase_location.num_subpatterns = 1; + + vector symbols = phrase.Get(); + if (inverted_index.count(symbols)) { + phrase_location.matchings = + make_shared >(inverted_index[symbols]); + return; + } + + vector matchings; + matchings.reserve(high - low + 1); + shared_ptr veb = VEB::Create(suffix_array->GetSize()); + for (int i = low; i < high; ++i) { + veb->Insert(suffix_array->GetSuffix(i)); + } + + int value = veb->GetMinimum(); + while (value != -1) { + matchings.push_back(value); + value = veb->GetSuccessor(value); + } + + phrase_location.matchings = make_shared >(matchings); +} diff --git a/extractor/intersector.h b/extractor/intersector.h new file mode 100644 index 00000000..874ffc1b --- /dev/null +++ b/extractor/intersector.h @@ -0,0 +1,57 @@ +#ifndef _INTERSECTOR_H_ +#define _INTERSECTOR_H_ + +#include +#include +#include + +#include + +#include "binary_search_merger.h" +#include "linear_merger.h" + +using namespace std; +using namespace tr1; + +typedef boost::hash > vector_hash; +typedef unordered_map, vector, vector_hash> Index; + +class DataArray; +class MatchingComparator; +class Phrase; +class PhraseLocation; +class Precomputation; +class SuffixArray; +class Vocabulary; + +class Intersector { + public: + Intersector( + shared_ptr vocabulary, + const Precomputation& precomputaiton, + shared_ptr source_suffix_array, + shared_ptr comparator, + bool use_baeza_yates); + + PhraseLocation Intersect( + const Phrase& prefix, PhraseLocation& prefix_location, + const Phrase& suffix, PhraseLocation& suffix_location, + const Phrase& phrase); + + private: + vector Convert(const vector& old_phrase, + shared_ptr source_data_array); + + void ExtendPhraseLocation(const Phrase& phrase, + PhraseLocation& phrase_location); + + shared_ptr vocabulary; + shared_ptr suffix_array; + shared_ptr linear_merger; + shared_ptr binary_search_merger; + Index inverted_index; + Index collocations; + bool use_baeza_yates; +}; + +#endif diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc new file mode 100644 index 00000000..59e5f34c --- /dev/null +++ b/extractor/linear_merger.cc @@ -0,0 +1,63 @@ +#include "linear_merger.h" + +#include + +#include "data_array.h" +#include "matching.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "phrase_location.h" +#include "vocabulary.h" + +LinearMerger::LinearMerger(shared_ptr vocabulary, + shared_ptr data_array, + shared_ptr comparator) : + vocabulary(vocabulary), data_array(data_array), comparator(comparator) {} + +LinearMerger::~LinearMerger() {} + +void LinearMerger::Merge( + vector& locations, const Phrase& phrase, const Phrase& suffix, + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const { + int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); + bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); + + while (prefix_start != prefix_end) { + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + + while (suffix_start != suffix_end) { + Matching right(suffix_start, suffix_subpatterns, + data_array->GetSentenceId(*suffix_start)); + if (comparator->Compare(left, right, last_chunk_len, offset) > 0) { + suffix_start += suffix_subpatterns; + } else { + break; + } + } + + int start_position = *prefix_start; + vector :: iterator i = suffix_start; + while (prefix_start != prefix_end && *prefix_start == start_position) { + Matching left(prefix_start, prefix_subpatterns, + data_array->GetSentenceId(*prefix_start)); + + while (i != suffix_end) { + Matching right(i, suffix_subpatterns, data_array->GetSentenceId(*i)); + int comparison = comparator->Compare(left, right, last_chunk_len, + offset); + if (comparison == 0) { + vector merged = left.Merge(right, phrase.Arity() + 1); + locations.insert(locations.end(), merged.begin(), merged.end()); + } else if (comparison < 0) { + break; + } + i += suffix_subpatterns; + } + + prefix_start += prefix_subpatterns; + } + } +} diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h new file mode 100644 index 00000000..7bfb9246 --- /dev/null +++ b/extractor/linear_merger.h @@ -0,0 +1,35 @@ +#ifndef _LINEAR_MERGER_H_ +#define _LINEAR_MERGER_H_ + +#include +#include + +using namespace std; + +class MatchingComparator; +class Phrase; +class PhraseLocation; +class DataArray; +class Vocabulary; + +class LinearMerger { + public: + LinearMerger(shared_ptr vocabulary, + shared_ptr data_array, + shared_ptr comparator); + + virtual ~LinearMerger(); + + virtual void Merge( + vector& locations, const Phrase& phrase, const Phrase& suffix, + vector::iterator prefix_start, vector::iterator prefix_end, + vector::iterator suffix_start, vector::iterator suffix_end, + int prefix_subpatterns, int suffix_subpatterns) const; + + private: + shared_ptr vocabulary; + shared_ptr data_array; + shared_ptr comparator; +}; + +#endif diff --git a/extractor/linear_merger_test.cc b/extractor/linear_merger_test.cc new file mode 100644 index 00000000..a6963430 --- /dev/null +++ b/extractor/linear_merger_test.cc @@ -0,0 +1,149 @@ +#include + +#include + +#include "linear_merger.h" +#include "matching_comparator.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class LinearMergerTest : public Test { + protected: + virtual void SetUp() { + shared_ptr vocabulary = make_shared(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + + shared_ptr data_array = make_shared(); + EXPECT_CALL(*data_array, GetSentenceId(_)) + .WillRepeatedly(Return(1)); + + shared_ptr comparator = + make_shared(1, 20); + + phrase_builder = make_shared(vocabulary); + linear_merger = make_shared(vocabulary, data_array, + comparator); + } + + shared_ptr linear_merger; + shared_ptr phrase_builder; +}; + +TEST_F(LinearMergerTest, aXbTest) { + vector locations; + // Encoding for him X it (see Adam's dissertation). + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{2, 6, 10, 15}; + vector suffix_locs{0, 4, 8, 13}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + vector expected_locations{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + EXPECT_EQ(expected_locations, locations); +} + +TEST_F(LinearMergerTest, aXbXcTest) { + vector locations; + // Encoding for it X him X it (see Adam's dissertation). + vector symbols{1, -1, 2, -2, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2, -2, 1}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, 8, 15, + 13, 15}; + vector suffix_locs{2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 2); + + vector expected_locs{0, 2, 4, 0, 2, 8, 0, 2, 13, 0, 6, 8, 0, 6, 13, 0, + 10, 13, 4, 6, 8, 4, 6, 13, 4, 10, 13, 8, 10, 13}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(LinearMergerTest, abXcXdTest) { + // Sentence: Anna has many many nuts and sour apples and juicy apples. + // Phrase: Anna has X and X apples. + vector locations; + vector symbols{1, 2, -1, 3, -2, 4}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{2, -1, 3, -2, 4}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs{1, 6, 1, 9}; + vector suffix_locs{2, 6, 8, 2, 6, 11, 2, 9, 11}; + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 2, 3); + + vector expected_locs{1, 6, 8, 1, 6, 11, 1, 9, 11}; + EXPECT_EQ(expected_locs, locations); +} + +TEST_F(LinearMergerTest, LargeTest) { + vector locations; + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 20 + 1); + } + vector suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 20 + 5); + suffix_locs.push_back(i * 20 + 13); + } + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(400, locations.size()); + for (int i = 0; i < 100; ++i) { + EXPECT_EQ(i * 20 + 1, locations[4 * i]); + EXPECT_EQ(i * 20 + 5, locations[4 * i + 1]); + EXPECT_EQ(i * 20 + 1, locations[4 * i + 2]); + EXPECT_EQ(i * 20 + 13, locations[4 * i + 3]); + } +} + +TEST_F(LinearMergerTest, EmptyResultTest) { + vector locations; + vector symbols{1, -1, 2}; + Phrase phrase = phrase_builder->Build(symbols); + vector suffix_symbols{-1, 2}; + Phrase suffix = phrase_builder->Build(suffix_symbols); + + vector prefix_locs; + for (int i = 0; i < 100; ++i) { + prefix_locs.push_back(i * 200 + 1); + } + vector suffix_locs; + for (int i = 0; i < 100; ++i) { + suffix_locs.push_back(i * 200 + 101); + } + + linear_merger->Merge(locations, phrase, suffix, prefix_locs.begin(), + prefix_locs.end(), suffix_locs.begin(), suffix_locs.end(), 1, 1); + + EXPECT_EQ(0, locations.size()); +} + +} // namespace diff --git a/extractor/matching.cc b/extractor/matching.cc new file mode 100644 index 00000000..16a3ed6f --- /dev/null +++ b/extractor/matching.cc @@ -0,0 +1,12 @@ +#include "matching.h" + +Matching::Matching(vector::iterator start, int len, int sentence_id) : + positions(start, start + len), sentence_id(sentence_id) {} + +vector Matching::Merge(const Matching& other, int num_subpatterns) const { + vector result = positions; + if (num_subpatterns > positions.size()) { + result.push_back(other.positions.back()); + } + return result; +} diff --git a/extractor/matching.h b/extractor/matching.h new file mode 100644 index 00000000..4c46559e --- /dev/null +++ b/extractor/matching.h @@ -0,0 +1,18 @@ +#ifndef _MATCHING_H_ +#define _MATCHING_H_ + +#include +#include + +using namespace std; + +struct Matching { + Matching(vector::iterator start, int len, int sentence_id); + + vector Merge(const Matching& other, int num_subpatterns) const; + + vector positions; + int sentence_id; +}; + +#endif diff --git a/extractor/matching_comparator.cc b/extractor/matching_comparator.cc new file mode 100644 index 00000000..03db95c0 --- /dev/null +++ b/extractor/matching_comparator.cc @@ -0,0 +1,46 @@ +#include "matching_comparator.h" + +#include "matching.h" +#include "vocabulary.h" + +MatchingComparator::MatchingComparator(int min_gap_size, int max_rule_span) : + min_gap_size(min_gap_size), max_rule_span(max_rule_span) {} + +int MatchingComparator::Compare(const Matching& left, + const Matching& right, + int last_chunk_len, + bool offset) const { + if (left.sentence_id != right.sentence_id) { + return left.sentence_id < right.sentence_id ? -1 : 1; + } + + if (left.positions.size() == 1 && right.positions.size() == 1) { + // The prefix and the suffix must be separated by a non-terminal, otherwise + // we would be doing a suffix array lookup. + if (right.positions[0] - left.positions[0] <= min_gap_size) { + return 1; + } + } else if (offset) { + for (size_t i = 1; i < left.positions.size(); ++i) { + if (left.positions[i] != right.positions[i - 1]) { + return left.positions[i] < right.positions[i - 1] ? -1 : 1; + } + } + } else { + if (left.positions[0] + 1 != right.positions[0]) { + return left.positions[0] + 1 < right.positions[0] ? -1 : 1; + } + for (size_t i = 1; i < left.positions.size(); ++i) { + if (left.positions[i] != right.positions[i]) { + return left.positions[i] < right.positions[i] ? -1 : 1; + } + } + } + + if (right.positions.back() + last_chunk_len - left.positions.front() > + max_rule_span) { + return -1; + } + + return 0; +} diff --git a/extractor/matching_comparator.h b/extractor/matching_comparator.h new file mode 100644 index 00000000..6e1bb487 --- /dev/null +++ b/extractor/matching_comparator.h @@ -0,0 +1,23 @@ +#ifndef _MATCHING_COMPARATOR_H_ +#define _MATCHING_COMPARATOR_H_ + +#include + +using namespace std; + +class Vocabulary; +class Matching; + +class MatchingComparator { + public: + MatchingComparator(int min_gap_size, int max_rule_span); + + int Compare(const Matching& left, const Matching& right, + int last_chunk_len, bool offset) const; + + private: + int min_gap_size; + int max_rule_span; +}; + +#endif diff --git a/extractor/matching_comparator_test.cc b/extractor/matching_comparator_test.cc new file mode 100644 index 00000000..b8f898cf --- /dev/null +++ b/extractor/matching_comparator_test.cc @@ -0,0 +1,139 @@ +#include + +#include "matching.h" +#include "matching_comparator.h" + +using namespace ::testing; + +namespace { + +class MatchingComparatorTest : public Test { + protected: + virtual void SetUp() { + comparator = make_shared(1, 20); + } + + shared_ptr comparator; +}; + +TEST_F(MatchingComparatorTest, SmallerSentenceId) { + vector left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector right_locations{100}; + Matching right(right_locations.begin(), 1, 5); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreaterSentenceId) { + vector left_locations{100}; + Matching left(left_locations.begin(), 1, 5); + vector right_locations{1}; + Matching right(right_locations.begin(), 1, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmalleraXb) { + vector left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector right_locations{21}; + Matching right(right_locations.begin(), 1, 1); + // The matching exceeds the max rule span. + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, EqualaXb) { + vector left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector lower_right_locations{3}; + Matching right(lower_right_locations.begin(), 1, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); + + vector higher_right_locations{20}; + right = Matching(higher_right_locations.begin(), 1, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreateraXb) { + vector left_locations{1}; + Matching left(left_locations.begin(), 1, 1); + vector right_locations{2}; + Matching right(right_locations.begin(), 1, 1); + // The gap between the prefix and the suffix is of size 0, less than the + // min gap size. + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmalleraXbXc) { + vector left_locations{1, 3}; + Matching left(left_locations.begin(), 2, 1); + vector right_locations{4, 6}; + // The common part doesn't match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); + + // The common part matches, but the rule exceeds the max span. + vector other_right_locations{3, 21}; + right = Matching(other_right_locations.begin(), 2, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, EqualaXbXc) { + vector left_locations{1, 3}; + Matching left(left_locations.begin(), 2, 1); + vector right_locations{3, 5}; + // The leftmost possible match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); + + // The rightmost possible match. + vector other_right_locations{3, 20}; + right = Matching(other_right_locations.begin(), 2, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, GreateraXbXc) { + vector left_locations{1, 4}; + Matching left(left_locations.begin(), 2, 1); + vector right_locations{3, 5}; + // The common part doesn't match. + Matching right(right_locations.begin(), 2, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, true)); +} + +TEST_F(MatchingComparatorTest, SmallerabXcXd) { + vector left_locations{9, 13}; + Matching left(left_locations.begin(), 2, 1); + // The suffix doesn't start on the next position. + vector right_locations{11, 13, 15}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, false)); + + // The common part doesn't match. + vector other_right_locations{10, 16, 18}; + right = Matching(other_right_locations.begin(), 3, 1); + EXPECT_EQ(-1, comparator->Compare(left, right, 1, false)); +} + +TEST_F(MatchingComparatorTest, EqualabXcXd) { + vector left_locations{10, 13}; + Matching left(left_locations.begin(), 2, 1); + vector right_locations{11, 13, 15}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(0, comparator->Compare(left, right, 1, false)); +} + +TEST_F(MatchingComparatorTest, GreaterabXcXd) { + vector left_locations{9, 15}; + Matching left(left_locations.begin(), 2, 1); + // The suffix doesn't start on the next position. + vector right_locations{7, 15, 17}; + Matching right(right_locations.begin(), 3, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, false)); + + // The common part doesn't match. + vector other_right_locations{10, 13, 16}; + right = Matching(other_right_locations.begin(), 3, 1); + EXPECT_EQ(1, comparator->Compare(left, right, 1, false)); +} + +} // namespace diff --git a/extractor/matching_test.cc b/extractor/matching_test.cc new file mode 100644 index 00000000..9593aa86 --- /dev/null +++ b/extractor/matching_test.cc @@ -0,0 +1,25 @@ +#include + +#include + +#include "matching.h" + +using namespace std; + +namespace { + +TEST(MatchingTest, SameSize) { + vector positions{1, 2, 3}; + Matching left(positions.begin(), positions.size(), 0); + Matching right(positions.begin(), positions.size(), 0); + EXPECT_EQ(positions, left.Merge(right, positions.size())); +} + +TEST(MatchingTest, DifferentSize) { + vector positions{1, 2, 3}; + Matching left(positions.begin(), positions.size() - 1, 0); + Matching right(positions.begin() + 1, positions.size() - 1, 0); + vector result = left.Merge(right, positions.size()); +} + +} // namespace diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc new file mode 100644 index 00000000..ba4edab1 --- /dev/null +++ b/extractor/matchings_finder.cc @@ -0,0 +1,17 @@ +#include "matchings_finder.h" + +#include "suffix_array.h" +#include "phrase_location.h" + +MatchingsFinder::MatchingsFinder(shared_ptr suffix_array) : + suffix_array(suffix_array) {} + +PhraseLocation MatchingsFinder::Find(PhraseLocation& location, + const string& word, int offset) { + if (location.sa_low == -1 && location.sa_high == -1) { + location.sa_low = 0; + location.sa_high = suffix_array->GetSize(); + } + + return suffix_array->Lookup(location.sa_low, location.sa_high, word, offset); +} diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h new file mode 100644 index 00000000..0458a4d8 --- /dev/null +++ b/extractor/matchings_finder.h @@ -0,0 +1,22 @@ +#ifndef _MATCHINGS_FINDER_H_ +#define _MATCHINGS_FINDER_H_ + +#include +#include + +using namespace std; + +class PhraseLocation; +class SuffixArray; + +class MatchingsFinder { + public: + MatchingsFinder(shared_ptr suffix_array); + + PhraseLocation Find(PhraseLocation& location, const string& word, int offset); + + private: + shared_ptr suffix_array; +}; + +#endif diff --git a/extractor/matchings_finder_test.cc b/extractor/matchings_finder_test.cc new file mode 100644 index 00000000..817f1635 --- /dev/null +++ b/extractor/matchings_finder_test.cc @@ -0,0 +1,42 @@ +#include + +#include + +#include "matchings_finder.h" +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class MatchingsFinderTest : public Test { + protected: + virtual void SetUp() { + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, Lookup(0, 10, _, _)) + .Times(1) + .WillOnce(Return(PhraseLocation(3, 5))); + + matchings_finder = make_shared(suffix_array); + } + + shared_ptr matchings_finder; + shared_ptr suffix_array; +}; + +TEST_F(MatchingsFinderTest, TestFind) { + PhraseLocation phrase_location(0, 10), expected_result(3, 5); + EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); +} + +TEST_F(MatchingsFinderTest, ResizeUnsetRange) { + EXPECT_CALL(*suffix_array, GetSize()).Times(1).WillOnce(Return(10)); + + PhraseLocation phrase_location, expected_result(3, 5); + EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); + EXPECT_EQ(PhraseLocation(0, 10), phrase_location); +} + +} // namespace diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc new file mode 100644 index 00000000..851d4596 --- /dev/null +++ b/extractor/matchings_trie.cc @@ -0,0 +1,11 @@ +#include "matchings_trie.h" + +void MatchingsTrie::Reset() { + // TODO(pauldb): This is probably memory leaking because of the suffix links. + // Check if it's true and free the memory properly. + root.reset(new TrieNode()); +} + +shared_ptr MatchingsTrie::GetRoot() const { + return root; +} diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h new file mode 100644 index 00000000..f935d1a9 --- /dev/null +++ b/extractor/matchings_trie.h @@ -0,0 +1,46 @@ +#ifndef _MATCHINGS_TRIE_ +#define _MATCHINGS_TRIE_ + +#include +#include + +#include "phrase.h" +#include "phrase_location.h" + +using namespace std; +using namespace tr1; + +struct TrieNode { + TrieNode(shared_ptr suffix_link = shared_ptr(), + Phrase phrase = Phrase(), + PhraseLocation matchings = PhraseLocation()) : + suffix_link(suffix_link), phrase(phrase), matchings(matchings) {} + + void AddChild(int key, shared_ptr child_node) { + children[key] = child_node; + } + + bool HasChild(int key) { + return children.count(key); + } + + shared_ptr GetChild(int key) { + return children[key]; + } + + shared_ptr suffix_link; + Phrase phrase; + PhraseLocation matchings; + unordered_map > children; +}; + +class MatchingsTrie { + public: + void Reset(); + shared_ptr GetRoot() const; + + private: + shared_ptr root; +}; + +#endif diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h new file mode 100644 index 00000000..cda8f7a6 --- /dev/null +++ b/extractor/mocks/mock_data_array.h @@ -0,0 +1,14 @@ +#include + +#include "../data_array.h" + +class MockDataArray : public DataArray { + public: + MOCK_CONST_METHOD0(GetData, const vector&()); + MOCK_CONST_METHOD1(AtIndex, int(int index)); + MOCK_CONST_METHOD0(GetSize, int()); + MOCK_CONST_METHOD0(GetVocabularySize, int()); + MOCK_CONST_METHOD1(HasWord, bool(const string& word)); + MOCK_CONST_METHOD1(GetWordId, int(const string& word)); + MOCK_CONST_METHOD1(GetSentenceId, int(int position)); +}; diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h new file mode 100644 index 00000000..0defa88a --- /dev/null +++ b/extractor/mocks/mock_linear_merger.h @@ -0,0 +1,21 @@ +#include + +#include + +#include "linear_merger.h" +#include "phrase.h" + +using namespace std; + +class MockLinearMerger: public LinearMerger { + public: + MockLinearMerger(shared_ptr vocabulary, + shared_ptr data_array, + shared_ptr comparator) : + LinearMerger(vocabulary, data_array, comparator) {} + + + MOCK_CONST_METHOD9(Merge, void(vector&, const Phrase&, const Phrase&, + vector::iterator, vector::iterator, vector::iterator, + vector::iterator, int, int)); +}; diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h new file mode 100644 index 00000000..38d8bad6 --- /dev/null +++ b/extractor/mocks/mock_suffix_array.h @@ -0,0 +1,17 @@ +#include + +#include + +#include "../data_array.h" +#include "../phrase_location.h" +#include "../suffix_array.h" + +using namespace std; + +class MockSuffixArray : public SuffixArray { + public: + MockSuffixArray() : SuffixArray(make_shared()) {} + + MOCK_CONST_METHOD0(GetSize, int()); + MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int)); +}; diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h new file mode 100644 index 00000000..06dea10f --- /dev/null +++ b/extractor/mocks/mock_vocabulary.h @@ -0,0 +1,8 @@ +#include + +#include "../vocabulary.h" + +class MockVocabulary : public Vocabulary { + public: + MOCK_METHOD1(GetTerminalValue, string(int word_id)); +}; diff --git a/extractor/phrase.cc b/extractor/phrase.cc new file mode 100644 index 00000000..f9bd9908 --- /dev/null +++ b/extractor/phrase.cc @@ -0,0 +1,25 @@ +#include "phrase.h" + +int Phrase::Arity() const { + return var_pos.size(); +} + +int Phrase::GetChunkLen(int index) const { + if (var_pos.size() == 0) { + return symbols.size(); + } else if (index == 0) { + return var_pos[0]; + } else if (index == var_pos.size()) { + return symbols.size() - var_pos.back() - 1; + } else { + return var_pos[index] - var_pos[index - 1] - 1; + } +} + +vector Phrase::Get() const { + return symbols; +} + +int Phrase::GetSymbol(int position) const { + return symbols[position]; +} diff --git a/extractor/phrase.h b/extractor/phrase.h new file mode 100644 index 00000000..5a5124d9 --- /dev/null +++ b/extractor/phrase.h @@ -0,0 +1,29 @@ +#ifndef _PHRASE_H_ +#define _PHRASE_H_ + +#include +#include + +#include "phrase_builder.h" + +using namespace std; + +class Phrase { + public: + friend Phrase PhraseBuilder::Build(const vector& phrase); + + int Arity() const; + + int GetChunkLen(int index) const; + + vector Get() const; + + int GetSymbol(int position) const; + + private: + vector symbols; + vector var_pos; + vector words; +}; + +#endif diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc new file mode 100644 index 00000000..7f3447e5 --- /dev/null +++ b/extractor/phrase_builder.cc @@ -0,0 +1,21 @@ +#include "phrase_builder.h" + +#include "phrase.h" +#include "vocabulary.h" + +PhraseBuilder::PhraseBuilder(shared_ptr vocabulary) : + vocabulary(vocabulary) {} + +Phrase PhraseBuilder::Build(const vector& symbols) { + Phrase phrase; + phrase.symbols = symbols; + phrase.words.resize(symbols.size()); + for (size_t i = 0; i < symbols.size(); ++i) { + if (vocabulary->IsTerminal(symbols[i])) { + phrase.words[i] = vocabulary->GetTerminalValue(symbols[i]); + } else { + phrase.var_pos.push_back(i); + } + } + return phrase; +} diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h new file mode 100644 index 00000000..f01cb23b --- /dev/null +++ b/extractor/phrase_builder.h @@ -0,0 +1,22 @@ +#ifndef _PHRASE_BUILDER_H_ +#define _PHRASE_BUILDER_H_ + +#include +#include + +using namespace std; + +class Phrase; +class Vocabulary; + +class PhraseBuilder { + public: + PhraseBuilder(shared_ptr vocabulary); + + Phrase Build(const vector& symbols); + + private: + shared_ptr vocabulary; +}; + +#endif diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc new file mode 100644 index 00000000..b5b68549 --- /dev/null +++ b/extractor/phrase_location.cc @@ -0,0 +1,35 @@ +#include "phrase_location.h" + +#include + +PhraseLocation::PhraseLocation(int sa_low, int sa_high) : + sa_low(sa_low), sa_high(sa_high), + matchings(shared_ptr >()), + num_subpatterns(0) {} + +PhraseLocation::PhraseLocation(shared_ptr > matchings, + int num_subpatterns) : + sa_high(0), sa_low(0), + matchings(matchings), + num_subpatterns(num_subpatterns) {} + +bool PhraseLocation::IsEmpty() { + return sa_low >= sa_high || (num_subpatterns > 0 && matchings->size() == 0); +} + +bool operator==(const PhraseLocation& a, const PhraseLocation& b) { + if (a.sa_low != b.sa_low || a.sa_high != b.sa_high || + a.num_subpatterns != b.num_subpatterns) { + return false; + } + + if (a.matchings == NULL && b.matchings == NULL) { + return true; + } + + if (a.matchings == NULL || b.matchings == NULL) { + return false; + } + + return *a.matchings == *b.matchings; +} diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h new file mode 100644 index 00000000..96004b33 --- /dev/null +++ b/extractor/phrase_location.h @@ -0,0 +1,23 @@ +#ifndef _PHRASE_LOCATION_H_ +#define _PHRASE_LOCATION_H_ + +#include +#include + +using namespace std; + +struct PhraseLocation { + PhraseLocation(int sa_low = -1, int sa_high = -1); + + PhraseLocation(shared_ptr > matchings, int num_subpatterns); + + bool IsEmpty(); + + friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); + + int sa_low, sa_high; + shared_ptr > matchings; + int num_subpatterns; +}; + +#endif diff --git a/extractor/phrase_test.cc b/extractor/phrase_test.cc new file mode 100644 index 00000000..2b553b6f --- /dev/null +++ b/extractor/phrase_test.cc @@ -0,0 +1,61 @@ +#include + +#include +#include + +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class PhraseTest : public Test { + protected: + virtual void SetUp() { + shared_ptr vocabulary = make_shared(); + EXPECT_CALL(*vocabulary, GetTerminalValue(_)) + .WillRepeatedly(Return("word")); + shared_ptr phrase_builder = + make_shared(vocabulary); + + symbols1 = vector{1, 2, 3}; + phrase1 = phrase_builder->Build(symbols1); + symbols2 = vector{1, 2, -1, 3, -2, 4}; + phrase2 = phrase_builder->Build(symbols2); + } + + vector symbols1, symbols2; + Phrase phrase1, phrase2; +}; + +TEST_F(PhraseTest, TestArity) { + EXPECT_EQ(0, phrase1.Arity()); + EXPECT_EQ(2, phrase2.Arity()); +} + +TEST_F(PhraseTest, GetChunkLen) { + EXPECT_EQ(3, phrase1.GetChunkLen(0)); + + EXPECT_EQ(2, phrase2.GetChunkLen(0)); + EXPECT_EQ(1, phrase2.GetChunkLen(1)); + EXPECT_EQ(1, phrase2.GetChunkLen(2)); +} + +TEST_F(PhraseTest, TestGet) { + EXPECT_EQ(symbols1, phrase1.Get()); + EXPECT_EQ(symbols2, phrase2.Get()); +} + +TEST_F(PhraseTest, TestGetSymbol) { + for (size_t i = 0; i < symbols1.size(); ++i) { + EXPECT_EQ(symbols1[i], phrase1.GetSymbol(i)); + } + for (size_t i = 0; i < symbols2.size(); ++i) { + EXPECT_EQ(symbols2[i], phrase2.GetSymbol(i)); + } +} + +} // namespace diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc new file mode 100644 index 00000000..97a70554 --- /dev/null +++ b/extractor/precomputation.cc @@ -0,0 +1,192 @@ +#include "precomputation.h" + +#include +#include +#include +#include +#include + +#include + +#include "data_array.h" +#include "suffix_array.h" + +using namespace std; +using namespace tr1; + +int Precomputation::NON_TERMINAL = -1; + +Precomputation::Precomputation( + shared_ptr suffix_array, int num_frequent_patterns, + int num_super_frequent_patterns, int max_rule_span, + int max_rule_symbols, int min_gap_size, + int max_frequent_phrase_len, int min_frequency) { + vector data = suffix_array->GetData()->GetData(); + vector > frequent_patterns = FindMostFrequentPatterns( + suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + min_frequency); + + unordered_set, boost::hash > > frequent_patterns_set; + unordered_set, boost::hash > > + super_frequent_patterns_set; + for (size_t i = 0; i < frequent_patterns.size(); ++i) { + frequent_patterns_set.insert(frequent_patterns[i]); + if (i < num_super_frequent_patterns) { + super_frequent_patterns_set.insert(frequent_patterns[i]); + } + } + + vector > matchings; + for (size_t i = 0; i < data.size(); ++i) { + if (data[i] == DataArray::END_OF_LINE) { + AddCollocations(matchings, data, max_rule_span, min_gap_size, + max_rule_symbols); + matchings.clear(); + continue; + } + vector pattern; + for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { + pattern.push_back(data[i + j - 1]); + if (frequent_patterns_set.count(pattern)) { + inverted_index[pattern].push_back(i); + int is_super_frequent = super_frequent_patterns_set.count(pattern); + matchings.push_back(make_tuple(i, j, is_super_frequent)); + } else { + // If the current pattern is not frequent, any longer pattern having the + // current pattern as prefix will not be frequent. + break; + } + } + } +} + +vector > Precomputation::FindMostFrequentPatterns( + shared_ptr suffix_array, const vector& data, + int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + vector lcp = suffix_array->BuildLCPArray(); + vector run_start(max_frequent_phrase_len); + + priority_queue > > heap; + for (size_t i = 1; i < lcp.size(); ++i) { + for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { + int frequency = i - run_start[len]; + // TODO(pauldb): Only add patterns that don't span across multiple + // sentences. + if (frequency >= min_frequency) { + heap.push(make_pair(frequency, + make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); + } + run_start[len] = i; + } + } + + vector > frequent_patterns; + for (size_t i = 0; i < num_frequent_patterns && !heap.empty(); ++i) { + int start = heap.top().second.first; + int len = heap.top().second.second; + heap.pop(); + + vector pattern(data.begin() + start, data.begin() + start + len); + frequent_patterns.push_back(pattern); + } + return frequent_patterns; +} + +void Precomputation::AddCollocations( + const vector >& matchings, const vector& data, + int max_rule_span, int min_gap_size, int max_rule_symbols) { + for (size_t i = 0; i < matchings.size(); ++i) { + int start1, size1, is_super1; + tie(start1, size1, is_super1) = matchings[i]; + + for (size_t j = i + 1; j < matchings.size(); ++j) { + int start2, size2, is_super2; + tie(start2, size2, is_super2) = matchings[j]; + if (start2 - start1 >= max_rule_span) { + break; + } + + if (start2 - start1 - size1 >= min_gap_size + && start2 + size2 - size1 <= max_rule_span + && size1 + size2 + 1 <= max_rule_symbols) { + vector pattern(data.begin() + start1, + data.begin() + start1 + size1); + pattern.push_back(Precomputation::NON_TERMINAL); + pattern.insert(pattern.end(), data.begin() + start2, + data.begin() + start2 + size2); + AddStartPositions(collocations[pattern], start1, start2); + + if (is_super2) { + pattern.push_back(Precomputation::NON_TERMINAL); + for (size_t k = j + 1; k < matchings.size(); ++k) { + int start3, size3, is_super3; + tie(start3, size3, is_super3) = matchings[k]; + if (start3 - start1 >= max_rule_span) { + break; + } + + if (start3 - start2 - size2 >= min_gap_size + && start3 + size3 - size1 <= max_rule_span + && size1 + size2 + size3 + 2 <= max_rule_symbols + && (is_super1 || is_super3)) { + pattern.insert(pattern.end(), data.begin() + start3, + data.begin() + start3 + size3); + AddStartPositions(collocations[pattern], start1, start2, start3); + pattern.erase(pattern.end() - size3); + } + } + } + } + } + } +} + +void Precomputation::AddStartPositions( + vector& positions, int pos1, int pos2) { + positions.push_back(pos1); + positions.push_back(pos2); +} + +void Precomputation::AddStartPositions( + vector& positions, int pos1, int pos2, int pos3) { + positions.push_back(pos1); + positions.push_back(pos2); + positions.push_back(pos3); +} + +void Precomputation::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "w"); + + // TODO(pauldb): Refactor this code. + int size = inverted_index.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: inverted_index) { + size = entry.first.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.first.data(), sizeof(int), size, file); + + size = entry.second.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.second.data(), sizeof(int), size, file); + } + + size = collocations.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: collocations) { + size = entry.first.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.first.data(), sizeof(int), size, file); + + size = entry.second.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(entry.second.data(), sizeof(int), size, file); + } +} + +const Index& Precomputation::GetInvertedIndex() const { + return inverted_index; +} + +const Index& Precomputation::GetCollocations() const { + return collocations; +} diff --git a/extractor/precomputation.h b/extractor/precomputation.h new file mode 100644 index 00000000..0d1b269f --- /dev/null +++ b/extractor/precomputation.h @@ -0,0 +1,52 @@ +#ifndef _PRECOMPUTATION_H_ +#define _PRECOMPUTATION_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = boost::filesystem; +using namespace std; +using namespace tr1; + +class SuffixArray; + +typedef boost::hash > vector_hash; +typedef unordered_map, vector, vector_hash> Index; + +class Precomputation { + public: + Precomputation( + shared_ptr suffix_array, int num_frequent_patterns, + int num_super_frequent_patterns, int max_rule_span, + int max_rule_symbols, int min_gap_size, + int max_frequent_phrase_len, int min_frequency); + + void WriteBinary(const fs::path& filepath) const; + + const Index& GetInvertedIndex() const; + const Index& GetCollocations() const; + + static int NON_TERMINAL; + + private: + vector > FindMostFrequentPatterns( + shared_ptr suffix_array, const vector& data, + int num_frequent_patterns, int max_frequent_phrase_len, + int min_frequency); + void AddCollocations( + const vector >& matchings, const vector& data, + int max_rule_span, int min_gap_size, int max_rule_symbols); + void AddStartPositions(vector& positions, int pos1, int pos2); + void AddStartPositions(vector& positions, int pos1, int pos2, int pos3); + + Index inverted_index; + Index collocations; +}; + +#endif diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc new file mode 100644 index 00000000..48b39b63 --- /dev/null +++ b/extractor/rule_extractor.cc @@ -0,0 +1,10 @@ +#include "rule_extractor.h" + +RuleExtractor::RuleExtractor( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alingment) { +} + +void RuleExtractor::ExtractRules() { +} diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h new file mode 100644 index 00000000..13b5447a --- /dev/null +++ b/extractor/rule_extractor.h @@ -0,0 +1,22 @@ +#ifndef _RULE_EXTRACTOR_H_ +#define _RULE_EXTRACTOR_H_ + +#include + +using namespace std; + +class Alignment; +class DataArray; +class SuffixArray; + +class RuleExtractor { + public: + RuleExtractor( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alingment); + + void ExtractRules(); +}; + +#endif diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc new file mode 100644 index 00000000..7a8356b8 --- /dev/null +++ b/extractor/rule_factory.cc @@ -0,0 +1,215 @@ +#include "rule_factory.h" + +#include +#include +#include +#include + +#include "matching_comparator.h" +#include "phrase.h" +#include "suffix_array.h" +#include "vocabulary.h" + +using namespace std; +using namespace tr1; + +struct State { + State(int start, int end, const vector& phrase, + const vector& subpatterns_start, shared_ptr node, + bool starts_with_x) : + start(start), end(end), phrase(phrase), + subpatterns_start(subpatterns_start), node(node), + starts_with_x(starts_with_x) {} + + int start, end; + vector phrase, subpatterns_start; + shared_ptr node; + bool starts_with_x; +}; + +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alignment, + const shared_ptr& vocabulary, + const Precomputation& precomputation, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool use_baeza_yates) : + matchings_finder(source_suffix_array), + intersector(vocabulary, precomputation, source_suffix_array, + make_shared(min_gap_size, max_rule_span), + use_baeza_yates), + phrase_builder(vocabulary), + rule_extractor(source_suffix_array, target_data_array, alignment), + vocabulary(vocabulary), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_nonterminals + 1), + max_rule_symbols(max_rule_symbols) {} + +void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { + // Clear cache for every new sentence. + trie.Reset(); + shared_ptr root = trie.GetRoot(); + + int first_x = vocabulary->GetNonterminalIndex(1); + shared_ptr x_root(new TrieNode(root)); + root->AddChild(first_x, x_root); + + queue states; + for (size_t i = 0; i < word_ids.size(); ++i) { + states.push(State(i, i, vector(), vector(1, i), root, false)); + } + for (size_t i = min_gap_size; i < word_ids.size(); ++i) { + states.push(State(i - min_gap_size, i, vector(1, first_x), + vector(1, i), x_root, true)); + } + + while (!states.empty()) { + State state = states.front(); + states.pop(); + + shared_ptr node = state.node; + vector phrase = state.phrase; + int word_id = word_ids[state.end]; + phrase.push_back(word_id); + Phrase next_phrase = phrase_builder.Build(phrase); + shared_ptr next_node; + + if (CannotHaveMatchings(node, word_id)) { + if (!node->HasChild(word_id)) { + node->AddChild(word_id, shared_ptr()); + } + continue; + } + + if (RequiresLookup(node, word_id)) { + shared_ptr next_suffix_link = + node->suffix_link->GetChild(word_id); + if (state.starts_with_x) { + // If the phrase starts with a non terminal, we simply use the matchings + // from the suffix link. + next_node = shared_ptr(new TrieNode( + next_suffix_link, next_phrase, next_suffix_link->matchings)); + } else { + PhraseLocation phrase_location; + if (next_phrase.Arity() > 0) { + phrase_location = intersector.Intersect( + node->phrase, + node->matchings, + next_suffix_link->phrase, + next_suffix_link->matchings, + next_phrase); + } else { + phrase_location = matchings_finder.Find( + node->matchings, + vocabulary->GetTerminalValue(word_id), + state.phrase.size()); + } + + if (phrase_location.IsEmpty()) { + continue; + } + next_node = shared_ptr(new TrieNode( + next_suffix_link, next_phrase, phrase_location)); + } + node->AddChild(word_id, next_node); + + // Automatically adds a trailing non terminal if allowed. Simply copy the + // matchings from the prefix node. + AddTrailingNonterminal(phrase, next_phrase, next_node, + state.starts_with_x); + + if (!state.starts_with_x) { + rule_extractor.ExtractRules(); + } + } else { + next_node = node->GetChild(word_id); + } + + vector new_states = ExtendState(word_ids, state, phrase, next_phrase, + next_node); + for (State new_state: new_states) { + states.push(new_state); + } + } +} + +bool HieroCachingRuleFactory::CannotHaveMatchings( + shared_ptr node, int word_id) { + if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) { + return true; + } + + shared_ptr suffix_link = node->suffix_link; + return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL; +} + +bool HieroCachingRuleFactory::RequiresLookup( + shared_ptr node, int word_id) { + return !node->HasChild(word_id); +} + +void HieroCachingRuleFactory::AddTrailingNonterminal( + vector symbols, + const Phrase& prefix, + const shared_ptr& prefix_node, + bool starts_with_x) { + if (prefix.Arity() >= max_nonterminals) { + return; + } + + int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); + symbols.push_back(var_id); + Phrase var_phrase = phrase_builder.Build(symbols); + + int suffix_var_id = vocabulary->GetNonterminalIndex( + prefix.Arity() + starts_with_x == 0); + shared_ptr var_suffix_link = + prefix_node->suffix_link->GetChild(suffix_var_id); + + prefix_node->AddChild(var_id, shared_ptr(new TrieNode( + var_suffix_link, var_phrase, prefix_node->matchings))); +} + +vector HieroCachingRuleFactory::ExtendState( + const vector& word_ids, + const State& state, + vector symbols, + const Phrase& phrase, + const shared_ptr& node) { + int span = state.end - state.start; + vector new_states; + if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() || + span >= max_rule_span) { + return new_states; + } + + new_states.push_back(State(state.start, state.end + 1, symbols, + state.subpatterns_start, node, state.starts_with_x)); + + int num_subpatterns = phrase.Arity() + state.starts_with_x == 0; + if (symbols.size() + 1 >= max_rule_symbols || + phrase.Arity() >= max_nonterminals || + num_subpatterns >= max_chunks) { + return new_states; + } + + int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); + symbols.push_back(var_id); + vector subpatterns_start = state.subpatterns_start; + size_t i = state.end + 1 + min_gap_size; + while (i < word_ids.size() && i - state.start <= max_rule_span) { + subpatterns_start.push_back(i); + new_states.push_back(State(state.start, i, symbols, subpatterns_start, + node->GetChild(var_id), state.starts_with_x)); + subpatterns_start.pop_back(); + ++i; + } + + return new_states; +} diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h new file mode 100644 index 00000000..8fe8bf30 --- /dev/null +++ b/extractor/rule_factory.h @@ -0,0 +1,67 @@ +#ifndef _RULE_FACTORY_H_ +#define _RULE_FACTORY_H_ + +#include +#include + +#include "matchings_finder.h" +#include "intersector.h" +#include "matchings_trie.h" +#include "phrase_builder.h" +#include "rule_extractor.h" + +using namespace std; + +class Alignment; +class DataArray; +class Precomputation; +class State; +class SuffixArray; +class Vocabulary; + +class HieroCachingRuleFactory { + public: + HieroCachingRuleFactory( + shared_ptr source_suffix_array, + shared_ptr target_data_array, + const Alignment& alignment, + const shared_ptr& vocabulary, + const Precomputation& precomputation, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool use_beaza_yates); + + void GetGrammar(const vector& word_ids); + + private: + bool CannotHaveMatchings(shared_ptr node, int word_id); + + bool RequiresLookup(shared_ptr node, int word_id); + + void AddTrailingNonterminal(vector symbols, + const Phrase& prefix, + const shared_ptr& prefix_node, + bool starts_with_x); + + vector ExtendState(const vector& word_ids, + const State& state, + vector symbols, + const Phrase& phrase, + const shared_ptr& node); + + MatchingsFinder matchings_finder; + Intersector intersector; + MatchingsTrie trie; + PhraseBuilder phrase_builder; + RuleExtractor rule_extractor; + shared_ptr vocabulary; + int min_gap_size; + int max_rule_span; + int max_nonterminals; + int max_chunks; + int max_rule_symbols; +}; + +#endif diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc new file mode 100644 index 00000000..4f841864 --- /dev/null +++ b/extractor/run_extractor.cc @@ -0,0 +1,109 @@ +#include +#include + +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "translation_table.h" + +namespace po = boost::program_options; +using namespace std; + +int main(int argc, char** argv) { + // TODO(pauldb): Also take arguments from config file. + po::options_description desc("Command line options"); + desc.add_options() + ("help,h", "Show available options") + ("source,f", po::value(), "Source language corpus") + ("target,e", po::value(), "Target language corpus") + ("bitext,b", po::value(), "Parallel text (source ||| target)") + ("alignment,a", po::value()->required(), "Bitext word alignment") + ("frequent", po::value()->default_value(100), + "Number of precomputed frequent patterns") + ("super_frequent", po::value()->default_value(10), + "Number of precomputed super frequent patterns") + ("max_rule_span,s", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols,l", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size,g", po::value()->default_value(1), "Minimum gap size") + ("max_phrase_len,p", po::value()->default_value(4), + "Maximum frequent phrase length") + ("max_nonterminals", po::value()->default_value(2), + "Maximum number of nonterminals in a rule") + ("min_frequency", po::value()->default_value(1000), + "Minimum number of occurences for a pharse to be considered frequent") + ("baeza_yates", po::value()->default_value(true), + "Use double binary search"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + + // Check for help argument before notify, so we don't need to pass in the + // required parameters. + if (vm.count("help")) { + cout << desc << endl; + return 0; + } + + po::notify(vm); + + if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { + cerr << "A paralel corpus is required. " + << "Use -f (source) with -e (target) or -b (bitext)." + << endl; + return 1; + } + + shared_ptr source_data_array, target_data_array; + if (vm.count("bitext")) { + source_data_array = make_shared( + vm["bitext"].as(), SOURCE); + target_data_array = make_shared( + vm["bitext"].as(), TARGET); + } else { + source_data_array = make_shared(vm["source"].as()); + target_data_array = make_shared(vm["target"].as()); + } + shared_ptr source_suffix_array = + make_shared(source_data_array); + + + Alignment alignment(vm["alignment"].as()); + + Precomputation precomputation( + source_suffix_array, + vm["frequent"].as(), + vm["super_frequent"].as(), + vm["max_rule_span"].as(), + vm["max_rule_symbols"].as(), + vm["min_gap_size"].as(), + vm["max_phrase_len"].as(), + vm["min_frequency"].as()); + + TranslationTable table(source_data_array, target_data_array, alignment); + + // TODO(pauldb): Add parallelization. + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + vm["min_gap_size"].as(), + vm["max_rule_span"].as(), + vm["max_nonterminals"].as(), + vm["max_rule_symbols"].as(), + vm["baeza_yates"].as()); + + string sentence; + while (getline(cin, sentence)) { + extractor.GetGrammar(sentence); + } + + return 0; +} diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt new file mode 100644 index 00000000..93d6b39d --- /dev/null +++ b/extractor/sample_bitext.txt @@ -0,0 +1,2 @@ +ana are mere . ||| anna has apples . +ana bea mult lapte . ||| anna drinks a lot of milk . diff --git a/extractor/scorer.cc b/extractor/scorer.cc new file mode 100644 index 00000000..22d5be1a --- /dev/null +++ b/extractor/scorer.cc @@ -0,0 +1,9 @@ +#include "scorer.h" + +Scorer::Scorer(const vector& features) : features(features) {} + +Scorer::~Scorer() { + for (Feature* feature: features) { + delete feature; + } +} diff --git a/extractor/scorer.h b/extractor/scorer.h new file mode 100644 index 00000000..57405a6c --- /dev/null +++ b/extractor/scorer.h @@ -0,0 +1,19 @@ +#ifndef _SCORER_H_ +#define _SCORER_H_ + +#include + +#include "features/feature.h" + +using namespace std; + +class Scorer { + public: + Scorer(const vector& features); + ~Scorer(); + + private: + vector features; +}; + +#endif diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc new file mode 100644 index 00000000..76f00ace --- /dev/null +++ b/extractor/suffix_array.cc @@ -0,0 +1,211 @@ +#include "suffix_array.h" + +#include +#include +#include + +#include "data_array.h" +#include "phrase_location.h" + +namespace fs = boost::filesystem; +using namespace std; + +SuffixArray::SuffixArray(shared_ptr data_array) : + data_array(data_array) { + BuildSuffixArray(); +} + +SuffixArray::~SuffixArray() {} + +void SuffixArray::BuildSuffixArray() { + vector groups = data_array->GetData(); + groups.reserve(groups.size() + 1); + groups.push_back(data_array->GetVocabularySize()); + suffix_array.resize(groups.size()); + word_start.resize(data_array->GetVocabularySize() + 2); + + InitialBucketSort(groups); + + int combined_group_size = 0; + for (size_t i = 1; i < word_start.size(); ++i) { + if (word_start[i] - word_start[i - 1] == 1) { + ++combined_group_size; + suffix_array[word_start[i] - combined_group_size] = -combined_group_size; + } else { + combined_group_size = 0; + } + } + + PrefixDoublingSort(groups); + + for (size_t i = 0; i < groups.size(); ++i) { + suffix_array[groups[i]] = i; + } +} + +void SuffixArray::InitialBucketSort(vector& groups) { + for (size_t i = 0; i < groups.size(); ++i) { + ++word_start[groups[i]]; + } + + for (size_t i = 1; i < word_start.size(); ++i) { + word_start[i] += word_start[i - 1]; + } + + for (size_t i = 0; i < groups.size(); ++i) { + --word_start[groups[i]]; + suffix_array[word_start[groups[i]]] = i; + } + + for (size_t i = 0; i < suffix_array.size(); ++i) { + groups[i] = word_start[groups[i] + 1] - 1; + } +} + +void SuffixArray::PrefixDoublingSort(vector& groups) { + int step = 1; + while (suffix_array[0] != -suffix_array.size()) { + int combined_group_size = 0; + int i = 0; + while (i < suffix_array.size()) { + if (suffix_array[i] < 0) { + int skip = -suffix_array[i]; + combined_group_size += skip; + i += skip; + suffix_array[i - combined_group_size] = -combined_group_size; + } else { + combined_group_size = 0; + int j = groups[suffix_array[i]]; + TernaryQuicksort(i, j, step, groups); + i = j + 1; + } + } + step *= 2; + } +} + +void SuffixArray::TernaryQuicksort(int left, int right, int step, + vector& groups) { + if (left > right) { + return; + } + + int pivot = left + rand() % (right - left + 1); + int pivot_value = groups[suffix_array[pivot] + step]; + swap(suffix_array[pivot], suffix_array[left]); + int mid_left = left, mid_right = left; + for (int i = left + 1; i <= right; ++i) { + if (groups[suffix_array[i] + step] < pivot_value) { + ++mid_right; + int temp = suffix_array[i]; + suffix_array[i] = suffix_array[mid_right]; + suffix_array[mid_right] = suffix_array[mid_left]; + suffix_array[mid_left] = temp; + ++mid_left; + } else if (groups[suffix_array[i] + step] == pivot_value) { + ++mid_right; + int temp = suffix_array[i]; + suffix_array[i] = suffix_array[mid_right]; + suffix_array[mid_right] = temp; + } + } + + if (mid_left == mid_right) { + groups[suffix_array[mid_left]] = mid_left; + suffix_array[mid_left] = -1; + } else { + for (int i = mid_left; i <= mid_right; ++i) { + groups[suffix_array[i]] = mid_right; + } + } + + TernaryQuicksort(left, mid_left - 1, step, groups); + TernaryQuicksort(mid_right + 1, right, step, groups); +} + +vector SuffixArray::BuildLCPArray() const { + vector lcp(suffix_array.size()); + vector rank(suffix_array.size()); + const vector& data = data_array->GetData(); + + for (size_t i = 0; i < suffix_array.size(); ++i) { + rank[suffix_array[i]] = i; + } + + int prefix_len = 0; + for (size_t i = 0; i < suffix_array.size(); ++i) { + if (rank[i] == 0) { + lcp[rank[i]] = -1; + } else { + int j = suffix_array[rank[i] - 1]; + while (i + prefix_len < data.size() && j + prefix_len < data.size() + && data[i + prefix_len] == data[j + prefix_len]) { + ++prefix_len; + } + lcp[rank[i]] = prefix_len; + } + + if (prefix_len > 0) { + --prefix_len; + } + } + + return lcp; +} + +int SuffixArray::GetSuffix(int rank) const { + return suffix_array[rank]; +} + +int SuffixArray::GetSize() const { + return suffix_array.size(); +} + +shared_ptr SuffixArray::GetData() const { + return data_array; +} + +void SuffixArray::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "r"); + data_array->WriteBinary(file); + + int size = suffix_array.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(suffix_array.data(), sizeof(int), size, file); + + size = word_start.size(); + fwrite(&size, sizeof(int), 1, file); + fwrite(word_start.data(), sizeof(int), size, file); +} + +PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, + int offset) const { + if (!data_array->HasWord(word)) { + // Return empty phrase location. + return PhraseLocation(0, 0); + } + + int word_id = data_array->GetWordId(word); + if (offset == 0) { + return PhraseLocation(word_start[word_id], word_start[word_id + 1]); + } + + return PhraseLocation(LookupRangeStart(low, high, word_id, offset), + LookupRangeStart(low, high, word_id + 1, offset)); +} + +int SuffixArray::LookupRangeStart(int low, int high, int word_id, + int offset) const { + int result = high; + while (low < high) { + int middle = low + (high - low) / 2; + if (suffix_array[middle] + offset < data_array->GetSize() && + data_array->AtIndex(suffix_array[middle] + offset) < word_id) { + low = middle + 1; + } else { + result = middle; + high = middle; + } + } + return result; +} diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h new file mode 100644 index 00000000..7708f5a2 --- /dev/null +++ b/extractor/suffix_array.h @@ -0,0 +1,51 @@ +#ifndef _SUFFIX_ARRAY_H_ +#define _SUFFIX_ARRAY_H_ + +#include +#include +#include + +#include + +namespace fs = boost::filesystem; +using namespace std; + +class DataArray; +class PhraseLocation; + +class SuffixArray { + public: + SuffixArray(shared_ptr data_array); + + virtual ~SuffixArray(); + + virtual int GetSize() const; + + shared_ptr GetData() const; + + vector BuildLCPArray() const; + + int GetSuffix(int rank) const; + + virtual PhraseLocation Lookup(int low, int high, const string& word, + int offset) const; + + void WriteBinary(const fs::path& filepath) const; + + private: + void BuildSuffixArray(); + + void InitialBucketSort(vector& groups); + + void TernaryQuicksort(int left, int right, int step, vector& groups); + + void PrefixDoublingSort(vector& groups); + + int LookupRangeStart(int low, int high, int word_id, int offset) const; + + shared_ptr data_array; + vector suffix_array; + vector word_start; +}; + +#endif diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc new file mode 100644 index 00000000..d891933c --- /dev/null +++ b/extractor/suffix_array_test.cc @@ -0,0 +1,75 @@ +#include + +#include "mocks/mock_data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +#include + +using namespace std; +using namespace ::testing; + +namespace { + +class SuffixArrayTest : public Test { + protected: + virtual void SetUp() { + data = vector{5, 3, 0, 1, 3, 4, 2, 3, 5, 5, 3, 0, 1}; + data_array = make_shared(); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(6)); + EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); + suffix_array = make_shared(data_array); + } + + vector data; + shared_ptr suffix_array; + shared_ptr data_array; +}; + +TEST_F(SuffixArrayTest, TestData) { + EXPECT_EQ(data_array, suffix_array->GetData()); + EXPECT_EQ(14, suffix_array->GetSize()); +} + +TEST_F(SuffixArrayTest, TestBuildSuffixArray) { + vector expected_suffix_array{2, 11, 3, 12, 6, 1, 10, 4, 7, 5, 0, 9, 8}; + for (size_t i = 0; i < expected_suffix_array.size(); ++i) { + EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); + } +} + +TEST_F(SuffixArrayTest, TestBuildLCP) { + vector expected_lcp{-1, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1, 0}; + EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); +} + +TEST_F(SuffixArrayTest, TestLookup) { + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + } + + EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(5)); + EXPECT_EQ(PhraseLocation(10, 13), suffix_array->Lookup(0, 14, "word1", 0)); + + EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); + EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); + + EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(3)); + EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 13, "word3", 1)); + + EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(0)); + EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word4", 2)); + + EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); + EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(1)); + EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word5", 3)); + + EXPECT_EQ(PhraseLocation(10, 11), suffix_array->Lookup(10, 12, "word3", 4)); + EXPECT_EQ(PhraseLocation(10, 10), suffix_array->Lookup(10, 12, "word5", 1)); +} + +} // namespace diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc new file mode 100644 index 00000000..5eb4ffdc --- /dev/null +++ b/extractor/translation_table.cc @@ -0,0 +1,94 @@ +#include "translation_table.h" + +#include +#include + +#include + +#include "alignment.h" +#include "data_array.h" + +using namespace std; +using namespace tr1; + +TranslationTable::TranslationTable(shared_ptr source_data_array, + shared_ptr target_data_array, + const Alignment& alignment) : + source_data_array(source_data_array), target_data_array(target_data_array) { + const vector& source_data = source_data_array->GetData(); + const vector& target_data = target_data_array->GetData(); + + unordered_map source_links_count; + unordered_map target_links_count; + unordered_map, int, boost::hash > > links_count; + + for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { + vector > links = alignment.GetLinks(i); + int source_start = source_data_array->GetSentenceStart(i); + int next_source_start = source_data_array->GetSentenceStart(i + 1); + int target_start = target_data_array->GetSentenceStart(i); + int next_target_start = target_data_array->GetSentenceStart(i + 1); + vector source_sentence(source_data.begin() + source_start, + source_data.begin() + next_source_start); + vector target_sentence(target_data.begin() + target_start, + target_data.begin() + next_target_start); + vector source_linked_words(source_sentence.size()); + vector target_linked_words(target_sentence.size()); + + for (pair link: links) { + source_linked_words[link.first] = 1; + target_linked_words[link.second] = 1; + int source_word = source_sentence[link.first]; + int target_word = target_sentence[link.second]; + + ++source_links_count[source_word]; + ++target_links_count[target_word]; + ++links_count[make_pair(source_word, target_word)]; + } + + // TODO(pauldb): Something seems wrong here. No NULL word? + } + + for (pair, int> link_count: links_count) { + int source_word = link_count.first.first; + int target_word = link_count.first.second; + double score1 = 1.0 * link_count.second / source_links_count[source_word]; + double score2 = 1.0 * link_count.second / target_links_count[target_word]; + translation_probabilities[link_count.first] = make_pair(score1, score2); + } +} + +double TranslationTable::GetEgivenFScore( + const string& source_word, const string& target_word) { + if (!source_data_array->HasWord(source_word) || + !target_data_array->HasWord(target_word)) { + return -1; + } + + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + return translation_probabilities[make_pair(source_id, target_id)].first; +} + +double TranslationTable::GetFgivenEScore( + const string& source_word, const string& target_word) { + if (!source_data_array->HasWord(source_word) || + !target_data_array->HasWord(target_word) == 0) { + return -1; + } + + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + return translation_probabilities[make_pair(source_id, target_id)].second; +} + +void TranslationTable::WriteBinary(const fs::path& filepath) const { + FILE* file = fopen(filepath.string().c_str(), "w"); + + int size = translation_probabilities.size(); + fwrite(&size, sizeof(int), 1, file); + for (auto entry: translation_probabilities) { + fwrite(&entry.first, sizeof(entry.first), 1, file); + fwrite(&entry.second, sizeof(entry.second), 1, file); + } +} diff --git a/extractor/translation_table.h b/extractor/translation_table.h new file mode 100644 index 00000000..6004eca0 --- /dev/null +++ b/extractor/translation_table.h @@ -0,0 +1,38 @@ +#ifndef _TRANSLATION_TABLE_ +#define _TRANSLATION_TABLE_ + +#include +#include +#include + +#include +#include + +using namespace std; +using namespace tr1; +namespace fs = boost::filesystem; + +class Alignment; +class DataArray; + +class TranslationTable { + public: + TranslationTable( + shared_ptr source_data_array, + shared_ptr target_data_array, + const Alignment& alignment); + + double GetEgivenFScore(const string& source_word, const string& target_word); + + double GetFgivenEScore(const string& source_word, const string& target_word); + + void WriteBinary(const fs::path& filepath) const; + + private: + shared_ptr source_data_array; + shared_ptr target_data_array; + unordered_map, pair, + boost::hash > > translation_probabilities; +}; + +#endif diff --git a/extractor/veb.cc b/extractor/veb.cc new file mode 100644 index 00000000..f38f672e --- /dev/null +++ b/extractor/veb.cc @@ -0,0 +1,25 @@ +#include "veb.h" + +#include "veb_bitset.h" +#include "veb_tree.h" + +int VEB::MIN_BOTTOM_BITS = 5; +int VEB::MIN_BOTTOM_SIZE = 1 << VEB::MIN_BOTTOM_BITS; + +shared_ptr VEB::Create(int size) { + if (size > MIN_BOTTOM_SIZE) { + return shared_ptr(new VEBTree(size)); + } else { + return shared_ptr(new VEBBitset(size)); + } +} + +int VEB::GetMinimum() { + return min; +} + +int VEB::GetMaximum() { + return max; +} + +VEB::VEB(int min, int max) : min(min), max(max) {} diff --git a/extractor/veb.h b/extractor/veb.h new file mode 100644 index 00000000..c8209cf7 --- /dev/null +++ b/extractor/veb.h @@ -0,0 +1,29 @@ +#ifndef _VEB_H_ +#define _VEB_H_ + +#include + +using namespace std; + +class VEB { + public: + static shared_ptr Create(int size); + + virtual void Insert(int value) = 0; + + virtual int GetSuccessor(int value) = 0; + + int GetMinimum(); + + int GetMaximum(); + + static int MIN_BOTTOM_BITS; + static int MIN_BOTTOM_SIZE; + + protected: + VEB(int min = -1, int max = -1); + + int min, max; +}; + +#endif diff --git a/extractor/veb_bitset.cc b/extractor/veb_bitset.cc new file mode 100644 index 00000000..4e364cc5 --- /dev/null +++ b/extractor/veb_bitset.cc @@ -0,0 +1,25 @@ +#include "veb_bitset.h" + +using namespace std; + +VEBBitset::VEBBitset(int size) : bitset(size) { + min = max = -1; +} + +void VEBBitset::Insert(int value) { + bitset[value] = 1; + if (min == -1 || value < min) { + min = value; + } + if (max == - 1 || value > max) { + max = value; + } +} + +int VEBBitset::GetSuccessor(int value) { + int next_value = bitset.find_next(value); + if (next_value == bitset.npos) { + return -1; + } + return next_value; +} diff --git a/extractor/veb_bitset.h b/extractor/veb_bitset.h new file mode 100644 index 00000000..f8a91234 --- /dev/null +++ b/extractor/veb_bitset.h @@ -0,0 +1,22 @@ +#ifndef _VEB_BITSET_H_ +#define _VEB_BITSET_H_ + +#include + +#include "veb.h" + +class VEBBitset: public VEB { + public: + VEBBitset(int size); + + void Insert(int value); + + int GetMinimum(); + + int GetSuccessor(int value); + + private: + boost::dynamic_bitset<> bitset; +}; + +#endif diff --git a/extractor/veb_test.cc b/extractor/veb_test.cc new file mode 100644 index 00000000..c40c9f28 --- /dev/null +++ b/extractor/veb_test.cc @@ -0,0 +1,56 @@ +#include + +#include +#include + +#include "veb.h" + +using namespace std; + +namespace { + +class VEBTest : public ::testing::Test { + protected: + void VEBSortTester(vector values, int max_value) { + shared_ptr veb = VEB::Create(max_value); + for (int value: values) { + veb->Insert(value); + } + + sort(values.begin(), values.end()); + EXPECT_EQ(values.front(), veb->GetMinimum()); + EXPECT_EQ(values.back(), veb->GetMaximum()); + for (size_t i = 0; i + 1 < values.size(); ++i) { + EXPECT_EQ(values[i + 1], veb->GetSuccessor(values[i])); + } + EXPECT_EQ(-1, veb->GetSuccessor(values.back())); + } +}; + +TEST_F(VEBTest, SmallRange) { + vector values{8, 13, 5, 1, 4, 15, 2, 10, 6, 7}; + VEBSortTester(values, 16); +} + +TEST_F(VEBTest, MediumRange) { + vector values{167, 243, 88, 12, 137, 199, 212, 45, 150, 189}; + VEBSortTester(values, 255); +} + +TEST_F(VEBTest, LargeRangeSparse) { + vector values; + for (size_t i = 0; i < 100; ++i) { + values.push_back(i * 1000000); + } + VEBSortTester(values, 100000000); +} + +TEST_F(VEBTest, LargeRangeDense) { + vector values; + for (size_t i = 0; i < 1000000; ++i) { + values.push_back(i); + } + VEBSortTester(values, 1000000); +} + +} // namespace diff --git a/extractor/veb_tree.cc b/extractor/veb_tree.cc new file mode 100644 index 00000000..f8945445 --- /dev/null +++ b/extractor/veb_tree.cc @@ -0,0 +1,71 @@ +#include + +#include "veb_tree.h" + +VEBTree::VEBTree(int size) { + int num_bits = ceil(log2(size)); + + lower_bits = num_bits >> 1; + upper_size = (size >> lower_bits) + 1; + + clusters.reserve(upper_size); + clusters.resize(upper_size); +} + +int VEBTree::GetNextValue(int value) { + return value & ((1 << lower_bits) - 1); +} + +int VEBTree::GetCluster(int value) { + return value >> lower_bits; +} + +int VEBTree::Compose(int cluster, int value) { + return (cluster << lower_bits) + value; +} + +void VEBTree::Insert(int value) { + if (min == -1 && max == -1) { + min = max = value; + return; + } + + if (value < min) { + swap(min, value); + } + + int cluster = GetCluster(value), next_value = GetNextValue(value); + if (clusters[cluster] == NULL) { + clusters[cluster] = VEB::Create(1 << lower_bits); + if (summary == NULL) { + summary = VEB::Create(upper_size); + } + summary->Insert(cluster); + } + clusters[cluster]->Insert(next_value); + + if (value > max) { + max = value; + } +} + +int VEBTree::GetSuccessor(int value) { + if (value >= max) { + return -1; + } + if (value < min) { + return min; + } + + int cluster = GetCluster(value), next_value = GetNextValue(value); + if (clusters[cluster] != NULL && + next_value < clusters[cluster]->GetMaximum()) { + return Compose(cluster, clusters[cluster]->GetSuccessor(next_value)); + } else { + int next_cluster = summary->GetSuccessor(cluster); + if (next_cluster == -1) { + return -1; + } + return Compose(next_cluster, clusters[next_cluster]->GetMinimum()); + } +} diff --git a/extractor/veb_tree.h b/extractor/veb_tree.h new file mode 100644 index 00000000..578d3e6a --- /dev/null +++ b/extractor/veb_tree.h @@ -0,0 +1,29 @@ +#ifndef _VEB_TREE_H_ +#define _VEB_TREE_H_ + +#include +#include + +using namespace std; + +#include "veb.h" + +class VEBTree: public VEB { + public: + VEBTree(int size); + + void Insert(int value); + + int GetSuccessor(int value); + + private: + int GetNextValue(int value); + int GetCluster(int value); + int Compose(int cluster, int value); + + int lower_bits, upper_size; + shared_ptr summary; + vector > clusters; +}; + +#endif diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc new file mode 100644 index 00000000..5c379a29 --- /dev/null +++ b/extractor/vocabulary.cc @@ -0,0 +1,26 @@ +#include "vocabulary.h" + +Vocabulary::~Vocabulary() {} + +int Vocabulary::GetTerminalIndex(const string& word) { + if (!dictionary.count(word)) { + int word_id = words.size(); + dictionary[word] = word_id; + words.push_back(word); + return word_id; + } + + return dictionary[word]; +} + +int Vocabulary::GetNonterminalIndex(int position) { + return -position; +} + +bool Vocabulary::IsTerminal(int symbol) { + return symbol >= 0; +} + +string Vocabulary::GetTerminalValue(int symbol) { + return words[symbol]; +} diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h new file mode 100644 index 00000000..05744269 --- /dev/null +++ b/extractor/vocabulary.h @@ -0,0 +1,28 @@ +#ifndef _VOCABULARY_H_ +#define _VOCABULARY_H_ + +#include +#include +#include + +using namespace std; +using namespace tr1; + +class Vocabulary { + public: + virtual ~Vocabulary(); + + int GetTerminalIndex(const string& word); + + int GetNonterminalIndex(int position); + + bool IsTerminal(int symbol); + + virtual string GetTerminalValue(int symbol); + + private: + unordered_map dictionary; + vector words; +}; + +#endif -- cgit v1.2.3