summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
Diffstat (limited to 'extractor')
-rw-r--r--extractor/alignment.cc2
-rw-r--r--extractor/alignment.h2
-rw-r--r--extractor/compile.cc4
-rw-r--r--extractor/data_array.cc9
-rw-r--r--extractor/data_array.h7
-rw-r--r--extractor/data_array_test.cc4
-rw-r--r--extractor/fast_intersector.cc40
-rw-r--r--extractor/fast_intersector.h8
-rw-r--r--extractor/fast_intersector_test.cc10
-rw-r--r--extractor/grammar_extractor.cc5
-rw-r--r--extractor/grammar_extractor.h1
-rw-r--r--extractor/mocks/mock_data_array.h3
-rw-r--r--extractor/mocks/mock_precomputation.h3
-rw-r--r--extractor/mocks/mock_rule_factory.h4
-rw-r--r--extractor/precomputation.cc125
-rw-r--r--extractor/precomputation.h43
-rw-r--r--extractor/precomputation_test.cc73
-rw-r--r--extractor/rule_factory.h5
-rw-r--r--extractor/run_extractor.cc5
-rw-r--r--extractor/suffix_array.cc5
-rw-r--r--extractor/suffix_array.h2
-rw-r--r--extractor/suffix_array_test.cc8
-rw-r--r--extractor/translation_table.cc14
-rw-r--r--extractor/translation_table.h2
-rw-r--r--extractor/translation_table_test.cc14
-rw-r--r--extractor/vocabulary.cc7
26 files changed, 210 insertions, 195 deletions
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index 2278c825..4a7a14f4 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -8,9 +8,7 @@
#include <vector>
#include <boost/algorithm/string.hpp>
-#include <boost/filesystem.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/alignment.h b/extractor/alignment.h
index dc5a8b55..76c27da2 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -4,13 +4,11 @@
#include <string>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/utility.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/compile.cc b/extractor/compile.cc
index 65fdd509..0d62757e 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -13,6 +13,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "translation_table.h"
+#include "vocabulary.h"
namespace ar = boost::archive;
namespace fs = boost::filesystem;
@@ -125,9 +126,12 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ shared_ptr<Vocabulary> vocabulary;
+
start_time = Clock::now();
cerr << "Precomputing collocations..." << endl;
Precomputation precomputation(
+ vocabulary,
source_suffix_array,
vm["frequent"].as<int>(),
vm["super_frequent"].as<int>(),
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 2e4bdafb..ac0493fd 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -5,9 +5,6 @@
#include <sstream>
#include <string>
-#include <boost/filesystem.hpp>
-
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -81,7 +78,7 @@ void DataArray::CreateDataArray(const vector<string>& lines) {
DataArray::~DataArray() {}
-const vector<int>& DataArray::GetData() const {
+vector<int> DataArray::GetData() const {
return data;
}
@@ -118,10 +115,6 @@ int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
-bool DataArray::HasWord(const string& word) const {
- return word2id.count(word);
-}
-
int DataArray::GetWordId(const string& word) const {
auto result = word2id.find(word);
return result == word2id.end() ? -1 : result->second;
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 2be6a09c..c5dc8a26 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -5,13 +5,11 @@
#include <unordered_map>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -53,7 +51,7 @@ class DataArray {
virtual ~DataArray();
// Returns a vector containing the word ids.
- virtual const vector<int>& GetData() const;
+ virtual vector<int> GetData() const;
// Returns the word id at the specified position.
virtual int AtIndex(int index) const;
@@ -67,9 +65,6 @@ class DataArray {
// Returns the number of distinct words in the data array.
virtual int GetVocabularySize() const;
- // Returns whether a word has ever been observed in the data array.
- virtual bool HasWord(const string& word) const;
-
// Returns the word id for a given word or -1 if it the word has never been
// observed.
virtual int GetWordId(const string& word) const;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 6c329e34..b6b56561 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -58,16 +58,12 @@ TEST_F(DataArrayTest, TestGetData) {
TEST_F(DataArrayTest, TestVocabulary) {
EXPECT_EQ(9, source_data.GetVocabularySize());
- EXPECT_TRUE(source_data.HasWord("mere"));
EXPECT_EQ(4, source_data.GetWordId("mere"));
EXPECT_EQ("mere", source_data.GetWord(4));
- EXPECT_FALSE(source_data.HasWord("banane"));
EXPECT_EQ(11, target_data.GetVocabularySize());
- EXPECT_TRUE(target_data.HasWord("apples"));
EXPECT_EQ(4, target_data.GetWordId("apples"));
EXPECT_EQ("apples", target_data.GetWord(4));
- EXPECT_FALSE(target_data.HasWord("bananas"));
}
TEST_F(DataArrayTest, TestSentenceData) {
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
index a8591a72..0d1fa6d8 100644
--- a/extractor/fast_intersector.cc
+++ b/extractor/fast_intersector.cc
@@ -11,41 +11,22 @@
namespace extractor {
-FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array,
- shared_ptr<Precomputation> precomputation,
- shared_ptr<Vocabulary> vocabulary,
- int max_rule_span,
- int min_gap_size) :
+FastIntersector::FastIntersector(
+ shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size) :
suffix_array(suffix_array),
+ precomputation(precomputation),
vocabulary(vocabulary),
max_rule_span(max_rule_span),
- min_gap_size(min_gap_size) {
- Index precomputed_collocations = precomputation->GetCollocations();
- for (pair<vector<int>, vector<int>> entry: precomputed_collocations) {
- vector<int> phrase = ConvertPhrase(entry.first);
- collocations[phrase] = entry.second;
- }
-}
+ min_gap_size(min_gap_size) {}
FastIntersector::FastIntersector() {}
FastIntersector::~FastIntersector() {}
-vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) {
- vector<int> new_phrase;
- new_phrase.reserve(old_phrase.size());
- shared_ptr<DataArray> data_array = suffix_array->GetData();
- for (int word_id: old_phrase) {
- if (word_id < 0) {
- new_phrase.push_back(word_id);
- } else {
- new_phrase.push_back(
- vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
- }
- }
- return new_phrase;
-}
-
PhraseLocation FastIntersector::Intersect(
PhraseLocation& prefix_location,
PhraseLocation& suffix_location,
@@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect(
assert(vocabulary->IsTerminal(symbols.front())
&& vocabulary->IsTerminal(symbols.back()));
- if (collocations.count(symbols)) {
- return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
+ if (precomputation->Contains(symbols)) {
+ return PhraseLocation(precomputation->GetCollocations(symbols),
+ phrase.Arity() + 1);
}
bool prefix_ends_with_x =
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
index 2819d239..305373dc 100644
--- a/extractor/fast_intersector.h
+++ b/extractor/fast_intersector.h
@@ -12,7 +12,6 @@ using namespace std;
namespace extractor {
typedef boost::hash<vector<int>> VectorHash;
-typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
class Phrase;
class PhraseLocation;
@@ -52,11 +51,6 @@ class FastIntersector {
FastIntersector();
private:
- // Uses the vocabulary to convert the phrase from the numberized format
- // specified by the source data array to the numberized format given by the
- // vocabulary.
- vector<int> ConvertPhrase(const vector<int>& old_phrase);
-
// Estimates the number of computations needed if the prefix/suffix is
// extended. If the last/first symbol is separated from the rest of the phrase
// by a nonterminal, then for each occurrence of the prefix/suffix we need to
@@ -85,10 +79,10 @@ class FastIntersector {
pair<int, int> GetSearchRange(bool has_marginal_x) const;
shared_ptr<SuffixArray> suffix_array;
+ shared_ptr<Precomputation> precomputation;
shared_ptr<Vocabulary> vocabulary;
int max_rule_span;
int min_gap_size;
- Index collocations;
};
} // namespace extractor
diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc
index 76c3aaea..f2a26ba1 100644
--- a/extractor/fast_intersector_test.cc
+++ b/extractor/fast_intersector_test.cc
@@ -59,15 +59,13 @@ class FastIntersectorTest : public Test {
}
precomputation = make_shared<MockPrecomputation>();
- EXPECT_CALL(*precomputation, GetCollocations())
- .WillRepeatedly(ReturnRef(collocations));
+ EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false));
phrase_builder = make_shared<PhraseBuilder>(vocabulary);
intersector = make_shared<FastIntersector>(suffix_array, precomputation,
vocabulary, 15, 1);
}
- Index collocations;
shared_ptr<MockDataArray> data_array;
shared_ptr<MockSuffixArray> suffix_array;
shared_ptr<MockPrecomputation> precomputation;
@@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) {
Phrase phrase = phrase_builder->Build(symbols);
PhraseLocation prefix_location(15, 16), suffix_location(16, 17);
- collocations[symbols] = expected_location;
- EXPECT_CALL(*precomputation, GetCollocations())
- .WillRepeatedly(ReturnRef(collocations));
+ EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true));
+ EXPECT_CALL(*precomputation, GetCollocations(symbols)).
+ WillRepeatedly(Return(expected_location));
intersector = make_shared<FastIntersector>(suffix_array, precomputation,
vocabulary, 15, 1);
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 487abcaf..4d0738f7 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
- shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
+ shared_ptr<Scorer> scorer, shared_ptr<Vocabulary> vocabulary,
+ int min_gap_size, int max_rule_span,
int max_nonterminals, int max_rule_symbols, int max_samples,
bool require_tight_phrases) :
- vocabulary(make_shared<Vocabulary>()),
+ vocabulary(vocabulary),
rule_factory(make_shared<HieroCachingRuleFactory>(
source_suffix_array, target_data_array, alignment, vocabulary,
precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index ae407b47..8f570df2 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -32,6 +32,7 @@ class GrammarExtractor {
shared_ptr<Alignment> alignment,
shared_ptr<Precomputation> precomputation,
shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index 6f85abb4..edc525fa 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -6,12 +6,11 @@ namespace extractor {
class MockDataArray : public DataArray {
public:
- MOCK_CONST_METHOD0(GetData, const vector<int>&());
+ MOCK_CONST_METHOD0(GetData, vector<int>());
MOCK_CONST_METHOD1(AtIndex, int(int index));
MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
- MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
MOCK_CONST_METHOD1(GetWord, string(int word_id));
MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h
index 8753343e..5f7aa999 100644
--- a/extractor/mocks/mock_precomputation.h
+++ b/extractor/mocks/mock_precomputation.h
@@ -6,7 +6,8 @@ namespace extractor {
class MockPrecomputation : public Precomputation {
public:
- MOCK_CONST_METHOD0(GetCollocations, const Index&());
+ MOCK_CONST_METHOD1(Contains, bool(const vector<int>& pattern));
+ MOCK_CONST_METHOD1(GetCollocations, vector<int>(const vector<int>& pattern));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 86a084b5..6b7b6586 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,7 +7,9 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array));
+ MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const
+ unordered_set<int>& blacklisted_sentence_ids,
+ const shared_ptr<DataArray> source_data_array));
};
} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 3b8aed69..3e58e2a9 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -5,59 +5,67 @@
#include "data_array.h"
#include "suffix_array.h"
+#include "time_util.h"
+#include "vocabulary.h"
using namespace std;
namespace extractor {
-int Precomputation::FIRST_NONTERMINAL = -1;
-int Precomputation::SECOND_NONTERMINAL = -2;
-
Precomputation::Precomputation(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int num_super_frequent_patterns, int max_rule_span,
- int max_rule_symbols, int min_gap_size,
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+ int num_frequent_patterns, int num_super_frequent_patterns,
+ int max_rule_span, int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency) {
- vector<int> data = suffix_array->GetData()->GetData();
+ Clock::time_point start_time = Clock::now();
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ vector<int> data = data_array->GetData();
vector<vector<int>> frequent_patterns = FindMostFrequentPatterns(
suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
min_frequency);
+ Clock::time_point end_time = Clock::now();
+ cerr << "Finding most frequent patterns took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
- // Construct sets containing the frequent and superfrequent contiguous
- // collocations.
- unordered_set<vector<int>, VectorHash> frequent_patterns_set;
- unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
+ vector<vector<int>> pattern_annotations(frequent_patterns.size());
+ unordered_map<vector<int>, int, VectorHash> frequent_patterns_index;
for (size_t i = 0; i < frequent_patterns.size(); ++i) {
- frequent_patterns_set.insert(frequent_patterns[i]);
- if (i < num_super_frequent_patterns) {
- super_frequent_patterns_set.insert(frequent_patterns[i]);
- }
+ frequent_patterns_index[frequent_patterns[i]] = i;
+ pattern_annotations[i] = AnnotatePattern(vocabulary, data_array,
+ frequent_patterns[i]);
}
+ start_time = Clock::now();
vector<tuple<int, int, int>> matchings;
+ vector<vector<int>> annotations;
for (size_t i = 0; i < data.size(); ++i) {
// If the sentence is over, add all the discontiguous frequent patterns to
// the index.
if (data[i] == DataArray::END_OF_LINE) {
- AddCollocations(matchings, data, max_rule_span, min_gap_size,
- max_rule_symbols);
+ UpdateIndex(matchings, annotations, max_rule_span, min_gap_size,
+ max_rule_symbols);
matchings.clear();
+ annotations.clear();
continue;
}
- vector<int> pattern;
// Find all the contiguous frequent patterns starting at position i.
+ vector<int> pattern;
for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
pattern.push_back(data[i + j - 1]);
- if (frequent_patterns_set.count(pattern)) {
- int is_super_frequent = super_frequent_patterns_set.count(pattern);
- matchings.push_back(make_tuple(i, j, is_super_frequent));
- } else {
+ auto it = frequent_patterns_index.find(pattern);
+ if (it == frequent_patterns_index.end()) {
// If the current pattern is not frequent, any longer pattern having the
// current pattern as prefix will not be frequent.
break;
}
+ int is_super_frequent = it->second < num_super_frequent_patterns;
+ matchings.push_back(make_tuple(i, j, is_super_frequent));
+ annotations.push_back(pattern_annotations[it->second]);
}
}
+ end_time = Clock::now();
+ cerr << "Constructing collocations index took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
}
Precomputation::Precomputation() {}
@@ -75,9 +83,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
for (size_t i = 1; i < lcp.size(); ++i) {
for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {
int frequency = i - run_start[len];
- if (frequency >= min_frequency) {
- heap.push(make_pair(frequency,
- make_pair(suffix_array->GetSuffix(run_start[len]), len + 1)));
+ int start = suffix_array->GetSuffix(run_start[len]);
+ if (frequency >= min_frequency && start + len <= data.size()) {
+ heap.push(make_pair(frequency, make_pair(start, len + 1)));
}
run_start[len] = i;
}
@@ -99,8 +107,20 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
return frequent_patterns;
}
-void Precomputation::AddCollocations(
- const vector<tuple<int, int, int>>& matchings, const vector<int>& data,
+vector<int> Precomputation::AnnotatePattern(
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const {
+ vector<int> annotation;
+ for (int word_id: pattern) {
+ annotation.push_back(vocabulary->GetTerminalIndex(
+ data_array->GetWord(word_id)));
+ }
+ return annotation;
+}
+
+void Precomputation::UpdateIndex(
+ const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols) {
// Select the leftmost subpattern.
for (size_t i = 0; i < matchings.size(); ++i) {
@@ -118,16 +138,14 @@ void Precomputation::AddCollocations(
if (start2 - start1 - size1 >= min_gap_size
&& start2 + size2 - start1 <= max_rule_span
&& size1 + size2 + 1 <= max_rule_symbols) {
- vector<int> pattern(data.begin() + start1,
- data.begin() + start1 + size1);
- pattern.push_back(Precomputation::FIRST_NONTERMINAL);
- pattern.insert(pattern.end(), data.begin() + start2,
- data.begin() + start2 + size2);
- AddStartPositions(collocations[pattern], start1, start2);
+ vector<int> pattern = annotations[i];
+ pattern.push_back(-1);
+ AppendSubpattern(pattern, annotations[j]);
+ AppendCollocation(index[pattern], start1, start2);
// Try extending the binary collocation to a ternary collocation.
if (is_super2) {
- pattern.push_back(Precomputation::SECOND_NONTERMINAL);
+ pattern.push_back(-2);
// Select the rightmost subpattern.
for (size_t k = j + 1; k < matchings.size(); ++k) {
int start3, size3, is_super3;
@@ -140,9 +158,8 @@ void Precomputation::AddCollocations(
&& start3 + size3 - start1 <= max_rule_span
&& size1 + size2 + size3 + 2 <= max_rule_symbols
&& (is_super1 || is_super3)) {
- pattern.insert(pattern.end(), data.begin() + start3,
- data.begin() + start3 + size3);
- AddStartPositions(collocations[pattern], start1, start2, start3);
+ AppendSubpattern(pattern, annotations[k]);
+ AppendCollocation(index[pattern], start1, start2, start3);
pattern.erase(pattern.end() - size3);
}
}
@@ -152,25 +169,35 @@ void Precomputation::AddCollocations(
}
}
-void Precomputation::AddStartPositions(
- vector<int>& positions, int pos1, int pos2) {
- positions.push_back(pos1);
- positions.push_back(pos2);
+void Precomputation::AppendSubpattern(
+ vector<int>& pattern,
+ const vector<int>& subpattern) {
+ copy(subpattern.begin(), subpattern.end(), back_inserter(pattern));
+}
+
+void Precomputation::AppendCollocation(
+ vector<int>& collocations, int pos1, int pos2) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
+}
+
+void Precomputation::AppendCollocation(
+ vector<int>& collocations, int pos1, int pos2, int pos3) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
+ collocations.push_back(pos3);
}
-void Precomputation::AddStartPositions(
- vector<int>& positions, int pos1, int pos2, int pos3) {
- positions.push_back(pos1);
- positions.push_back(pos2);
- positions.push_back(pos3);
+bool Precomputation::Contains(const vector<int>& pattern) const {
+ return index.count(pattern);
}
-const Index& Precomputation::GetCollocations() const {
- return collocations;
+vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const {
+ return index.at(pattern);
}
bool Precomputation::operator==(const Precomputation& other) const {
- return collocations == other.collocations;
+ return index == other.index;
}
} // namespace extractor
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 9f0c9424..2b34fc29 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -7,13 +7,11 @@
#include <tuple>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/utility.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -21,7 +19,9 @@ namespace extractor {
typedef boost::hash<vector<int>> VectorHash;
typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
+class DataArray;
class SuffixArray;
+class Vocabulary;
/**
* Data structure wrapping an index with all the occurrences of the most
@@ -37,9 +37,9 @@ class Precomputation {
public:
// Constructs the index using the suffix array.
Precomputation(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int num_super_frequent_patterns, int max_rule_span,
- int max_rule_symbols, int min_gap_size,
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+ int num_frequent_patterns, int num_super_frequent_patterns,
+ int max_rule_span, int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency);
// Creates empty precomputation data structure.
@@ -47,13 +47,13 @@ class Precomputation {
virtual ~Precomputation();
- // Returns a reference to the index.
- virtual const Index& GetCollocations() const;
+ // Returns whether a pattern is contained in the index of collocations.
+ virtual bool Contains(const vector<int>& pattern) const;
- bool operator==(const Precomputation& other) const;
+ // Returns the list of collocations for a given pattern.
+ virtual vector<int> GetCollocations(const vector<int>& pattern) const;
- static int FIRST_NONTERMINAL;
- static int SECOND_NONTERMINAL;
+ bool operator==(const Precomputation& other) const;
private:
// Finds the most frequent contiguous collocations.
@@ -62,25 +62,32 @@ class Precomputation {
int num_frequent_patterns, int max_frequent_phrase_len,
int min_frequency);
+ vector<int> AnnotatePattern(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const;
+
// Given the locations of the frequent contiguous collocations in a sentence,
// it adds new entries to the index for each discontiguous collocation
// matching the criteria specified in the class description.
- void AddCollocations(
- const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data,
+ void UpdateIndex(
+ const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols);
+ void AppendSubpattern(vector<int>& pattern, const vector<int>& subpattern);
+
// Adds an occurrence of a binary collocation.
- void AddStartPositions(vector<int>& positions, int pos1, int pos2);
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2);
// Adds an occurrence of a ternary collocation.
- void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2, int pos3);
friend class boost::serialization::access;
template<class Archive> void save(Archive& ar, unsigned int) const {
- int num_entries = collocations.size();
+ int num_entries = index.size();
ar << num_entries;
- for (pair<vector<int>, vector<int>> entry: collocations) {
+ for (pair<vector<int>, vector<int>> entry: index) {
ar << entry;
}
}
@@ -91,13 +98,13 @@ class Precomputation {
for (size_t i = 0; i < num_entries; ++i) {
pair<vector<int>, vector<int>> entry;
ar >> entry;
- collocations.insert(entry);
+ index.insert(entry);
}
}
BOOST_SERIALIZATION_SPLIT_MEMBER();
- Index collocations;
+ Index index;
};
} // namespace extractor
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index e81ece5d..3a98ce05 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -9,6 +9,7 @@
#include "mocks/mock_data_array.h"
#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_vocabulary.h"
#include "precomputation.h"
using namespace std;
@@ -23,7 +24,12 @@ class PrecomputationTest : public Test {
virtual void SetUp() {
data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+ }
+ EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2"));
+ EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3"));
vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
@@ -35,77 +41,98 @@ class PrecomputationTest : public Test {
}
EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
- precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3));
+
+ precomputation = Precomputation(vocabulary, suffix_array,
+ 3, 3, 10, 5, 1, 4, 2);
}
vector<int> data;
shared_ptr<MockDataArray> data_array;
shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<MockVocabulary> vocabulary;
Precomputation precomputation;
};
TEST_F(PrecomputationTest, TestCollocations) {
- Index collocations = precomputation.GetCollocations();
-
vector<int> key = {2, 3, -1, 2};
vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, 3, -1, 2, 3};
expected_value = {1, 5, 1, 8, 5, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, 3, -1, 3};
expected_value = {1, 6, 1, 9, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2};
expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3};
expected_value = {2, 6, 2, 9, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, 3};
expected_value = {2, 5, 2, 8, 6, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2};
expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, 3};
expected_value = {1, 5, 1, 8, 5, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3};
expected_value = {1, 6, 1, 9, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, -2, 2};
expected_value = {1, 5, 8, 5, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, -2, 3};
expected_value = {1, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3, -2, 2};
expected_value = {1, 6, 8, 5, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3, -2, 3};
expected_value = {1, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, -2, 2};
expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, -2, 3};
expected_value = {2, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3, -2, 2};
expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3, -2, 3};
expected_value = {2, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
// Exceeds max_rule_symbols.
key = {2, -1, 2, -2, 2, 3};
- EXPECT_EQ(0, collocations.count(key));
+ EXPECT_FALSE(precomputation.Contains(key));
// Contains non frequent pattern.
key = {2, -1, 5};
- EXPECT_EQ(0, collocations.count(key));
+ EXPECT_FALSE(precomputation.Contains(key));
}
TEST_F(PrecomputationTest, TestSerialization) {
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index df63a9d8..a1ff76e4 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -72,7 +72,10 @@ class HieroCachingRuleFactory {
// Constructs SCFG rules for a given sentence.
// (See class description for more details.)
- virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+ virtual Grammar GetGrammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids,
+ const shared_ptr<DataArray> source_data_array);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 6eb55073..85c8a422 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -28,6 +28,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "translation_table.h"
+#include "vocabulary.h"
namespace fs = boost::filesystem;
namespace po = boost::program_options;
@@ -142,11 +143,14 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+
// Constructs an index storing the occurrences in the source data for each
// frequent collocation.
start_time = Clock::now();
cerr << "Precomputing collocations..." << endl;
shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
+ vocabulary,
source_suffix_array,
vm["frequent"].as<int>(),
vm["super_frequent"].as<int>(),
@@ -194,6 +198,7 @@ int main(int argc, char** argv) {
alignment,
precomputation,
scorer,
+ vocabulary,
vm["min_gap_size"].as<int>(),
vm["max_rule_span"].as<int>(),
vm["max_nonterminals"].as<int>(),
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index 0cf4d1f6..4a514b12 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -10,7 +10,6 @@
#include "phrase_location.h"
#include "time_util.h"
-namespace fs = boost::filesystem;
using namespace std;
using namespace chrono;
@@ -188,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const {
PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
int offset) const {
- if (!data_array->HasWord(word)) {
+ int word_id = data_array->GetWordId(word);
+ if (word_id == -1) {
// Return empty phrase location.
return PhraseLocation(0, 0);
}
- int word_id = data_array->GetWordId(word);
if (offset == 0) {
return PhraseLocation(word_start[word_id], word_start[word_id + 1]);
}
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
index 8ee454ec..df80c152 100644
--- a/extractor/suffix_array.h
+++ b/extractor/suffix_array.h
@@ -5,12 +5,10 @@
#include <string>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index ba0dbcc3..161edbc0 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -21,7 +21,7 @@ class SuffixArrayTest : public Test {
virtual void SetUp() {
data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};
data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
suffix_array = SuffixArray(data_array);
@@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) {
EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
}
- EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));
- EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
+ EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1));
EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));
- EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));
- EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));
- EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3));
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 1b1ba112..11e29e1e 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount(
double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
@@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore(
double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index 2a37bab7..97620727 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -5,14 +5,12 @@
#include <string>
#include <unordered_map>
-#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/utility.hpp>
using namespace std;
-namespace fs = boost::filesystem;
namespace extractor {
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index 606777bd..3cfc0011 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -28,7 +28,7 @@ class TranslationTableTest : public Test {
vector<int> source_sentence_start = {0, 6, 10, 14};
shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*source_data_array, GetData())
- .WillRepeatedly(ReturnRef(source_data));
+ .WillRepeatedly(Return(source_data));
EXPECT_CALL(*source_data_array, GetNumSentences())
.WillRepeatedly(Return(3));
for (size_t i = 0; i < source_sentence_start.size(); ++i) {
@@ -36,31 +36,25 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(source_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*source_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*source_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*source_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
vector<int> target_sentence_start = {0, 7, 10, 13};
shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*target_data_array, GetData())
- .WillRepeatedly(ReturnRef(target_data));
+ .WillRepeatedly(Return(target_data));
for (size_t i = 0; i < target_sentence_start.size(); ++i) {
EXPECT_CALL(*target_data_array, GetSentenceStart(i))
.WillRepeatedly(Return(target_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*target_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*target_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*target_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<pair<int, int>> links1 = {
make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
index 15795d1e..aef674a5 100644
--- a/extractor/vocabulary.cc
+++ b/extractor/vocabulary.cc
@@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) {
int word_id = -1;
#pragma omp critical (vocabulary)
{
- if (!dictionary.count(word)) {
+ auto it = dictionary.find(word);
+ if (it != dictionary.end()) {
+ word_id = it->second;
+ } else {
word_id = words.size();
dictionary[word] = word_id;
words.push_back(word);
- } else {
- word_id = dictionary[word];
}
}
return word_id;