summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-12-04 20:13:07 +0100
committerPatrick Simianer <p@simianer.de>2013-12-04 20:13:07 +0100
commit02647059daa297e7b2b3ca3a2c03d848ae3ad9f2 (patch)
treea5d2b5d66a8cff38a9422378c66861a3fb493e80 /extractor
parent7b2cd4e93114baaa329c483a98d6f7999aad1ba0 (diff)
parent7d391d8eb88d27c5637042fefe2d27e7b12f5587 (diff)
fix merge conflict
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am68
-rw-r--r--extractor/README.md10
-rw-r--r--extractor/alignment.cc2
-rw-r--r--extractor/alignment.h2
-rw-r--r--extractor/backoff_sampler.cc66
-rw-r--r--extractor/backoff_sampler.h41
-rw-r--r--extractor/compile.cc34
-rw-r--r--extractor/data_array.cc21
-rw-r--r--extractor/data_array.h15
-rw-r--r--extractor/data_array_test.cc16
-rw-r--r--extractor/extract.cc253
-rw-r--r--extractor/fast_intersector.cc40
-rw-r--r--extractor/fast_intersector.h8
-rw-r--r--extractor/fast_intersector_test.cc10
-rw-r--r--extractor/grammar_extractor.cc11
-rw-r--r--extractor/grammar_extractor.h6
-rw-r--r--extractor/grammar_extractor_test.cc4
-rw-r--r--extractor/matchings_sampler.cc39
-rw-r--r--extractor/matchings_sampler.h31
-rw-r--r--extractor/matchings_sampler_test.cc118
-rw-r--r--extractor/mocks/mock_data_array.h5
-rw-r--r--extractor/mocks/mock_matchings_sampler.h15
-rw-r--r--extractor/mocks/mock_precomputation.h3
-rw-r--r--extractor/mocks/mock_rule_factory.h4
-rw-r--r--extractor/mocks/mock_sampler.h4
-rw-r--r--extractor/mocks/mock_suffix_array_sampler.h15
-rw-r--r--extractor/phrase_location.cc2
-rw-r--r--extractor/phrase_location_sampler.cc34
-rw-r--r--extractor/phrase_location_sampler.h35
-rw-r--r--extractor/phrase_location_sampler_test.cc50
-rw-r--r--extractor/precomputation.cc125
-rw-r--r--extractor/precomputation.h43
-rw-r--r--extractor/precomputation_test.cc73
-rw-r--r--extractor/rule_factory.cc11
-rw-r--r--extractor/rule_factory.h4
-rw-r--r--extractor/rule_factory_test.cc8
-rw-r--r--extractor/run_extractor.cc28
-rw-r--r--extractor/sampler.cc75
-rw-r--r--extractor/sampler.h24
-rw-r--r--extractor/sampler_test.cc80
-rw-r--r--extractor/sampler_test_blacklist.cc102
-rw-r--r--extractor/suffix_array.cc5
-rw-r--r--extractor/suffix_array.h2
-rw-r--r--extractor/suffix_array_sampler.cc40
-rw-r--r--extractor/suffix_array_sampler.h34
-rw-r--r--extractor/suffix_array_sampler_test.cc114
-rw-r--r--extractor/suffix_array_test.cc8
-rw-r--r--extractor/translation_table.cc14
-rw-r--r--extractor/translation_table.h2
-rw-r--r--extractor/translation_table_test.cc14
-rw-r--r--extractor/vocabulary.cc11
-rw-r--r--extractor/vocabulary.h23
-rw-r--r--extractor/vocabulary_test.cc45
53 files changed, 1298 insertions, 549 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 65a3d436..e5b439f9 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -1,5 +1,5 @@
-bin_PROGRAMS = compile run_extractor
+bin_PROGRAMS = compile run_extractor extract
if HAVE_CXX11
@@ -15,16 +15,19 @@ EXTRA_PROGRAMS = alignment_test \
feature_target_given_source_coherent_test \
grammar_extractor_test \
matchings_finder_test \
+ matchings_sampler_test \
+ phrase_location_sampler_test \
phrase_test \
precomputation_test \
rule_extractor_helper_test \
rule_extractor_test \
rule_factory_test \
- sampler_test \
scorer_test \
+ suffix_array_sampler_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
if HAVE_GTEST
RUNNABLE_TESTS = alignment_test \
@@ -39,16 +42,19 @@ if HAVE_GTEST
feature_target_given_source_coherent_test \
grammar_extractor_test \
matchings_finder_test \
+ matchings_sampler_test \
+ phrase_location_sampler_test \
phrase_test \
precomputation_test \
rule_extractor_helper_test \
rule_extractor_test \
rule_factory_test \
- sampler_test \
scorer_test \
+ suffix_array_sampler_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
endif
noinst_PROGRAMS = $(RUNNABLE_TESTS)
@@ -79,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc
grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
matchings_finder_test_SOURCES = matchings_finder_test.cc
matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+matchings_sampler_test_SOURCES = matchings_sampler_test.cc
+matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc
+phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
phrase_test_SOURCES = phrase_test.cc
phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
precomputation_test_SOURCES = precomputation_test.cc
@@ -89,57 +99,31 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc
rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
rule_factory_test_SOURCES = rule_factory_test.cc
rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
-sampler_test_SOURCES = sampler_test.cc
-sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
scorer_test_SOURCES = scorer_test.cc
scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc
+suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
suffix_array_test_SOURCES = suffix_array_test.cc
suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
translation_table_test_SOURCES = translation_table_test.cc
translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+vocabulary_test_SOURCES = vocabulary_test.cc
+vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
-noinst_LIBRARIES = libextractor.a libcompile.a
+noinst_LIBRARIES = libextractor.a
compile_SOURCES = compile.cc
-compile_LDADD = libcompile.a
+compile_LDADD = libextractor.a
run_extractor_SOURCES = run_extractor.cc
run_extractor_LDADD = libextractor.a
-
-libcompile_a_SOURCES = \
- alignment.cc \
- data_array.cc \
- phrase_location.cc \
- precomputation.cc \
- suffix_array.cc \
- time_util.cc \
- translation_table.cc \
- alignment.h \
- data_array.h \
- fast_intersector.h \
- grammar.h \
- grammar_extractor.h \
- matchings_finder.h \
- matchings_trie.h \
- phrase.h \
- phrase_builder.h \
- phrase_location.h \
- precomputation.h \
- rule.h \
- rule_extractor.h \
- rule_extractor_helper.h \
- rule_factory.h \
- sampler.h \
- scorer.h \
- suffix_array.h \
- target_phrase_extractor.h \
- time_util.h \
- translation_table.h \
- vocabulary.h
+extract_SOURCES = extract.cc
+extract_LDADD = libextractor.a
libextractor_a_SOURCES = \
alignment.cc \
+ backoff_sampler.cc \
data_array.cc \
fast_intersector.cc \
features/count_source_target.cc \
@@ -153,18 +137,20 @@ libextractor_a_SOURCES = \
grammar.cc \
grammar_extractor.cc \
matchings_finder.cc \
+ matchings_sampler.cc \
matchings_trie.cc \
phrase.cc \
phrase_builder.cc \
phrase_location.cc \
+ phrase_location_sampler.cc \
precomputation.cc \
rule.cc \
rule_extractor.cc \
rule_extractor_helper.cc \
rule_factory.cc \
- sampler.cc \
scorer.cc \
suffix_array.cc \
+ suffix_array_sampler.cc \
target_phrase_extractor.cc \
time_util.cc \
translation_table.cc \
diff --git a/extractor/README.md b/extractor/README.md
index c9db8de8..642fbd1d 100644
--- a/extractor/README.md
+++ b/extractor/README.md
@@ -1,8 +1,14 @@
C++ implementation of the online grammar extractor originally developed by [Adam Lopez](http://www.cs.jhu.edu/~alopez/).
-To run the extractor you need to:
+The grammar extraction takes place in two steps: (a) precomputing a number of data structures and (b) actually extracting the grammars. All the flags below have the same meaning as in the cython implementation.
- cdec/extractor/run_extractor -t <num_threads> -a <alignment> -b <parallel_corpus> -g <grammar_output_path> < <input_sentences> > <sgm_file>
+To compile the data structures you need to run:
+
+ cdec/extractor/compile -a <alignment> -b <parallel_corpus> -c <compile_config_file> -o <compile_directory>
+
+To extract the grammars you need to run:
+
+ cdec/extract/extract -t <num_threads> -c <compile_config_file> -g <grammar_output_path> < <input_sentencs> > <sgm_file>
To run unit tests you need first to configure `cdec` with the [Google Test](https://code.google.com/p/googletest/) and [Google Mock](https://code.google.com/p/googlemock/) libraries:
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index 2278c825..4a7a14f4 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -8,9 +8,7 @@
#include <vector>
#include <boost/algorithm/string.hpp>
-#include <boost/filesystem.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/alignment.h b/extractor/alignment.h
index dc5a8b55..76c27da2 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -4,13 +4,11 @@
#include <string>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/utility.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc
new file mode 100644
index 00000000..891276c6
--- /dev/null
+++ b/extractor/backoff_sampler.cc
@@ -0,0 +1,66 @@
+#include "backoff_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+BackoffSampler::BackoffSampler(
+ shared_ptr<DataArray> source_data_array, int max_samples) :
+ source_data_array(source_data_array), max_samples(max_samples) {}
+
+BackoffSampler::BackoffSampler() {}
+
+PhraseLocation BackoffSampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
+ vector<int> samples;
+ int low = GetRangeLow(location), high = GetRangeHigh(location);
+ int last = low - 1;
+ double step = max(1.0, (double) (high - low) / max_samples);
+ for (double num_samples = 0, i = low;
+ num_samples < max_samples && i < high;
+ ++num_samples, i += step) {
+ int sample = round(i);
+ int position = GetPosition(location, sample);
+ int sentence_id = source_data_array->GetSentenceId(position);
+ bool found = false;
+ if (last >= sample ||
+ blacklisted_sentence_ids.count(sentence_id)) {
+ for (double backoff_step = 1; backoff_step < step; ++backoff_step) {
+ double j = i - backoff_step;
+ sample = round(j);
+ if (sample >= 0) {
+ position = GetPosition(location, sample);
+ sentence_id = source_data_array->GetSentenceId(position);
+ if (sample > last && !blacklisted_sentence_ids.count(sentence_id)) {
+ found = true;
+ break;
+ }
+ }
+
+ double k = i + backoff_step;
+ sample = round(k);
+ if (sample < high) {
+ position = GetPosition(location, sample);
+ sentence_id = source_data_array->GetSentenceId(position);
+ if (!blacklisted_sentence_ids.count(sentence_id)) {
+ found = true;
+ break;
+ }
+ }
+ }
+ } else {
+ found = true;
+ }
+
+ if (found) {
+ last = sample;
+ AppendMatching(samples, sample, location);
+ }
+ }
+
+ return PhraseLocation(samples, GetNumSubpatterns(location));
+}
+
+} // namespace extractor
diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h
new file mode 100644
index 00000000..5c244105
--- /dev/null
+++ b/extractor/backoff_sampler.h
@@ -0,0 +1,41 @@
+#ifndef _BACKOFF_SAMPLER_H_
+#define _BACKOFF_SAMPLER_H_
+
+#include <vector>
+
+#include "sampler.h"
+
+namespace extractor {
+
+class DataArray;
+class PhraseLocation;
+
+class BackoffSampler : public Sampler {
+ public:
+ BackoffSampler(shared_ptr<DataArray> source_data_array, int max_samples);
+
+ BackoffSampler();
+
+ PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
+
+ private:
+ virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0;
+
+ virtual int GetRangeLow(const PhraseLocation& location) const = 0;
+
+ virtual int GetRangeHigh(const PhraseLocation& location) const = 0;
+
+ virtual int GetPosition(const PhraseLocation& location, int index) const = 0;
+
+ virtual void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const = 0;
+
+ shared_ptr<DataArray> source_data_array;
+ int max_samples;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/compile.cc b/extractor/compile.cc
index 65fdd509..3ee668ce 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -13,6 +13,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "translation_table.h"
+#include "vocabulary.h"
namespace ar = boost::archive;
namespace fs = boost::filesystem;
@@ -29,6 +30,8 @@ int main(int argc, char** argv) {
("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
("alignment,a", po::value<string>()->required(), "Bitext word alignment")
("output,o", po::value<string>()->required(), "Output path")
+ ("config,c", po::value<string>()->required(),
+ "Path where the config file will be generated")
("frequent", po::value<int>()->default_value(100),
"Number of precomputed frequent patterns")
("super_frequent", po::value<int>()->default_value(10),
@@ -81,8 +84,12 @@ int main(int argc, char** argv) {
target_data_array = make_shared<DataArray>(vm["target"].as<string>());
}
+ ofstream config_stream(vm["config"].as<string>());
+
Clock::time_point start_write = Clock::now();
- ofstream target_fstream((output_dir / fs::path("target.bin")).string());
+ string target_path = (output_dir / fs::path("target.bin")).string();
+ config_stream << "target = " << target_path << endl;
+ ofstream target_fstream(target_path);
ar::binary_oarchive target_stream(target_fstream);
target_stream << *target_data_array;
Clock::time_point stop_write = Clock::now();
@@ -99,7 +106,9 @@ int main(int argc, char** argv) {
make_shared<SuffixArray>(source_data_array);
start_write = Clock::now();
- ofstream source_fstream((output_dir / fs::path("source.bin")).string());
+ string source_path = (output_dir / fs::path("source.bin")).string();
+ config_stream << "source = " << source_path << endl;
+ ofstream source_fstream(source_path);
ar::binary_oarchive output_stream(source_fstream);
output_stream << *source_suffix_array;
stop_write = Clock::now();
@@ -115,7 +124,9 @@ int main(int argc, char** argv) {
make_shared<Alignment>(vm["alignment"].as<string>());
start_write = Clock::now();
- ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string());
+ string alignment_path = (output_dir / fs::path("alignment.bin")).string();
+ config_stream << "alignment = " << alignment_path << endl;
+ ofstream alignment_fstream(alignment_path);
ar::binary_oarchive alignment_stream(alignment_fstream);
alignment_stream << *alignment;
stop_write = Clock::now();
@@ -125,9 +136,12 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+
start_time = Clock::now();
cerr << "Precomputing collocations..." << endl;
Precomputation precomputation(
+ vocabulary,
source_suffix_array,
vm["frequent"].as<int>(),
vm["super_frequent"].as<int>(),
@@ -138,9 +152,17 @@ int main(int argc, char** argv) {
vm["min_frequency"].as<int>());
start_write = Clock::now();
- ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string());
+ string precomputation_path = (output_dir / fs::path("precomp.bin")).string();
+ config_stream << "precomputation = " << precomputation_path << endl;
+ ofstream precomp_fstream(precomputation_path);
ar::binary_oarchive precomp_stream(precomp_fstream);
precomp_stream << precomputation;
+
+ string vocabulary_path = (output_dir / fs::path("vocab.bin")).string();
+ config_stream << "vocabulary = " << vocabulary_path << endl;
+ ofstream vocab_fstream(vocabulary_path);
+ ar::binary_oarchive vocab_stream(vocab_fstream);
+ vocab_stream << *vocabulary;
stop_write = Clock::now();
write_duration += GetDuration(start_write, stop_write);
@@ -153,7 +175,9 @@ int main(int argc, char** argv) {
TranslationTable table(source_data_array, target_data_array, alignment);
start_write = Clock::now();
- ofstream table_fstream((output_dir / fs::path("bilex.bin")).string());
+ string table_path = (output_dir / fs::path("bilex.bin")).string();
+ config_stream << "ttable = " << table_path << endl;
+ ofstream table_fstream(table_path);
ar::binary_oarchive table_stream(table_fstream);
table_stream << table;
stop_write = Clock::now();
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 2e4bdafb..9612aa8a 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -5,9 +5,6 @@
#include <sstream>
#include <string>
-#include <boost/filesystem.hpp>
-
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -81,7 +78,7 @@ void DataArray::CreateDataArray(const vector<string>& lines) {
DataArray::~DataArray() {}
-const vector<int>& DataArray::GetData() const {
+vector<int> DataArray::GetData() const {
return data;
}
@@ -93,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const {
return id2word[data[index]];
}
+vector<int> DataArray::GetWordIds(int index, int size) const {
+ return vector<int>(data.begin() + index, data.begin() + index + size);
+}
+
+vector<string> DataArray::GetWords(int start_index, int size) const {
+ vector<string> words;
+ for (int word_id: GetWordIds(start_index, size)) {
+ words.push_back(id2word[word_id]);
+ }
+ return words;
+}
+
int DataArray::GetSize() const {
return data.size();
}
@@ -118,10 +127,6 @@ int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
-bool DataArray::HasWord(const string& word) const {
- return word2id.count(word);
-}
-
int DataArray::GetWordId(const string& word) const {
auto result = word2id.find(word);
return result == word2id.end() ? -1 : result->second;
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 2be6a09c..b96901d1 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -5,13 +5,11 @@
#include <unordered_map>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/string.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -53,7 +51,7 @@ class DataArray {
virtual ~DataArray();
// Returns a vector containing the word ids.
- virtual const vector<int>& GetData() const;
+ virtual vector<int> GetData() const;
// Returns the word id at the specified position.
virtual int AtIndex(int index) const;
@@ -61,15 +59,20 @@ class DataArray {
// Returns the original word at the specified position.
virtual string GetWordAtIndex(int index) const;
+ // Returns the substring of word ids starting at the specified position and
+ // having the specified length.
+ virtual vector<int> GetWordIds(int start_index, int size) const;
+
+ // Returns the substring of words starting at the specified position and
+ // having the specified length.
+ virtual vector<string> GetWords(int start_index, int size) const;
+
// Returns the size of the data array.
virtual int GetSize() const;
// Returns the number of distinct words in the data array.
virtual int GetVocabularySize() const;
- // Returns whether a word has ever been observed in the data array.
- virtual bool HasWord(const string& word) const;
-
// Returns the word id for a given word or -1 if it the word has never been
// observed.
virtual int GetWordId(const string& word) const;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 6c329e34..99f79d91 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -56,18 +56,26 @@ TEST_F(DataArrayTest, TestGetData) {
}
}
+TEST_F(DataArrayTest, TestSubstrings) {
+ vector<int> expected_word_ids = {3, 4, 5};
+ vector<string> expected_words = {"are", "mere", "."};
+ EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3));
+ EXPECT_EQ(expected_words, source_data.GetWords(1, 3));
+
+ expected_word_ids = {7, 8};
+ expected_words = {"a", "lot"};
+ EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2));
+ EXPECT_EQ(expected_words, target_data.GetWords(7, 2));
+}
+
TEST_F(DataArrayTest, TestVocabulary) {
EXPECT_EQ(9, source_data.GetVocabularySize());
- EXPECT_TRUE(source_data.HasWord("mere"));
EXPECT_EQ(4, source_data.GetWordId("mere"));
EXPECT_EQ("mere", source_data.GetWord(4));
- EXPECT_FALSE(source_data.HasWord("banane"));
EXPECT_EQ(11, target_data.GetVocabularySize());
- EXPECT_TRUE(target_data.HasWord("apples"));
EXPECT_EQ(4, target_data.GetWordId("apples"));
EXPECT_EQ("apples", target_data.GetWord(4));
- EXPECT_FALSE(target_data.HasWord("bananas"));
}
TEST_F(DataArrayTest, TestSentenceData) {
diff --git a/extractor/extract.cc b/extractor/extract.cc
new file mode 100644
index 00000000..387cbe9b
--- /dev/null
+++ b/extractor/extract.cc
@@ -0,0 +1,253 @@
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "translation_table.h"
+#include "vocabulary.h"
+
+namespace ar = boost::archive;
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace extractor;
+using namespace features;
+using namespace std;
+
+// Returns the file path in which a given grammar should be written.
+fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
+ string file_name = "grammar." + to_string(file_number);
+ return grammar_path / file_name;
+}
+
+int main(int argc, char** argv) {
+ po::options_description general_options("General options");
+ int max_threads = 1;
+ #pragma omp parallel
+ max_threads = omp_get_num_threads();
+ string threads_option = "Number of threads used for grammar extraction "
+ "max(" + to_string(max_threads) + ")";
+ general_options.add_options()
+ ("threads,t", po::value<int>()->required()->default_value(1),
+ threads_option.c_str())
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("max_rule_span", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_nonterminals", po::value<int>()->default_value(2),
+ "Maximum number of nonterminals in a rule")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
+
+ po::options_description cmdline_options("Command line options");
+ cmdline_options.add_options()
+ ("help", "Show available options")
+ ("config,c", po::value<string>()->required(), "Path to config file");
+ cmdline_options.add(general_options);
+
+ po::options_description config_options("Config file options");
+ config_options.add_options()
+ ("target", po::value<string>()->required(),
+ "Path to target data file in binary format")
+ ("source", po::value<string>()->required(),
+ "Path to source suffix array file in binary format")
+ ("alignment", po::value<string>()->required(),
+ "Path to alignment file in binary format")
+ ("precomputation", po::value<string>()->required(),
+ "Path to precomputation file in binary format")
+ ("vocabulary", po::value<string>()->required(),
+ "Path to vocabulary file in binary format")
+ ("ttable", po::value<string>()->required(),
+ "Path to translation table in binary format");
+ config_options.add(general_options);
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
+ if (vm.count("help")) {
+ po::options_description all_options;
+ all_options.add(cmdline_options).add(config_options);
+ cout << all_options << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ ifstream config_stream(vm["config"].as<string>());
+ po::store(po::parse_config_file(config_stream, config_options), vm);
+ po::notify(vm);
+
+ int num_threads = vm["threads"].as<int>();
+ cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
+
+ Clock::time_point read_start_time = Clock::now();
+
+ Clock::time_point start_time = Clock::now();
+ cerr << "Reading target data in binary format..." << endl;
+ shared_ptr<DataArray> target_data_array = make_shared<DataArray>();
+ ifstream target_fstream(vm["target"].as<string>());
+ ar::binary_iarchive target_stream(target_fstream);
+ target_stream >> *target_data_array;
+ Clock::time_point end_time = Clock::now();
+ cerr << "Reading target data took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading source suffix array in binary format..." << endl;
+ shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>();
+ ifstream source_fstream(vm["source"].as<string>());
+ ar::binary_iarchive source_stream(source_fstream);
+ source_stream >> *source_suffix_array;
+ end_time = Clock::now();
+ cerr << "Reading source suffix array took "
+ << GetDuration(start_time, end_time) << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading alignment in binary format..." << endl;
+ shared_ptr<Alignment> alignment = make_shared<Alignment>();
+ ifstream alignment_fstream(vm["alignment"].as<string>());
+ ar::binary_iarchive alignment_stream(alignment_fstream);
+ alignment_stream >> *alignment;
+ end_time = Clock::now();
+ cerr << "Reading alignment took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading precomputation in binary format..." << endl;
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>();
+ ifstream precomputation_fstream(vm["precomputation"].as<string>());
+ ar::binary_iarchive precomputation_stream(precomputation_fstream);
+ precomputation_stream >> *precomputation;
+ end_time = Clock::now();
+ cerr << "Reading precomputation took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading vocabulary in binary format..." << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+ ifstream vocabulary_fstream(vm["vocabulary"].as<string>());
+ ar::binary_iarchive vocabulary_stream(vocabulary_fstream);
+ vocabulary_stream >> *vocabulary;
+ end_time = Clock::now();
+ cerr << "Reading vocabulary took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading translation table in binary format..." << endl;
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>();
+ ifstream ttable_fstream(vm["ttable"].as<string>());
+ ar::binary_iarchive ttable_stream(ttable_fstream);
+ ttable_stream >> *table;
+ end_time = Clock::now();
+ cerr << "Reading translation table took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ Clock::time_point read_end_time = Clock::now();
+ cerr << "Total time spent loading data structures into memory: "
+ << GetDuration(read_start_time, read_end_time) << " seconds" << endl;
+
+ Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
+ vector<shared_ptr<Feature>> features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+
+ GrammarExtractor extractor(
+ source_suffix_array,
+ target_data_array,
+ alignment,
+ precomputation,
+ scorer,
+ vocabulary,
+ vm["min_gap_size"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_nonterminals"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["max_samples"].as<int>(),
+ vm["tight_phrases"].as<bool>());
+
+ // Creates the grammars directory if it doesn't exist.
+ fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
+ string sentence;
+ vector<string> sentences;
+ while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ // Extracts the grammar for each sentence and saves it to a file.
+ vector<string> suffixes(sentences.size());
+ bool leave_one_out = vm.count("leave_one_out");
+ #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string suffix;
+ int position = sentences[i].find("|||");
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
+ }
+ suffixes[i] = suffix;
+
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
+ ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
+ output << grammar;
+ }
+
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""
+ << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
+ }
+
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
+
+ return 0;
+}
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
index a8591a72..0d1fa6d8 100644
--- a/extractor/fast_intersector.cc
+++ b/extractor/fast_intersector.cc
@@ -11,41 +11,22 @@
namespace extractor {
-FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array,
- shared_ptr<Precomputation> precomputation,
- shared_ptr<Vocabulary> vocabulary,
- int max_rule_span,
- int min_gap_size) :
+FastIntersector::FastIntersector(
+ shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size) :
suffix_array(suffix_array),
+ precomputation(precomputation),
vocabulary(vocabulary),
max_rule_span(max_rule_span),
- min_gap_size(min_gap_size) {
- Index precomputed_collocations = precomputation->GetCollocations();
- for (pair<vector<int>, vector<int>> entry: precomputed_collocations) {
- vector<int> phrase = ConvertPhrase(entry.first);
- collocations[phrase] = entry.second;
- }
-}
+ min_gap_size(min_gap_size) {}
FastIntersector::FastIntersector() {}
FastIntersector::~FastIntersector() {}
-vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) {
- vector<int> new_phrase;
- new_phrase.reserve(old_phrase.size());
- shared_ptr<DataArray> data_array = suffix_array->GetData();
- for (int word_id: old_phrase) {
- if (word_id < 0) {
- new_phrase.push_back(word_id);
- } else {
- new_phrase.push_back(
- vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
- }
- }
- return new_phrase;
-}
-
PhraseLocation FastIntersector::Intersect(
PhraseLocation& prefix_location,
PhraseLocation& suffix_location,
@@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect(
assert(vocabulary->IsTerminal(symbols.front())
&& vocabulary->IsTerminal(symbols.back()));
- if (collocations.count(symbols)) {
- return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
+ if (precomputation->Contains(symbols)) {
+ return PhraseLocation(precomputation->GetCollocations(symbols),
+ phrase.Arity() + 1);
}
bool prefix_ends_with_x =
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
index 2819d239..305373dc 100644
--- a/extractor/fast_intersector.h
+++ b/extractor/fast_intersector.h
@@ -12,7 +12,6 @@ using namespace std;
namespace extractor {
typedef boost::hash<vector<int>> VectorHash;
-typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
class Phrase;
class PhraseLocation;
@@ -52,11 +51,6 @@ class FastIntersector {
FastIntersector();
private:
- // Uses the vocabulary to convert the phrase from the numberized format
- // specified by the source data array to the numberized format given by the
- // vocabulary.
- vector<int> ConvertPhrase(const vector<int>& old_phrase);
-
// Estimates the number of computations needed if the prefix/suffix is
// extended. If the last/first symbol is separated from the rest of the phrase
// by a nonterminal, then for each occurrence of the prefix/suffix we need to
@@ -85,10 +79,10 @@ class FastIntersector {
pair<int, int> GetSearchRange(bool has_marginal_x) const;
shared_ptr<SuffixArray> suffix_array;
+ shared_ptr<Precomputation> precomputation;
shared_ptr<Vocabulary> vocabulary;
int max_rule_span;
int min_gap_size;
- Index collocations;
};
} // namespace extractor
diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc
index 76c3aaea..f2a26ba1 100644
--- a/extractor/fast_intersector_test.cc
+++ b/extractor/fast_intersector_test.cc
@@ -59,15 +59,13 @@ class FastIntersectorTest : public Test {
}
precomputation = make_shared<MockPrecomputation>();
- EXPECT_CALL(*precomputation, GetCollocations())
- .WillRepeatedly(ReturnRef(collocations));
+ EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false));
phrase_builder = make_shared<PhraseBuilder>(vocabulary);
intersector = make_shared<FastIntersector>(suffix_array, precomputation,
vocabulary, 15, 1);
}
- Index collocations;
shared_ptr<MockDataArray> data_array;
shared_ptr<MockSuffixArray> suffix_array;
shared_ptr<MockPrecomputation> precomputation;
@@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) {
Phrase phrase = phrase_builder->Build(symbols);
PhraseLocation prefix_location(15, 16), suffix_location(16, 17);
- collocations[symbols] = expected_location;
- EXPECT_CALL(*precomputation, GetCollocations())
- .WillRepeatedly(ReturnRef(collocations));
+ EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true));
+ EXPECT_CALL(*precomputation, GetCollocations(symbols)).
+ WillRepeatedly(Return(expected_location));
intersector = make_shared<FastIntersector>(suffix_array, precomputation,
vocabulary, 15, 1);
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 487abcaf..1dc94c25 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
- shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
+ shared_ptr<Scorer> scorer, shared_ptr<Vocabulary> vocabulary,
+ int min_gap_size, int max_rule_span,
int max_nonterminals, int max_rule_symbols, int max_samples,
bool require_tight_phrases) :
- vocabulary(make_shared<Vocabulary>()),
+ vocabulary(vocabulary),
rule_factory(make_shared<HieroCachingRuleFactory>(
source_suffix_array, target_data_array, alignment, vocabulary,
precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
@@ -34,10 +35,12 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar GrammarExtractor::GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index ae407b47..0f3069b0 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -15,7 +15,6 @@ class DataArray;
class Grammar;
class HieroCachingRuleFactory;
class Precomputation;
-class Rule;
class Scorer;
class SuffixArray;
class Vocabulary;
@@ -32,6 +31,7 @@ class GrammarExtractor {
shared_ptr<Alignment> alignment,
shared_ptr<Precomputation> precomputation,
shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
@@ -45,7 +45,9 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+ Grammar GetGrammar(
+ const string& sentence,
+ const unordered_set<int>& blacklisted_sentence_ids);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index f32a9599..719e90ff 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
Grammar grammar(rules, feature_names);
unordered_set<int> blacklisted_sentence_ids;
shared_ptr<DataArray> source_data_array;
- EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids);
}
} // namespace
diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc
new file mode 100644
index 00000000..75a62366
--- /dev/null
+++ b/extractor/matchings_sampler.cc
@@ -0,0 +1,39 @@
+#include "matchings_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+MatchingsSampler::MatchingsSampler(
+ shared_ptr<DataArray> data_array, int max_samples) :
+ BackoffSampler(data_array, max_samples) {}
+
+MatchingsSampler::MatchingsSampler() {}
+
+int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const {
+ return location.num_subpatterns;
+}
+
+int MatchingsSampler::GetRangeLow(const PhraseLocation&) const {
+ return 0;
+}
+
+int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const {
+ return location.matchings->size() / location.num_subpatterns;
+}
+
+int MatchingsSampler::GetPosition(const PhraseLocation& location,
+ int index) const {
+ return (*location.matchings)[index * location.num_subpatterns];
+}
+
+void MatchingsSampler::AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const {
+ int start = index * location.num_subpatterns;
+ copy(location.matchings->begin() + start,
+ location.matchings->begin() + start + location.num_subpatterns,
+ back_inserter(samples));
+}
+
+} // namespace extractor
diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h
new file mode 100644
index 00000000..ca4fce93
--- /dev/null
+++ b/extractor/matchings_sampler.h
@@ -0,0 +1,31 @@
+#ifndef _MATCHINGS_SAMPLER_H_
+#define _MATCHINGS_SAMPLER_H_
+
+#include "backoff_sampler.h"
+
+namespace extractor {
+
+class DataArray;
+
+class MatchingsSampler : public BackoffSampler {
+ public:
+ MatchingsSampler(shared_ptr<DataArray> data_array, int max_samples);
+
+ MatchingsSampler();
+
+ private:
+ int GetNumSubpatterns(const PhraseLocation& location) const;
+
+ int GetRangeLow(const PhraseLocation& location) const;
+
+ int GetRangeHigh(const PhraseLocation& location) const;
+
+ int GetPosition(const PhraseLocation& location, int index) const;
+
+ void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc
new file mode 100644
index 00000000..bc927152
--- /dev/null
+++ b/extractor/matchings_sampler_test.cc
@@ -0,0 +1,118 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_data_array.h"
+#include "matchings_sampler.h"
+#include "phrase_location.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class MatchingsSamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ location = PhraseLocation(locations, 2);
+
+ data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2));
+ }
+ }
+
+ unordered_set<int> blacklisted_sentence_ids;
+ PhraseLocation location;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MatchingsSampler> sampler;
+};
+
+TEST_F(MatchingsSamplerTest, TestSample) {
+ sampler = make_shared<MatchingsSampler>(data_array, 1);
+ vector<int> expected_locations = {0, 1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 2);
+ expected_locations = {0, 1, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 3);
+ expected_locations = {0, 1, 4, 5, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 7);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(MatchingsSamplerTest, TestBackoffSample) {
+ sampler = make_shared<MatchingsSampler>(data_array, 1);
+ blacklisted_sentence_ids = {0};
+ vector<int> expected_locations = {2, 3};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4};
+ expected_locations = {};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 2);
+ blacklisted_sentence_ids = {0, 3};
+ expected_locations = {2, 3, 4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 3);
+ blacklisted_sentence_ids = {0, 3};
+ expected_locations = {2, 3, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 2, 3};
+ expected_locations = {2, 3, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 4);
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {1, 3};
+ expected_locations = {0, 1, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ sampler = make_shared<MatchingsSampler>(data_array, 7);
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4};
+ expected_locations = {};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 2, 4};
+ expected_locations = {2, 3, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {1, 3};
+ expected_locations = {0, 1, 4, 5, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2),
+ sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index 6f85abb4..98e711d2 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -6,12 +6,13 @@ namespace extractor {
class MockDataArray : public DataArray {
public:
- MOCK_CONST_METHOD0(GetData, const vector<int>&());
+ MOCK_CONST_METHOD0(GetData, vector<int>());
MOCK_CONST_METHOD1(AtIndex, int(int index));
MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
+ MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size));
+ MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
- MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
MOCK_CONST_METHOD1(GetWord, string(int word_id));
MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h
new file mode 100644
index 00000000..de2009c3
--- /dev/null
+++ b/extractor/mocks/mock_matchings_sampler.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "phrase_location.h"
+#include "matchings_sampler.h"
+
+namespace extractor {
+
+class MockMatchingsSampler : public MatchingsSampler {
+ public:
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h
index 8753343e..5f7aa999 100644
--- a/extractor/mocks/mock_precomputation.h
+++ b/extractor/mocks/mock_precomputation.h
@@ -6,7 +6,8 @@ namespace extractor {
class MockPrecomputation : public Precomputation {
public:
- MOCK_CONST_METHOD0(GetCollocations, const Index&());
+ MOCK_CONST_METHOD1(Contains, bool(const vector<int>& pattern));
+ MOCK_CONST_METHOD1(GetCollocations, vector<int>(const vector<int>& pattern));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 86a084b5..53eb5022 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,7 +7,9 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array));
+ MOCK_METHOD2(GetGrammar, Grammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
index 75c43c27..b2742f62 100644
--- a/extractor/mocks/mock_sampler.h
+++ b/extractor/mocks/mock_sampler.h
@@ -7,7 +7,9 @@ namespace extractor {
class MockSampler : public Sampler {
public:
- MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
};
} // namespace extractor
diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h
new file mode 100644
index 00000000..d799b969
--- /dev/null
+++ b/extractor/mocks/mock_suffix_array_sampler.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "phrase_location.h"
+#include "suffix_array_sampler.h"
+
+namespace extractor {
+
+class MockSuffixArraySampler : public SuffixArrayRangeSampler {
+ public:
+ MOCK_CONST_METHOD2(Sample, PhraseLocation(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids));
+};
+
+} // namespace extractor
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
index 13140cac..2c367893 100644
--- a/extractor/phrase_location.cc
+++ b/extractor/phrase_location.cc
@@ -1,5 +1,7 @@
#include "phrase_location.h"
+#include <iostream>
+
namespace extractor {
PhraseLocation::PhraseLocation(int sa_low, int sa_high) :
diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc
new file mode 100644
index 00000000..a2eec105
--- /dev/null
+++ b/extractor/phrase_location_sampler.cc
@@ -0,0 +1,34 @@
+#include "phrase_location_sampler.h"
+
+#include "matchings_sampler.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+#include "suffix_array_sampler.h"
+
+namespace extractor {
+
+PhraseLocationSampler::PhraseLocationSampler(
+ shared_ptr<SuffixArray> suffix_array, int max_samples) {
+ matchings_sampler = make_shared<MatchingsSampler>(
+ suffix_array->GetData(), max_samples);
+ suffix_array_sampler = make_shared<SuffixArrayRangeSampler>(
+ suffix_array, max_samples);
+}
+
+PhraseLocationSampler::PhraseLocationSampler(
+ shared_ptr<MatchingsSampler> matchings_sampler,
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler) :
+ matchings_sampler(matchings_sampler),
+ suffix_array_sampler(suffix_array_sampler) {}
+
+PhraseLocation PhraseLocationSampler::Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const {
+ if (location.matchings == NULL) {
+ return suffix_array_sampler->Sample(location, blacklisted_sentence_ids);
+ } else {
+ return matchings_sampler->Sample(location, blacklisted_sentence_ids);
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h
new file mode 100644
index 00000000..0e88335e
--- /dev/null
+++ b/extractor/phrase_location_sampler.h
@@ -0,0 +1,35 @@
+#ifndef _PHRASE_LOCATION_SAMPLER_H_
+#define _PHRASE_LOCATION_SAMPLER_H_
+
+#include <memory>
+
+#include "sampler.h"
+
+namespace extractor {
+
+class MatchingsSampler;
+class PhraseLocation;
+class SuffixArray;
+class SuffixArrayRangeSampler;
+
+class PhraseLocationSampler : public Sampler {
+ public:
+ PhraseLocationSampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
+
+ // For testing only.
+ PhraseLocationSampler(
+ shared_ptr<MatchingsSampler> matchings_sampler,
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler);
+
+ PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const;
+
+ private:
+ shared_ptr<MatchingsSampler> matchings_sampler;
+ shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc
new file mode 100644
index 00000000..e7520ce7
--- /dev/null
+++ b/extractor/phrase_location_sampler_test.cc
@@ -0,0 +1,50 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_matchings_sampler.h"
+#include "mocks/mock_suffix_array_sampler.h"
+#include "phrase_location.h"
+#include "phrase_location_sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class MatchingsSamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ matchings_sampler = make_shared<MockMatchingsSampler>();
+ suffix_array_sampler = make_shared<MockSuffixArraySampler>();
+
+ sampler = make_shared<PhraseLocationSampler>(
+ matchings_sampler, suffix_array_sampler);
+ }
+
+ shared_ptr<MockMatchingsSampler> matchings_sampler;
+ shared_ptr<MockSuffixArraySampler> suffix_array_sampler;
+ shared_ptr<PhraseLocationSampler> sampler;
+};
+
+TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) {
+ vector<int> locations = {0, 1, 2, 3};
+ PhraseLocation location(0, 3), result(locations, 2);
+ unordered_set<int> blacklisted_sentence_ids;
+ EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids))
+ .WillOnce(Return(result));
+ EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(MatchingsSamplerTest, TestMatchings) {
+ vector<int> locations = {0, 1, 2, 3};
+ PhraseLocation location(locations, 2), result(locations, 2);
+ unordered_set<int> blacklisted_sentence_ids;
+ EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids))
+ .WillOnce(Return(result));
+ EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 3b8aed69..3e58e2a9 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -5,59 +5,67 @@
#include "data_array.h"
#include "suffix_array.h"
+#include "time_util.h"
+#include "vocabulary.h"
using namespace std;
namespace extractor {
-int Precomputation::FIRST_NONTERMINAL = -1;
-int Precomputation::SECOND_NONTERMINAL = -2;
-
Precomputation::Precomputation(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int num_super_frequent_patterns, int max_rule_span,
- int max_rule_symbols, int min_gap_size,
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+ int num_frequent_patterns, int num_super_frequent_patterns,
+ int max_rule_span, int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency) {
- vector<int> data = suffix_array->GetData()->GetData();
+ Clock::time_point start_time = Clock::now();
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ vector<int> data = data_array->GetData();
vector<vector<int>> frequent_patterns = FindMostFrequentPatterns(
suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
min_frequency);
+ Clock::time_point end_time = Clock::now();
+ cerr << "Finding most frequent patterns took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
- // Construct sets containing the frequent and superfrequent contiguous
- // collocations.
- unordered_set<vector<int>, VectorHash> frequent_patterns_set;
- unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
+ vector<vector<int>> pattern_annotations(frequent_patterns.size());
+ unordered_map<vector<int>, int, VectorHash> frequent_patterns_index;
for (size_t i = 0; i < frequent_patterns.size(); ++i) {
- frequent_patterns_set.insert(frequent_patterns[i]);
- if (i < num_super_frequent_patterns) {
- super_frequent_patterns_set.insert(frequent_patterns[i]);
- }
+ frequent_patterns_index[frequent_patterns[i]] = i;
+ pattern_annotations[i] = AnnotatePattern(vocabulary, data_array,
+ frequent_patterns[i]);
}
+ start_time = Clock::now();
vector<tuple<int, int, int>> matchings;
+ vector<vector<int>> annotations;
for (size_t i = 0; i < data.size(); ++i) {
// If the sentence is over, add all the discontiguous frequent patterns to
// the index.
if (data[i] == DataArray::END_OF_LINE) {
- AddCollocations(matchings, data, max_rule_span, min_gap_size,
- max_rule_symbols);
+ UpdateIndex(matchings, annotations, max_rule_span, min_gap_size,
+ max_rule_symbols);
matchings.clear();
+ annotations.clear();
continue;
}
- vector<int> pattern;
// Find all the contiguous frequent patterns starting at position i.
+ vector<int> pattern;
for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
pattern.push_back(data[i + j - 1]);
- if (frequent_patterns_set.count(pattern)) {
- int is_super_frequent = super_frequent_patterns_set.count(pattern);
- matchings.push_back(make_tuple(i, j, is_super_frequent));
- } else {
+ auto it = frequent_patterns_index.find(pattern);
+ if (it == frequent_patterns_index.end()) {
// If the current pattern is not frequent, any longer pattern having the
// current pattern as prefix will not be frequent.
break;
}
+ int is_super_frequent = it->second < num_super_frequent_patterns;
+ matchings.push_back(make_tuple(i, j, is_super_frequent));
+ annotations.push_back(pattern_annotations[it->second]);
}
}
+ end_time = Clock::now();
+ cerr << "Constructing collocations index took "
+ << GetDuration(start_time, end_time) << " seconds..." << endl;
}
Precomputation::Precomputation() {}
@@ -75,9 +83,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
for (size_t i = 1; i < lcp.size(); ++i) {
for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {
int frequency = i - run_start[len];
- if (frequency >= min_frequency) {
- heap.push(make_pair(frequency,
- make_pair(suffix_array->GetSuffix(run_start[len]), len + 1)));
+ int start = suffix_array->GetSuffix(run_start[len]);
+ if (frequency >= min_frequency && start + len <= data.size()) {
+ heap.push(make_pair(frequency, make_pair(start, len + 1)));
}
run_start[len] = i;
}
@@ -99,8 +107,20 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(
return frequent_patterns;
}
-void Precomputation::AddCollocations(
- const vector<tuple<int, int, int>>& matchings, const vector<int>& data,
+vector<int> Precomputation::AnnotatePattern(
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const {
+ vector<int> annotation;
+ for (int word_id: pattern) {
+ annotation.push_back(vocabulary->GetTerminalIndex(
+ data_array->GetWord(word_id)));
+ }
+ return annotation;
+}
+
+void Precomputation::UpdateIndex(
+ const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols) {
// Select the leftmost subpattern.
for (size_t i = 0; i < matchings.size(); ++i) {
@@ -118,16 +138,14 @@ void Precomputation::AddCollocations(
if (start2 - start1 - size1 >= min_gap_size
&& start2 + size2 - start1 <= max_rule_span
&& size1 + size2 + 1 <= max_rule_symbols) {
- vector<int> pattern(data.begin() + start1,
- data.begin() + start1 + size1);
- pattern.push_back(Precomputation::FIRST_NONTERMINAL);
- pattern.insert(pattern.end(), data.begin() + start2,
- data.begin() + start2 + size2);
- AddStartPositions(collocations[pattern], start1, start2);
+ vector<int> pattern = annotations[i];
+ pattern.push_back(-1);
+ AppendSubpattern(pattern, annotations[j]);
+ AppendCollocation(index[pattern], start1, start2);
// Try extending the binary collocation to a ternary collocation.
if (is_super2) {
- pattern.push_back(Precomputation::SECOND_NONTERMINAL);
+ pattern.push_back(-2);
// Select the rightmost subpattern.
for (size_t k = j + 1; k < matchings.size(); ++k) {
int start3, size3, is_super3;
@@ -140,9 +158,8 @@ void Precomputation::AddCollocations(
&& start3 + size3 - start1 <= max_rule_span
&& size1 + size2 + size3 + 2 <= max_rule_symbols
&& (is_super1 || is_super3)) {
- pattern.insert(pattern.end(), data.begin() + start3,
- data.begin() + start3 + size3);
- AddStartPositions(collocations[pattern], start1, start2, start3);
+ AppendSubpattern(pattern, annotations[k]);
+ AppendCollocation(index[pattern], start1, start2, start3);
pattern.erase(pattern.end() - size3);
}
}
@@ -152,25 +169,35 @@ void Precomputation::AddCollocations(
}
}
-void Precomputation::AddStartPositions(
- vector<int>& positions, int pos1, int pos2) {
- positions.push_back(pos1);
- positions.push_back(pos2);
+void Precomputation::AppendSubpattern(
+ vector<int>& pattern,
+ const vector<int>& subpattern) {
+ copy(subpattern.begin(), subpattern.end(), back_inserter(pattern));
+}
+
+void Precomputation::AppendCollocation(
+ vector<int>& collocations, int pos1, int pos2) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
+}
+
+void Precomputation::AppendCollocation(
+ vector<int>& collocations, int pos1, int pos2, int pos3) {
+ collocations.push_back(pos1);
+ collocations.push_back(pos2);
+ collocations.push_back(pos3);
}
-void Precomputation::AddStartPositions(
- vector<int>& positions, int pos1, int pos2, int pos3) {
- positions.push_back(pos1);
- positions.push_back(pos2);
- positions.push_back(pos3);
+bool Precomputation::Contains(const vector<int>& pattern) const {
+ return index.count(pattern);
}
-const Index& Precomputation::GetCollocations() const {
- return collocations;
+vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const {
+ return index.at(pattern);
}
bool Precomputation::operator==(const Precomputation& other) const {
- return collocations == other.collocations;
+ return index == other.index;
}
} // namespace extractor
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 9f0c9424..2b34fc29 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -7,13 +7,11 @@
#include <tuple>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/utility.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
@@ -21,7 +19,9 @@ namespace extractor {
typedef boost::hash<vector<int>> VectorHash;
typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
+class DataArray;
class SuffixArray;
+class Vocabulary;
/**
* Data structure wrapping an index with all the occurrences of the most
@@ -37,9 +37,9 @@ class Precomputation {
public:
// Constructs the index using the suffix array.
Precomputation(
- shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
- int num_super_frequent_patterns, int max_rule_span,
- int max_rule_symbols, int min_gap_size,
+ shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array,
+ int num_frequent_patterns, int num_super_frequent_patterns,
+ int max_rule_span, int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency);
// Creates empty precomputation data structure.
@@ -47,13 +47,13 @@ class Precomputation {
virtual ~Precomputation();
- // Returns a reference to the index.
- virtual const Index& GetCollocations() const;
+ // Returns whether a pattern is contained in the index of collocations.
+ virtual bool Contains(const vector<int>& pattern) const;
- bool operator==(const Precomputation& other) const;
+ // Returns the list of collocations for a given pattern.
+ virtual vector<int> GetCollocations(const vector<int>& pattern) const;
- static int FIRST_NONTERMINAL;
- static int SECOND_NONTERMINAL;
+ bool operator==(const Precomputation& other) const;
private:
// Finds the most frequent contiguous collocations.
@@ -62,25 +62,32 @@ class Precomputation {
int num_frequent_patterns, int max_frequent_phrase_len,
int min_frequency);
+ vector<int> AnnotatePattern(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<DataArray> data_array,
+ const vector<int>& pattern) const;
+
// Given the locations of the frequent contiguous collocations in a sentence,
// it adds new entries to the index for each discontiguous collocation
// matching the criteria specified in the class description.
- void AddCollocations(
- const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data,
+ void UpdateIndex(
+ const vector<tuple<int, int, int>>& matchings,
+ const vector<vector<int>>& annotations,
int max_rule_span, int min_gap_size, int max_rule_symbols);
+ void AppendSubpattern(vector<int>& pattern, const vector<int>& subpattern);
+
// Adds an occurrence of a binary collocation.
- void AddStartPositions(vector<int>& positions, int pos1, int pos2);
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2);
// Adds an occurrence of a ternary collocation.
- void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
+ void AppendCollocation(vector<int>& collocations, int pos1, int pos2, int pos3);
friend class boost::serialization::access;
template<class Archive> void save(Archive& ar, unsigned int) const {
- int num_entries = collocations.size();
+ int num_entries = index.size();
ar << num_entries;
- for (pair<vector<int>, vector<int>> entry: collocations) {
+ for (pair<vector<int>, vector<int>> entry: index) {
ar << entry;
}
}
@@ -91,13 +98,13 @@ class Precomputation {
for (size_t i = 0; i < num_entries; ++i) {
pair<vector<int>, vector<int>> entry;
ar >> entry;
- collocations.insert(entry);
+ index.insert(entry);
}
}
BOOST_SERIALIZATION_SPLIT_MEMBER();
- Index collocations;
+ Index index;
};
} // namespace extractor
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index e81ece5d..3a98ce05 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -9,6 +9,7 @@
#include "mocks/mock_data_array.h"
#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_vocabulary.h"
#include "precomputation.h"
using namespace std;
@@ -23,7 +24,12 @@ class PrecomputationTest : public Test {
virtual void SetUp() {
data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+ }
+ EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2"));
+ EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3"));
vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
@@ -35,77 +41,98 @@ class PrecomputationTest : public Test {
}
EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
- precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3));
+
+ precomputation = Precomputation(vocabulary, suffix_array,
+ 3, 3, 10, 5, 1, 4, 2);
}
vector<int> data;
shared_ptr<MockDataArray> data_array;
shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<MockVocabulary> vocabulary;
Precomputation precomputation;
};
TEST_F(PrecomputationTest, TestCollocations) {
- Index collocations = precomputation.GetCollocations();
-
vector<int> key = {2, 3, -1, 2};
vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, 3, -1, 2, 3};
expected_value = {1, 5, 1, 8, 5, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, 3, -1, 3};
expected_value = {1, 6, 1, 9, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2};
expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3};
expected_value = {2, 6, 2, 9, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, 3};
expected_value = {2, 5, 2, 8, 6, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2};
expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, 3};
expected_value = {1, 5, 1, 8, 5, 8};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3};
expected_value = {1, 6, 1, 9, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, -2, 2};
expected_value = {1, 5, 8, 5, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 2, -2, 3};
expected_value = {1, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3, -2, 2};
expected_value = {1, 6, 8, 5, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {2, -1, 3, -2, 3};
expected_value = {1, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, -2, 2};
expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 2, -2, 3};
expected_value = {2, 5, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3, -2, 2};
expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
key = {3, -1, 3, -2, 3};
expected_value = {2, 6, 9};
- EXPECT_EQ(expected_value, collocations[key]);
+ EXPECT_TRUE(precomputation.Contains(key));
+ EXPECT_EQ(expected_value, precomputation.GetCollocations(key));
// Exceeds max_rule_symbols.
key = {2, -1, 2, -2, 2, 3};
- EXPECT_EQ(0, collocations.count(key));
+ EXPECT_FALSE(precomputation.Contains(key));
// Contains non frequent pattern.
key = {2, -1, 5};
- EXPECT_EQ(0, collocations.count(key));
+ EXPECT_FALSE(precomputation.Contains(key));
}
TEST_F(PrecomputationTest, TestSerialization) {
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 6ae2d792..18a60695 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -12,6 +12,7 @@
#include "phrase_builder.h"
#include "rule.h"
#include "rule_extractor.h"
+#include "phrase_location_sampler.h"
#include "sampler.h"
#include "scorer.h"
#include "suffix_array.h"
@@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
target_data_array, alignment, phrase_builder, scorer, vocabulary,
max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
false, require_tight_phrases);
- sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+ sampler = make_shared<PhraseLocationSampler>(
+ source_suffix_array, max_samples);
}
HieroCachingRuleFactory::HieroCachingRuleFactory(
@@ -101,7 +103,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
+Grammar HieroCachingRuleFactory::GetGrammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -193,7 +197,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
+ PhraseLocation sample = sampler->Sample(
+ next_node->matchings, blacklisted_sentence_ids);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index df63a9d8..1a9fa2af 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -72,7 +72,9 @@ class HieroCachingRuleFactory {
// Constructs SCFG rules for a given sentence.
// (See class description for more details.)
- virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
+ virtual Grammar GetGrammar(
+ const vector<int>& word_ids,
+ const unordered_set<int>& blacklisted_sentence_ids);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index f26cc567..332c5959 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {
.WillRepeatedly(Return(feature_names));
sampler = make_shared<MockSampler>();
- EXPECT_CALL(*sampler, Sample(_))
+ EXPECT_CALL(*sampler, Sample(_, _))
.WillRepeatedly(Return(PhraseLocation(0, 1)));
Phrase phrase;
@@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
vector<int> word_ids = {2, 3, 4};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
vector<int> word_ids = {2, 3, 4, 2, 3};
unordered_set<int> blacklisted_sentence_ids;
- shared_ptr<DataArray> source_data_array;
- Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 6f59f0b6..f1aa5e35 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -5,10 +5,10 @@
#include <string>
#include <vector>
-#include <omp.h>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
#include "alignment.h"
#include "data_array.h"
@@ -28,6 +28,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "translation_table.h"
+#include "vocabulary.h"
namespace fs = boost::filesystem;
namespace po = boost::program_options;
@@ -77,7 +78,8 @@ int main(int argc, char** argv) {
("tight_phrases", po::value<bool>()->default_value(true),
"False if phrases may be loose (better, but slower)")
("leave_one_out", po::value<bool>()->zero_tokens(),
- "do leave-one-out estimation (e.g. for extracting grammars for the training set)");
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -98,11 +100,6 @@ int main(int argc, char** argv) {
return 1;
}
- bool leave_one_out = false;
- if (vm.count("leave_one_out")) {
- leave_one_out = true;
- }
-
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -142,11 +139,14 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+
// Constructs an index storing the occurrences in the source data for each
// frequent collocation.
start_time = Clock::now();
cerr << "Precomputing collocations..." << endl;
shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
+ vocabulary,
source_suffix_array,
vm["frequent"].as<int>(),
vm["super_frequent"].as<int>(),
@@ -174,8 +174,8 @@ int main(int argc, char** argv) {
<< GetDuration(preprocess_start_time, preprocess_stop_time)
<< " seconds" << endl;
- // Features used to score each grammar rule.
Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
vector<shared_ptr<Feature>> features = {
make_shared<TargetGivenSourceCoherent>(),
make_shared<SampleSourceCount>(),
@@ -194,6 +194,7 @@ int main(int argc, char** argv) {
alignment,
precomputation,
scorer,
+ vocabulary,
vm["min_gap_size"].as<int>(),
vm["max_rule_span"].as<int>(),
vm["max_nonterminals"].as<int>(),
@@ -201,9 +202,6 @@ int main(int argc, char** argv) {
vm["max_samples"].as<int>(),
vm["tight_phrases"].as<bool>());
- // Releases extra memory used by the initial precomputation.
- precomputation.reset();
-
// Creates the grammars directory if it doesn't exist.
fs::path grammar_path = vm["grammars"].as<string>();
if (!fs::is_directory(grammar_path)) {
@@ -219,6 +217,7 @@ int main(int argc, char** argv) {
}
// Extracts the grammar for each sentence and saves it to a file.
+ bool leave_one_out = vm.count("leave_one_out");
vector<string> suffixes(sentences.size());
#pragma omp parallel for schedule(dynamic) num_threads(num_threads)
for (size_t i = 0; i < sentences.size(); ++i) {
@@ -231,8 +230,11 @@ int main(int argc, char** argv) {
suffixes[i] = suffix;
unordered_set<int> blacklisted_sentence_ids;
- if (leave_one_out) blacklisted_sentence_ids.insert(i);
- Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
deleted file mode 100644
index 963afa7a..00000000
--- a/extractor/sampler.cc
+++ /dev/null
@@ -1,75 +0,0 @@
-#include "sampler.h"
-
-#include "phrase_location.h"
-#include "suffix_array.h"
-
-namespace extractor {
-
-Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) :
- suffix_array(suffix_array), max_samples(max_samples) {}
-
-Sampler::Sampler() {}
-
-Sampler::~Sampler() {}
-
-PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
- vector<int> sample;
- int num_subpatterns;
- if (location.matchings == NULL) {
- // Sample suffix array range.
- num_subpatterns = 1;
- int low = location.sa_low, high = location.sa_high;
- double step = max(1.0, (double) (high - low) / max_samples);
- double i = low, last = i;
- bool found;
- while (sample.size() < max_samples && i < high) {
- int x = suffix_array->GetSuffix(Round(i));
- int id = source_data_array->GetSentenceId(x);
- if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
- found = false;
- double backoff_step = 1;
- while (true) {
- if ((double)backoff_step >= step) break;
- double j = i - backoff_step;
- x = suffix_array->GetSuffix(Round(j));
- id = source_data_array->GetSentenceId(x);
- if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
- found = true; last = i; break;
- }
- double k = i + backoff_step;
- x = suffix_array->GetSuffix(Round(k));
- id = source_data_array->GetSentenceId(x);
- if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
- found = true; last = k; break;
- }
- if (j <= last && k >= high) break;
- backoff_step++;
- }
- } else {
- found = true;
- last = i;
- }
- if (found) sample.push_back(x);
- i += step;
- }
- } else {
- // Sample vector of occurrences.
- num_subpatterns = location.num_subpatterns;
- int num_matchings = location.matchings->size() / num_subpatterns;
- double step = max(1.0, (double) num_matchings / max_samples);
- for (double i = 0, num_samples = 0;
- i < num_matchings && num_samples < max_samples;
- i += step, ++num_samples) {
- int start = Round(i) * num_subpatterns;
- sample.insert(sample.end(), location.matchings->begin() + start,
- location.matchings->begin() + start + num_subpatterns);
- }
- }
- return PhraseLocation(sample, num_subpatterns);
-}
-
-int Sampler::Round(double x) const {
- return x + 0.5;
-}
-
-} // namespace extractor
diff --git a/extractor/sampler.h b/extractor/sampler.h
index de450c48..3c4e37f1 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -4,36 +4,20 @@
#include <memory>
#include <unordered_set>
-#include "data_array.h"
-
using namespace std;
namespace extractor {
class PhraseLocation;
-class SuffixArray;
/**
- * Provides uniform sampling for a PhraseLocation.
+ * Base sampler class.
*/
class Sampler {
public:
- Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
-
- virtual ~Sampler();
-
- // Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
-
- protected:
- Sampler();
-
- private:
- // Round floating point number to the nearest integer.
- int Round(double x) const;
-
- shared_ptr<SuffixArray> suffix_array;
- int max_samples;
+ virtual PhraseLocation Sample(
+ const PhraseLocation& location,
+ const unordered_set<int>& blacklisted_sentence_ids) const = 0;
};
} // namespace extractor
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
deleted file mode 100644
index 965567ba..00000000
--- a/extractor/sampler_test.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-#include <gtest/gtest.h>
-
-#include <memory>
-
-#include "mocks/mock_suffix_array.h"
-#include "mocks/mock_data_array.h"
-#include "phrase_location.h"
-#include "sampler.h"
-
-using namespace std;
-using namespace ::testing;
-
-namespace extractor {
-namespace {
-
-class SamplerTest : public Test {
- protected:
- virtual void SetUp() {
- source_data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
- suffix_array = make_shared<MockSuffixArray>();
- for (int i = 0; i < 10; ++i) {
- EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
- }
- }
-
- shared_ptr<MockSuffixArray> suffix_array;
- shared_ptr<Sampler> sampler;
- shared_ptr<MockDataArray> source_data_array;
-};
-
-TEST_F(SamplerTest, TestSuffixArrayRange) {
- PhraseLocation location(0, 10);
- unordered_set<int> blacklist;
-
- sampler = make_shared<Sampler>(suffix_array, 1);
- vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 4);
- expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-}
-
-TEST_F(SamplerTest, TestSubstringsSample) {
- vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- unordered_set<int> blacklist;
- PhraseLocation location(locations, 2);
-
- sampler = make_shared<Sampler>(suffix_array, 1);
- vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
-
- sampler = make_shared<Sampler>(suffix_array, 7);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
-}
-
-} // namespace
-} // namespace extractor
diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc
deleted file mode 100644
index 3305b990..00000000
--- a/extractor/sampler_test_blacklist.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-#include <gtest/gtest.h>
-
-#include <memory>
-
-#include "mocks/mock_suffix_array.h"
-#include "mocks/mock_data_array.h"
-#include "phrase_location.h"
-#include "sampler.h"
-
-using namespace std;
-using namespace ::testing;
-
-namespace extractor {
-namespace {
-
-class SamplerTestBlacklist : public Test {
- protected:
- virtual void SetUp() {
- source_data_array = make_shared<MockDataArray>();
- for (int i = 0; i < 10; ++i) {
- EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
- }
- for (int i = -10; i < 0; ++i) {
- EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0));
- }
- suffix_array = make_shared<MockSuffixArray>();
- for (int i = -10; i < 10; ++i) {
- EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
- }
- }
-
- shared_ptr<MockSuffixArray> suffix_array;
- shared_ptr<Sampler> sampler;
- shared_ptr<MockDataArray> source_data_array;
-};
-
-TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) {
- PhraseLocation location(0, 10);
- unordered_set<int> blacklist;
- vector<int> expected_locations;
-
- blacklist.insert(0);
- sampler = make_shared<Sampler>(suffix_array, 1);
- expected_locations = {1};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- for (int i = 0; i < 9; i++) {
- blacklist.insert(i);
- }
- sampler = make_shared<Sampler>(suffix_array, 1);
- expected_locations = {9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(5);
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {1, 4};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(1);
- blacklist.insert(2);
- blacklist.insert(3);
- sampler = make_shared<Sampler>(suffix_array, 2);
- expected_locations = {4, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(3);
- blacklist.insert(7);
- sampler = make_shared<Sampler>(suffix_array, 3);
- expected_locations = {1, 2, 6};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- blacklist.insert(3);
- blacklist.insert(5);
- blacklist.insert(8);
- sampler = make_shared<Sampler>(suffix_array, 4);
- expected_locations = {1, 2, 4, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(0);
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
- blacklist.clear();
-
- blacklist.insert(9);
- sampler = make_shared<Sampler>(suffix_array, 100);
- expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
-}
-
-} // namespace
-} // namespace extractor
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index 0cf4d1f6..4a514b12 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -10,7 +10,6 @@
#include "phrase_location.h"
#include "time_util.h"
-namespace fs = boost::filesystem;
using namespace std;
using namespace chrono;
@@ -188,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const {
PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
int offset) const {
- if (!data_array->HasWord(word)) {
+ int word_id = data_array->GetWordId(word);
+ if (word_id == -1) {
// Return empty phrase location.
return PhraseLocation(0, 0);
}
- int word_id = data_array->GetWordId(word);
if (offset == 0) {
return PhraseLocation(word_start[word_id], word_start[word_id + 1]);
}
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
index 8ee454ec..df80c152 100644
--- a/extractor/suffix_array.h
+++ b/extractor/suffix_array.h
@@ -5,12 +5,10 @@
#include <string>
#include <vector>
-#include <boost/filesystem.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/vector.hpp>
-namespace fs = boost::filesystem;
using namespace std;
namespace extractor {
diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc
new file mode 100644
index 00000000..4a4ced34
--- /dev/null
+++ b/extractor/suffix_array_sampler.cc
@@ -0,0 +1,40 @@
+#include "suffix_array_sampler.h"
+
+#include "data_array.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+namespace extractor {
+
+SuffixArrayRangeSampler::SuffixArrayRangeSampler(
+ shared_ptr<SuffixArray> source_suffix_array, int max_samples) :
+ BackoffSampler(source_suffix_array->GetData(), max_samples),
+ source_suffix_array(source_suffix_array) {}
+
+SuffixArrayRangeSampler::SuffixArrayRangeSampler() {}
+
+int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const {
+ return 1;
+}
+
+int SuffixArrayRangeSampler::GetRangeLow(
+ const PhraseLocation& location) const {
+ return location.sa_low;
+}
+
+int SuffixArrayRangeSampler::GetRangeHigh(
+ const PhraseLocation& location) const {
+ return location.sa_high;
+}
+
+int SuffixArrayRangeSampler::GetPosition(
+ const PhraseLocation&, int position) const {
+ return source_suffix_array->GetSuffix(position);
+}
+
+void SuffixArrayRangeSampler::AppendMatching(
+ vector<int>& samples, int index, const PhraseLocation&) const {
+ samples.push_back(source_suffix_array->GetSuffix(index));
+}
+
+} // namespace extractor
diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h
new file mode 100644
index 00000000..bb3c2653
--- /dev/null
+++ b/extractor/suffix_array_sampler.h
@@ -0,0 +1,34 @@
+#ifndef _SUFFIX_ARRAY_SAMPLER_H_
+#define _SUFFIX_ARRAY_SAMPLER_H_
+
+#include "backoff_sampler.h"
+
+namespace extractor {
+
+class SuffixArray;
+
+class SuffixArrayRangeSampler : public BackoffSampler {
+ public:
+ SuffixArrayRangeSampler(shared_ptr<SuffixArray> suffix_array,
+ int max_samples);
+
+ SuffixArrayRangeSampler();
+
+ private:
+ int GetNumSubpatterns(const PhraseLocation& location) const;
+
+ int GetRangeLow(const PhraseLocation& location) const;
+
+ int GetRangeHigh(const PhraseLocation& location) const;
+
+ int GetPosition(const PhraseLocation& location, int index) const;
+
+ void AppendMatching(vector<int>& samples, int index,
+ const PhraseLocation& location) const;
+
+ shared_ptr<SuffixArray> source_suffix_array;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc
new file mode 100644
index 00000000..4b88c027
--- /dev/null
+++ b/extractor/suffix_array_sampler_test.cc
@@ -0,0 +1,114 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_suffix_array.h"
+#include "suffix_array_sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SuffixArraySamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
+ }
+
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+};
+
+TEST_F(SuffixArraySamplerTest, TestSample) {
+ PhraseLocation location(0, 10);
+ unordered_set<int> blacklisted_sentence_ids;
+
+ SuffixArrayRangeSampler sampler(suffix_array, 1);
+ vector<int> expected_locations = {0};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 2);
+ expected_locations = {0, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 3);
+ expected_locations = {0, 3, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 4);
+ expected_locations = {0, 3, 5, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+}
+
+TEST_F(SuffixArraySamplerTest, TestBackoffSample) {
+ PhraseLocation location(0, 10);
+
+ SuffixArrayRangeSampler sampler(suffix_array, 1);
+ unordered_set<int> blacklisted_sentence_ids = {0};
+ vector<int> expected_locations = {1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ expected_locations = {9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 2);
+ blacklisted_sentence_ids = {0, 5};
+ expected_locations = {1, 4};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {0, 1, 2, 3};
+ expected_locations = {4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 3);
+ blacklisted_sentence_ids = {0, 3, 7};
+ expected_locations = {1, 2, 6};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 4);
+ blacklisted_sentence_ids = {0, 3, 5, 8};
+ expected_locations = {1, 2, 4, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ sampler = SuffixArrayRangeSampler(suffix_array, 100);
+ blacklisted_sentence_ids = {0};
+ expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+
+ blacklisted_sentence_ids = {9};
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1),
+ sampler.Sample(location, blacklisted_sentence_ids));
+}
+
+}
+} // namespace extractor
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index ba0dbcc3..161edbc0 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -21,7 +21,7 @@ class SuffixArrayTest : public Test {
virtual void SetUp() {
data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};
data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));
EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
suffix_array = SuffixArray(data_array);
@@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) {
EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
}
- EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));
- EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
+ EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1));
EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));
- EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));
- EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));
- EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3));
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 1b1ba112..11e29e1e 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount(
double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
@@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore(
double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
- if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word)) {
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ if (source_id == -1 || target_id == -1) {
return -1;
}
- int source_id = source_data_array->GetWordId(source_word);
- int target_id = target_data_array->GetWordId(target_word);
auto entry = make_pair(source_id, target_id);
auto it = translation_probabilities.find(entry);
if (it == translation_probabilities.end()) {
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index 2a37bab7..97620727 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -5,14 +5,12 @@
#include <string>
#include <unordered_map>
-#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
#include <boost/serialization/serialization.hpp>
#include <boost/serialization/split_member.hpp>
#include <boost/serialization/utility.hpp>
using namespace std;
-namespace fs = boost::filesystem;
namespace extractor {
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index 606777bd..3cfc0011 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -28,7 +28,7 @@ class TranslationTableTest : public Test {
vector<int> source_sentence_start = {0, 6, 10, 14};
shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*source_data_array, GetData())
- .WillRepeatedly(ReturnRef(source_data));
+ .WillRepeatedly(Return(source_data));
EXPECT_CALL(*source_data_array, GetNumSentences())
.WillRepeatedly(Return(3));
for (size_t i = 0; i < source_sentence_start.size(); ++i) {
@@ -36,31 +36,25 @@ class TranslationTableTest : public Test {
.WillRepeatedly(Return(source_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*source_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*source_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*source_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
vector<int> target_sentence_start = {0, 7, 10, 13};
shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
EXPECT_CALL(*target_data_array, GetData())
- .WillRepeatedly(ReturnRef(target_data));
+ .WillRepeatedly(Return(target_data));
for (size_t i = 0; i < target_sentence_start.size(); ++i) {
EXPECT_CALL(*target_data_array, GetSentenceStart(i))
.WillRepeatedly(Return(target_sentence_start[i]));
}
for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*target_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
EXPECT_CALL(*target_data_array, GetWordId(words[i]))
.WillRepeatedly(Return(i + 2));
}
- EXPECT_CALL(*target_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
+ EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1));
vector<pair<int, int>> links1 = {
make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
index 15795d1e..c9c2d6f4 100644
--- a/extractor/vocabulary.cc
+++ b/extractor/vocabulary.cc
@@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) {
int word_id = -1;
#pragma omp critical (vocabulary)
{
- if (!dictionary.count(word)) {
+ auto it = dictionary.find(word);
+ if (it != dictionary.end()) {
+ word_id = it->second;
+ } else {
word_id = words.size();
dictionary[word] = word_id;
words.push_back(word);
- } else {
- word_id = dictionary[word];
}
}
return word_id;
@@ -34,4 +35,8 @@ string Vocabulary::GetTerminalValue(int symbol) {
return word;
}
+bool Vocabulary::operator==(const Vocabulary& other) const {
+ return words == other.words && dictionary == other.dictionary;
+}
+
} // namespace extractor
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index c8fd9411..db092e99 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -5,6 +5,10 @@
#include <unordered_map>
#include <vector>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/string.hpp>
+#include <boost/serialization/vector.hpp>
+
using namespace std;
namespace extractor {
@@ -14,7 +18,7 @@ namespace extractor {
*
* This strucure contains words located in the frequent collocations and words
* encountered during the grammar extraction time. This dictionary is
- * considerably smaller than the dictionaries in the data arrays (and so is the
+ * considerably smaller than the dictionaries in the data arays (and so is the
* query time). Note that this is the single data structure that changes state
* and needs to have thread safe read/write operations.
*
@@ -38,7 +42,24 @@ class Vocabulary {
// Returns the word corresponding to the given word id.
virtual string GetTerminalValue(int symbol);
+ bool operator==(const Vocabulary& vocabulary) const;
+
private:
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << words;
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ ar >> words;
+ for (size_t i = 0; i < words.size(); ++i) {
+ dictionary[words[i]] = i;
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
unordered_map<string, int> dictionary;
vector<string> words;
};
diff --git a/extractor/vocabulary_test.cc b/extractor/vocabulary_test.cc
new file mode 100644
index 00000000..cf5e3e36
--- /dev/null
+++ b/extractor/vocabulary_test.cc
@@ -0,0 +1,45 @@
+#include <gtest/gtest.h>
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+
+#include "vocabulary.h"
+
+using namespace std;
+using namespace ::testing;
+namespace ar = boost::archive;
+
+namespace extractor {
+namespace {
+
+TEST(VocabularyTest, TestIndexes) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ EXPECT_EQ(1, vocabulary.GetTerminalIndex("one"));
+ EXPECT_EQ("one", vocabulary.GetTerminalValue(1));
+}
+
+TEST(VocabularyTest, TestSerialization) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ stringstream stream(ios_base::out | ios_base::in);
+ ar::text_oarchive output_stream(stream, ar::no_header);
+ output_stream << vocabulary;
+
+ Vocabulary vocabulary_copy;
+ ar::text_iarchive input_stream(stream, ar::no_header);
+ input_stream >> vocabulary_copy;
+
+ EXPECT_EQ(vocabulary, vocabulary_copy);
+}
+
+} // namespace
+} // namespace extractor