summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
commitc164dc0ed8a32e4095ba1b36495e0f743b8cc1ea (patch)
tree78b81e4c63adfa67adb7b8f80c3e6be87b4a2b2a /extractor
parent0e46089cafa4e8e2f060e370d7afaceeda6b90a9 (diff)
parentd467e14b28085809c31431be0478eb3d9322fe96 (diff)
merge paul's extractor code
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am149
-rw-r--r--extractor/alignment.cc53
-rw-r--r--extractor/alignment.h39
-rw-r--r--extractor/alignment_test.cc33
-rw-r--r--extractor/compile.cc100
-rw-r--r--extractor/data_array.cc161
-rw-r--r--extractor/data_array.h110
-rw-r--r--extractor/data_array_test.cc98
-rw-r--r--extractor/fast_intersector.cc195
-rw-r--r--extractor/fast_intersector.h96
-rw-r--r--extractor/fast_intersector_test.cc146
-rw-r--r--extractor/features/count_source_target.cc17
-rw-r--r--extractor/features/count_source_target.h22
-rw-r--r--extractor/features/count_source_target_test.cc36
-rw-r--r--extractor/features/feature.cc11
-rw-r--r--extractor/features/feature.h47
-rw-r--r--extractor/features/is_source_singleton.cc17
-rw-r--r--extractor/features/is_source_singleton.h22
-rw-r--r--extractor/features/is_source_singleton_test.cc39
-rw-r--r--extractor/features/is_source_target_singleton.cc17
-rw-r--r--extractor/features/is_source_target_singleton.h22
-rw-r--r--extractor/features/is_source_target_singleton_test.cc39
-rw-r--r--extractor/features/max_lex_source_given_target.cc37
-rw-r--r--extractor/features/max_lex_source_given_target.h34
-rw-r--r--extractor/features/max_lex_source_given_target_test.cc78
-rw-r--r--extractor/features/max_lex_target_given_source.cc37
-rw-r--r--extractor/features/max_lex_target_given_source.h34
-rw-r--r--extractor/features/max_lex_target_given_source_test.cc78
-rw-r--r--extractor/features/sample_source_count.cc17
-rw-r--r--extractor/features/sample_source_count.h23
-rw-r--r--extractor/features/sample_source_count_test.cc40
-rw-r--r--extractor/features/target_given_source_coherent.cc18
-rw-r--r--extractor/features/target_given_source_coherent.h23
-rw-r--r--extractor/features/target_given_source_coherent_test.cc39
-rw-r--r--extractor/grammar.cc43
-rw-r--r--extractor/grammar.h34
-rw-r--r--extractor/grammar_extractor.cc62
-rw-r--r--extractor/grammar_extractor.h62
-rw-r--r--extractor/grammar_extractor_test.cc51
-rw-r--r--extractor/matchings_finder.cc25
-rw-r--r--extractor/matchings_finder.h37
-rw-r--r--extractor/matchings_finder_test.cc44
-rw-r--r--extractor/matchings_trie.cc29
-rw-r--r--extractor/matchings_trie.h66
-rw-r--r--extractor/mocks/mock_alignment.h14
-rw-r--r--extractor/mocks/mock_data_array.h23
-rw-r--r--extractor/mocks/mock_fast_intersector.h15
-rw-r--r--extractor/mocks/mock_feature.h15
-rw-r--r--extractor/mocks/mock_matchings_finder.h13
-rw-r--r--extractor/mocks/mock_precomputation.h12
-rw-r--r--extractor/mocks/mock_rule_extractor.h16
-rw-r--r--extractor/mocks/mock_rule_extractor_helper.h82
-rw-r--r--extractor/mocks/mock_rule_factory.h13
-rw-r--r--extractor/mocks/mock_sampler.h13
-rw-r--r--extractor/mocks/mock_scorer.h15
-rw-r--r--extractor/mocks/mock_suffix_array.h23
-rw-r--r--extractor/mocks/mock_target_phrase_extractor.h16
-rw-r--r--extractor/mocks/mock_translation_table.h13
-rw-r--r--extractor/mocks/mock_vocabulary.h13
-rw-r--r--extractor/phrase.cc58
-rw-r--r--extractor/phrase.h52
-rw-r--r--extractor/phrase_builder.cc48
-rw-r--r--extractor/phrase_builder.h33
-rw-r--r--extractor/phrase_location.cc43
-rw-r--r--extractor/phrase_location.h41
-rw-r--r--extractor/phrase_test.cc83
-rw-r--r--extractor/precomputation.cc189
-rw-r--r--extractor/precomputation.h80
-rw-r--r--extractor/precomputation_test.cc106
-rw-r--r--extractor/rule.cc14
-rw-r--r--extractor/rule.h27
-rw-r--r--extractor/rule_extractor.cc343
-rw-r--r--extractor/rule_extractor.h124
-rw-r--r--extractor/rule_extractor_helper.cc362
-rw-r--r--extractor/rule_extractor_helper.h101
-rw-r--r--extractor/rule_extractor_helper_test.cc645
-rw-r--r--extractor/rule_extractor_test.cc168
-rw-r--r--extractor/rule_factory.cc303
-rw-r--r--extractor/rule_factory.h118
-rw-r--r--extractor/rule_factory_test.cc103
-rw-r--r--extractor/run_extractor.cc242
-rw-r--r--extractor/sample_alignment.txt2
-rw-r--r--extractor/sample_bitext.txt2
-rw-r--r--extractor/sampler.cc46
-rw-r--r--extractor/sampler.h38
-rw-r--r--extractor/sampler_test.cc74
-rw-r--r--extractor/scorer.cc30
-rw-r--r--extractor/scorer.h41
-rw-r--r--extractor/scorer_test.cc49
-rw-r--r--extractor/suffix_array.cc235
-rw-r--r--extractor/suffix_array.h75
-rw-r--r--extractor/suffix_array_test.cc78
-rw-r--r--extractor/target_phrase_extractor.cc158
-rw-r--r--extractor/target_phrase_extractor.h64
-rw-r--r--extractor/target_phrase_extractor_test.cc143
-rw-r--r--extractor/time_util.cc10
-rw-r--r--extractor/time_util.h19
-rw-r--r--extractor/translation_table.cc126
-rw-r--r--extractor/translation_table.h63
-rw-r--r--extractor/translation_table_test.cc84
-rw-r--r--extractor/vocabulary.cc37
-rw-r--r--extractor/vocabulary.h48
102 files changed, 7381 insertions, 0 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
new file mode 100644
index 00000000..d8239b7d
--- /dev/null
+++ b/extractor/Makefile.am
@@ -0,0 +1,149 @@
+bin_PROGRAMS = compile run_extractor
+
+EXTRA_PROGRAMS = alignment_test \
+ data_array_test \
+ fast_intersector_test \
+ feature_count_source_target_test \
+ feature_is_source_singleton_test \
+ feature_is_source_target_singleton_test \
+ feature_max_lex_source_given_target_test \
+ feature_max_lex_target_given_source_test \
+ feature_sample_source_count_test \
+ feature_target_given_source_coherent_test \
+ grammar_extractor_test \
+ matchings_finder_test \
+ phrase_test \
+ precomputation_test \
+ rule_extractor_helper_test \
+ rule_extractor_test \
+ rule_factory_test \
+ sampler_test \
+ scorer_test \
+ suffix_array_test \
+ target_phrase_extractor_test \
+ translation_table_test
+
+if HAVE_GTEST
+ RUNNABLE_TESTS = alignment_test \
+ data_array_test \
+ fast_intersector_test \
+ feature_count_source_target_test \
+ feature_is_source_singleton_test \
+ feature_is_source_target_singleton_test \
+ feature_max_lex_source_given_target_test \
+ feature_max_lex_target_given_source_test \
+ feature_sample_source_count_test \
+ feature_target_given_source_coherent_test \
+ grammar_extractor_test \
+ matchings_finder_test \
+ phrase_test \
+ precomputation_test \
+ rule_extractor_helper_test \
+ rule_extractor_test \
+ rule_factory_test \
+ sampler_test \
+ scorer_test \
+ suffix_array_test \
+ target_phrase_extractor_test \
+ translation_table_test
+endif
+
+noinst_PROGRAMS = $(RUNNABLE_TESTS)
+
+TESTS = $(RUNNABLE_TESTS)
+
+alignment_test_SOURCES = alignment_test.cc
+alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+data_array_test_SOURCES = data_array_test.cc
+data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+fast_intersector_test_SOURCES = fast_intersector_test.cc
+fast_intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+feature_count_source_target_test_SOURCES = features/count_source_target_test.cc
+feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc
+feature_is_source_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_is_source_target_singleton_test_SOURCES = features/is_source_target_singleton_test.cc
+feature_is_source_target_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_max_lex_source_given_target_test_SOURCES = features/max_lex_source_given_target_test.cc
+feature_max_lex_source_given_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+feature_max_lex_target_given_source_test_SOURCES = features/max_lex_target_given_source_test.cc
+feature_max_lex_target_given_source_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+feature_sample_source_count_test_SOURCES = features/sample_source_count_test.cc
+feature_sample_source_count_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_target_given_source_coherent_test_SOURCES = features/target_given_source_coherent_test.cc
+feature_target_given_source_coherent_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+grammar_extractor_test_SOURCES = grammar_extractor_test.cc
+grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+matchings_finder_test_SOURCES = matchings_finder_test.cc
+matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+phrase_test_SOURCES = phrase_test.cc
+phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+precomputation_test_SOURCES = precomputation_test.cc
+precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_extractor_helper_test_SOURCES = rule_extractor_helper_test.cc
+rule_extractor_helper_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_extractor_test_SOURCES = rule_extractor_test.cc
+rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_factory_test_SOURCES = rule_factory_test.cc
+rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+sampler_test_SOURCES = sampler_test.cc
+sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+scorer_test_SOURCES = scorer_test.cc
+scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+suffix_array_test_SOURCES = suffix_array_test.cc
+suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
+target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+translation_table_test_SOURCES = translation_table_test.cc
+translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+
+noinst_LIBRARIES = libextractor.a libcompile.a
+
+compile_SOURCES = compile.cc
+compile_LDADD = libcompile.a
+run_extractor_SOURCES = run_extractor.cc
+run_extractor_LDADD = libextractor.a
+
+libcompile_a_SOURCES = \
+ alignment.cc \
+ data_array.cc \
+ phrase_location.cc \
+ precomputation.cc \
+ suffix_array.cc \
+ time_util.cc \
+ translation_table.cc
+
+libextractor_a_SOURCES = \
+ alignment.cc \
+ data_array.cc \
+ fast_intersector.cc \
+ features/count_source_target.cc \
+ features/feature.cc \
+ features/is_source_singleton.cc \
+ features/is_source_target_singleton.cc \
+ features/max_lex_source_given_target.cc \
+ features/max_lex_target_given_source.cc \
+ features/sample_source_count.cc \
+ features/target_given_source_coherent.cc \
+ grammar.cc \
+ grammar_extractor.cc \
+ matchings_finder.cc \
+ matchings_trie.cc \
+ phrase.cc \
+ phrase_builder.cc \
+ phrase_location.cc \
+ precomputation.cc \
+ rule.cc \
+ rule_extractor.cc \
+ rule_extractor_helper.cc \
+ rule_factory.cc \
+ sampler.cc \
+ scorer.cc \
+ suffix_array.cc \
+ target_phrase_extractor.cc \
+ time_util.cc \
+ translation_table.cc \
+ vocabulary.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x -fopenmp $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS)
+AM_LDFLAGS = -fopenmp
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
new file mode 100644
index 00000000..1aea34b3
--- /dev/null
+++ b/extractor/alignment.cc
@@ -0,0 +1,53 @@
+#include "alignment.h"
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <fcntl.h>
+#include <unistd.h>
+#include <vector>
+
+#include <boost/algorithm/string.hpp>
+#include <boost/filesystem.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+Alignment::Alignment(const string& filename) {
+ ifstream infile(filename.c_str());
+ string line;
+ while (getline(infile, line)) {
+ vector<string> items;
+ boost::split(items, line, boost::is_any_of(" -"));
+ vector<pair<int, int> > alignment;
+ alignment.reserve(items.size() / 2);
+ for (size_t i = 0; i < items.size(); i += 2) {
+ alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1])));
+ }
+ alignments.push_back(alignment);
+ }
+ alignments.shrink_to_fit();
+}
+
+Alignment::Alignment() {}
+
+Alignment::~Alignment() {}
+
+vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const {
+ return alignments[sentence_index];
+}
+
+void Alignment::WriteBinary(const fs::path& filepath) {
+ FILE* file = fopen(filepath.string().c_str(), "w");
+ int size = alignments.size();
+ fwrite(&size, sizeof(int), 1, file);
+ for (vector<pair<int, int> > alignment: alignments) {
+ size = alignment.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(alignment.data(), sizeof(pair<int, int>), size, file);
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/alignment.h b/extractor/alignment.h
new file mode 100644
index 00000000..e9292121
--- /dev/null
+++ b/extractor/alignment.h
@@ -0,0 +1,39 @@
+#ifndef _ALIGNMENT_H_
+#define _ALIGNMENT_H_
+
+#include <string>
+#include <vector>
+
+#include <boost/filesystem.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Data structure storing the word alignments for a parallel corpus.
+ */
+class Alignment {
+ public:
+ // Reads alignment from text file.
+ Alignment(const string& filename);
+
+ // Returns the alignment for a given sentence.
+ virtual vector<pair<int, int> > GetLinks(int sentence_index) const;
+
+ // Writes alignment to file in binary format.
+ void WriteBinary(const fs::path& filepath);
+
+ virtual ~Alignment();
+
+ protected:
+ Alignment();
+
+ private:
+ vector<vector<pair<int, int> > > alignments;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc
new file mode 100644
index 00000000..a7defb66
--- /dev/null
+++ b/extractor/alignment_test.cc
@@ -0,0 +1,33 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "alignment.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class AlignmentTest : public Test {
+ protected:
+ virtual void SetUp() {
+ alignment = make_shared<Alignment>("sample_alignment.txt");
+ }
+
+ shared_ptr<Alignment> alignment;
+};
+
+TEST_F(AlignmentTest, TestGetLinks) {
+ vector<pair<int, int> > expected_links = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2)
+ };
+ EXPECT_EQ(expected_links, alignment->GetLinks(0));
+ expected_links = {make_pair(1, 0), make_pair(2, 1)};
+ EXPECT_EQ(expected_links, alignment->GetLinks(1));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/compile.cc b/extractor/compile.cc
new file mode 100644
index 00000000..a9ae2cef
--- /dev/null
+++ b/extractor/compile.cc
@@ -0,0 +1,100 @@
+#include <iostream>
+#include <string>
+
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "precomputation.h"
+#include "suffix_array.h"
+#include "translation_table.h"
+
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace std;
+using namespace extractor;
+
+int main(int argc, char** argv) {
+ po::options_description desc("Command line options");
+ desc.add_options()
+ ("help,h", "Show available options")
+ ("source,f", po::value<string>(), "Source language corpus")
+ ("target,e", po::value<string>(), "Target language corpus")
+ ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
+ ("alignment,a", po::value<string>()->required(), "Bitext word alignment")
+ ("output,o", po::value<string>()->required(), "Output path")
+ ("frequent", po::value<int>()->default_value(100),
+ "Number of precomputed frequent patterns")
+ ("super_frequent", po::value<int>()->default_value(10),
+ "Number of precomputed super frequent patterns")
+ ("max_rule_span,s", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols,l", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size,g", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_phrase_len,p", po::value<int>()->default_value(4),
+ "Maximum frequent phrase length")
+ ("min_frequency", po::value<int>()->default_value(1000),
+ "Minimum number of occurrences for a pharse to be considered frequent");
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+
+ // Check for help argument before notify, so we don't need to pass in the
+ // required parameters.
+ if (vm.count("help")) {
+ cout << desc << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) {
+ cerr << "A paralel corpus is required. "
+ << "Use -f (source) with -e (target) or -b (bitext)."
+ << endl;
+ return 1;
+ }
+
+ fs::path output_dir(vm["output"].as<string>().c_str());
+ if (!fs::exists(output_dir)) {
+ fs::create_directory(output_dir);
+ }
+
+ shared_ptr<DataArray> source_data_array, target_data_array;
+ if (vm.count("bitext")) {
+ source_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), SOURCE);
+ target_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), TARGET);
+ } else {
+ source_data_array = make_shared<DataArray>(vm["source"].as<string>());
+ target_data_array = make_shared<DataArray>(vm["target"].as<string>());
+ }
+ shared_ptr<SuffixArray> source_suffix_array =
+ make_shared<SuffixArray>(source_data_array);
+ source_suffix_array->WriteBinary(output_dir / fs::path("f.bin"));
+ target_data_array->WriteBinary(output_dir / fs::path("e.bin"));
+
+ shared_ptr<Alignment> alignment =
+ make_shared<Alignment>(vm["alignment"].as<string>());
+ alignment->WriteBinary(output_dir / fs::path("a.bin"));
+
+ Precomputation precomputation(
+ source_suffix_array,
+ vm["frequent"].as<int>(),
+ vm["super_frequent"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["min_gap_size"].as<int>(),
+ vm["max_phrase_len"].as<int>(),
+ vm["min_frequency"].as<int>());
+ precomputation.WriteBinary(output_dir / fs::path("precompute.bin"));
+
+ TranslationTable table(source_data_array, target_data_array, alignment);
+ table.WriteBinary(output_dir / fs::path("lex.bin"));
+
+ return 0;
+}
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
new file mode 100644
index 00000000..203fe219
--- /dev/null
+++ b/extractor/data_array.cc
@@ -0,0 +1,161 @@
+#include "data_array.h"
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+
+#include <boost/filesystem.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+int DataArray::NULL_WORD = 0;
+int DataArray::END_OF_LINE = 1;
+string DataArray::NULL_WORD_STR = "__NULL__";
+string DataArray::END_OF_LINE_STR = "__END_OF_LINE__";
+
+DataArray::DataArray() {
+ InitializeDataArray();
+}
+
+DataArray::DataArray(const string& filename) {
+ InitializeDataArray();
+ ifstream infile(filename.c_str());
+ vector<string> lines;
+ string line;
+ while (getline(infile, line)) {
+ lines.push_back(line);
+ }
+ CreateDataArray(lines);
+}
+
+DataArray::DataArray(const string& filename, const Side& side) {
+ InitializeDataArray();
+ ifstream infile(filename.c_str());
+ vector<string> lines;
+ string line, delimiter = "|||";
+ while (getline(infile, line)) {
+ int position = line.find(delimiter);
+ if (side == SOURCE) {
+ lines.push_back(line.substr(0, position));
+ } else {
+ lines.push_back(line.substr(position + delimiter.size()));
+ }
+ }
+ CreateDataArray(lines);
+}
+
+void DataArray::InitializeDataArray() {
+ word2id[NULL_WORD_STR] = NULL_WORD;
+ id2word.push_back(NULL_WORD_STR);
+ word2id[END_OF_LINE_STR] = END_OF_LINE;
+ id2word.push_back(END_OF_LINE_STR);
+}
+
+void DataArray::CreateDataArray(const vector<string>& lines) {
+ for (size_t i = 0; i < lines.size(); ++i) {
+ sentence_start.push_back(data.size());
+
+ istringstream iss(lines[i]);
+ string word;
+ while (iss >> word) {
+ if (word2id.count(word) == 0) {
+ word2id[word] = id2word.size();
+ id2word.push_back(word);
+ }
+ data.push_back(word2id[word]);
+ sentence_id.push_back(i);
+ }
+ data.push_back(END_OF_LINE);
+ sentence_id.push_back(i);
+ }
+ sentence_start.push_back(data.size());
+
+ data.shrink_to_fit();
+ sentence_id.shrink_to_fit();
+ sentence_start.shrink_to_fit();
+}
+
+DataArray::~DataArray() {}
+
+const vector<int>& DataArray::GetData() const {
+ return data;
+}
+
+int DataArray::AtIndex(int index) const {
+ return data[index];
+}
+
+string DataArray::GetWordAtIndex(int index) const {
+ return id2word[data[index]];
+}
+
+int DataArray::GetSize() const {
+ return data.size();
+}
+
+int DataArray::GetVocabularySize() const {
+ return id2word.size();
+}
+
+int DataArray::GetNumSentences() const {
+ return sentence_start.size() - 1;
+}
+
+int DataArray::GetSentenceStart(int position) const {
+ return sentence_start[position];
+}
+
+int DataArray::GetSentenceLength(int sentence_id) const {
+ // Ignore end of line markers.
+ return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1;
+}
+
+int DataArray::GetSentenceId(int position) const {
+ return sentence_id[position];
+}
+
+void DataArray::WriteBinary(const fs::path& filepath) const {
+ std::cerr << "File: " << filepath.string() << std::endl;
+ WriteBinary(fopen(filepath.string().c_str(), "w"));
+}
+
+void DataArray::WriteBinary(FILE* file) const {
+ int size = id2word.size();
+ fwrite(&size, sizeof(int), 1, file);
+ for (string word: id2word) {
+ size = word.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(word.data(), sizeof(char), size, file);
+ }
+
+ size = data.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(data.data(), sizeof(int), size, file);
+
+ size = sentence_id.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(sentence_id.data(), sizeof(int), size, file);
+
+ size = sentence_start.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(sentence_start.data(), sizeof(int), 1, file);
+}
+
+bool DataArray::HasWord(const string& word) const {
+ return word2id.count(word);
+}
+
+int DataArray::GetWordId(const string& word) const {
+ auto result = word2id.find(word);
+ return result == word2id.end() ? -1 : result->second;
+}
+
+string DataArray::GetWord(int word_id) const {
+ return id2word[word_id];
+}
+
+} // namespace extractor
diff --git a/extractor/data_array.h b/extractor/data_array.h
new file mode 100644
index 00000000..978a6931
--- /dev/null
+++ b/extractor/data_array.h
@@ -0,0 +1,110 @@
+#ifndef _DATA_ARRAY_H_
+#define _DATA_ARRAY_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <boost/filesystem.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+enum Side {
+ SOURCE,
+ TARGET
+};
+
+/**
+ * Data structure storing information about a single side of a parallel corpus.
+ *
+ * Each word is mapped to a unique integer (word_id). The data structure holds
+ * the corpus in the numberized format, together with the hash table mapping
+ * words to word_ids. It also holds additional information such as the starting
+ * index for each sentence and, for each token, the index of the sentence it
+ * belongs to.
+ *
+ * Note: This class has features for both the source and target data arrays.
+ * Maybe we can save some memory by having more specific implementations (not
+ * likely to save a lot of memory tough).
+ */
+class DataArray {
+ public:
+ static int NULL_WORD;
+ static int END_OF_LINE;
+ static string NULL_WORD_STR;
+ static string END_OF_LINE_STR;
+
+ // Reads data array from text file.
+ DataArray(const string& filename);
+
+ // Reads data array from bitext file where the sentences are separated by |||.
+ DataArray(const string& filename, const Side& side);
+
+ virtual ~DataArray();
+
+ // Returns a vector containing the word ids.
+ virtual const vector<int>& GetData() const;
+
+ // Returns the word id at the specified position.
+ virtual int AtIndex(int index) const;
+
+ // Returns the original word at the specified position.
+ virtual string GetWordAtIndex(int index) const;
+
+ // Returns the size of the data array.
+ virtual int GetSize() const;
+
+ // Returns the number of distinct words in the data array.
+ virtual int GetVocabularySize() const;
+
+ // Returns whether a word has ever been observed in the data array.
+ virtual bool HasWord(const string& word) const;
+
+ // Returns the word id for a given word or -1 if it the word has never been
+ // observed.
+ virtual int GetWordId(const string& word) const;
+
+ // Returns the word corresponding to a particular word id.
+ virtual string GetWord(int word_id) const;
+
+ // Returns the number of sentences in the data.
+ virtual int GetNumSentences() const;
+
+ // Returns the index where the sentence containing the given position starts.
+ virtual int GetSentenceStart(int position) const;
+
+ // Returns the length of the sentence.
+ virtual int GetSentenceLength(int sentence_id) const;
+
+ // Returns the number of the sentence containing the given position.
+ virtual int GetSentenceId(int position) const;
+
+ // Writes data array to file in binary format.
+ void WriteBinary(const fs::path& filepath) const;
+
+ // Writes data array to file in binary format.
+ void WriteBinary(FILE* file) const;
+
+ protected:
+ DataArray();
+
+ private:
+ // Sets up specific constants.
+ void InitializeDataArray();
+
+ // Constructs the data array.
+ void CreateDataArray(const vector<string>& lines);
+
+ unordered_map<string, int> word2id;
+ vector<string> id2word;
+ vector<int> data;
+ vector<int> sentence_id;
+ vector<int> sentence_start;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
new file mode 100644
index 00000000..71175fda
--- /dev/null
+++ b/extractor/data_array_test.cc
@@ -0,0 +1,98 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include <boost/filesystem.hpp>
+
+#include "data_array.h"
+
+using namespace std;
+using namespace ::testing;
+namespace fs = boost::filesystem;
+
+namespace extractor {
+namespace {
+
+class DataArrayTest : public Test {
+ protected:
+ virtual void SetUp() {
+ string sample_test_file("sample_bitext.txt");
+ source_data = make_shared<DataArray>(sample_test_file, SOURCE);
+ target_data = make_shared<DataArray>(sample_test_file, TARGET);
+ }
+
+ shared_ptr<DataArray> source_data;
+ shared_ptr<DataArray> target_data;
+};
+
+TEST_F(DataArrayTest, TestGetData) {
+ vector<int> expected_source_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 5, 1};
+ vector<string> expected_source_words = {
+ "ana", "are", "mere", ".", "__END_OF_LINE__",
+ "ana", "bea", "mult", "lapte", ".", "__END_OF_LINE__"
+ };
+ EXPECT_EQ(expected_source_data, source_data->GetData());
+ EXPECT_EQ(expected_source_data.size(), source_data->GetSize());
+ for (size_t i = 0; i < expected_source_data.size(); ++i) {
+ EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i));
+ EXPECT_EQ(expected_source_words[i], source_data->GetWordAtIndex(i));
+ }
+
+ vector<int> expected_target_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1};
+ vector<string> expected_target_words = {
+ "anna", "has", "apples", ".", "__END_OF_LINE__",
+ "anna", "drinks", "a", "lot", "of", "milk", ".", "__END_OF_LINE__"
+ };
+ EXPECT_EQ(expected_target_data, target_data->GetData());
+ EXPECT_EQ(expected_target_data.size(), target_data->GetSize());
+ for (size_t i = 0; i < expected_target_data.size(); ++i) {
+ EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i));
+ EXPECT_EQ(expected_target_words[i], target_data->GetWordAtIndex(i));
+ }
+}
+
+TEST_F(DataArrayTest, TestVocabulary) {
+ EXPECT_EQ(9, source_data->GetVocabularySize());
+ EXPECT_TRUE(source_data->HasWord("mere"));
+ EXPECT_EQ(4, source_data->GetWordId("mere"));
+ EXPECT_EQ("mere", source_data->GetWord(4));
+ EXPECT_FALSE(source_data->HasWord("banane"));
+
+ EXPECT_EQ(11, target_data->GetVocabularySize());
+ EXPECT_TRUE(target_data->HasWord("apples"));
+ EXPECT_EQ(4, target_data->GetWordId("apples"));
+ EXPECT_EQ("apples", target_data->GetWord(4));
+ EXPECT_FALSE(target_data->HasWord("bananas"));
+}
+
+TEST_F(DataArrayTest, TestSentenceData) {
+ EXPECT_EQ(2, source_data->GetNumSentences());
+ EXPECT_EQ(0, source_data->GetSentenceStart(0));
+ EXPECT_EQ(5, source_data->GetSentenceStart(1));
+ EXPECT_EQ(11, source_data->GetSentenceStart(2));
+
+ EXPECT_EQ(4, source_data->GetSentenceLength(0));
+ EXPECT_EQ(5, source_data->GetSentenceLength(1));
+
+ vector<int> expected_source_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1};
+ for (size_t i = 0; i < expected_source_ids.size(); ++i) {
+ EXPECT_EQ(expected_source_ids[i], source_data->GetSentenceId(i));
+ }
+
+ EXPECT_EQ(2, target_data->GetNumSentences());
+ EXPECT_EQ(0, target_data->GetSentenceStart(0));
+ EXPECT_EQ(5, target_data->GetSentenceStart(1));
+ EXPECT_EQ(13, target_data->GetSentenceStart(2));
+
+ EXPECT_EQ(4, target_data->GetSentenceLength(0));
+ EXPECT_EQ(7, target_data->GetSentenceLength(1));
+
+ vector<int> expected_target_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};
+ for (size_t i = 0; i < expected_target_ids.size(); ++i) {
+ EXPECT_EQ(expected_target_ids[i], target_data->GetSentenceId(i));
+ }
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
new file mode 100644
index 00000000..2a7693b2
--- /dev/null
+++ b/extractor/fast_intersector.cc
@@ -0,0 +1,195 @@
+#include "fast_intersector.h"
+
+#include <cassert>
+
+#include "data_array.h"
+#include "phrase.h"
+#include "phrase_location.h"
+#include "precomputation.h"
+#include "suffix_array.h"
+#include "vocabulary.h"
+
+namespace extractor {
+
+FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size) :
+ suffix_array(suffix_array),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size) {
+ Index precomputed_collocations = precomputation->GetCollocations();
+ for (pair<vector<int>, vector<int> > entry: precomputed_collocations) {
+ vector<int> phrase = ConvertPhrase(entry.first);
+ collocations[phrase] = entry.second;
+ }
+}
+
+FastIntersector::FastIntersector() {}
+
+FastIntersector::~FastIntersector() {}
+
+vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) {
+ vector<int> new_phrase;
+ new_phrase.reserve(old_phrase.size());
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ for (int word_id: old_phrase) {
+ if (word_id < 0) {
+ new_phrase.push_back(word_id);
+ } else {
+ new_phrase.push_back(
+ vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
+ }
+ }
+ return new_phrase;
+}
+
+PhraseLocation FastIntersector::Intersect(
+ PhraseLocation& prefix_location,
+ PhraseLocation& suffix_location,
+ const Phrase& phrase) {
+ vector<int> symbols = phrase.Get();
+
+ // We should never attempt to do an intersect query for a pattern starting or
+ // ending with a non terminal. The RuleFactory should handle these cases,
+ // initializing the matchings list with the one for the pattern without the
+ // starting or ending terminal.
+ assert(vocabulary->IsTerminal(symbols.front())
+ && vocabulary->IsTerminal(symbols.back()));
+
+ if (collocations.count(symbols)) {
+ return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
+ }
+
+ bool prefix_ends_with_x =
+ !vocabulary->IsTerminal(symbols[symbols.size() - 2]);
+ bool suffix_starts_with_x = !vocabulary->IsTerminal(symbols[1]);
+ if (EstimateNumOperations(prefix_location, prefix_ends_with_x) <=
+ EstimateNumOperations(suffix_location, suffix_starts_with_x)) {
+ return ExtendPrefixPhraseLocation(prefix_location, phrase,
+ prefix_ends_with_x, symbols.back());
+ } else {
+ return ExtendSuffixPhraseLocation(suffix_location, phrase,
+ suffix_starts_with_x, symbols.front());
+ }
+}
+
+int FastIntersector::EstimateNumOperations(
+ const PhraseLocation& phrase_location, bool has_margin_x) const {
+ int num_locations = phrase_location.GetSize();
+ return has_margin_x ? num_locations * max_rule_span : num_locations;
+}
+
+PhraseLocation FastIntersector::ExtendPrefixPhraseLocation(
+ PhraseLocation& prefix_location, const Phrase& phrase,
+ bool prefix_ends_with_x, int next_symbol) const {
+ ExtendPhraseLocation(prefix_location);
+ vector<int> positions = *prefix_location.matchings;
+ int num_subpatterns = prefix_location.num_subpatterns;
+
+ vector<int> new_positions;
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ int data_array_symbol = data_array->GetWordId(
+ vocabulary->GetTerminalValue(next_symbol));
+ if (data_array_symbol == -1) {
+ return PhraseLocation(new_positions, num_subpatterns);
+ }
+
+ pair<int, int> range = GetSearchRange(prefix_ends_with_x);
+ for (size_t i = 0; i < positions.size(); i += num_subpatterns) {
+ int sent_id = data_array->GetSentenceId(positions[i]);
+ int sent_end = data_array->GetSentenceStart(sent_id + 1) - 1;
+ int pattern_end = positions[i + num_subpatterns - 1] + range.first;
+ if (prefix_ends_with_x) {
+ pattern_end += phrase.GetChunkLen(phrase.Arity() - 1) - 1;
+ } else {
+ pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2;
+ }
+ // Searches for the last symbol in the phrase after each prefix occurrence.
+ for (int j = range.first; j < range.second; ++j) {
+ if (pattern_end >= sent_end ||
+ pattern_end - positions[i] >= max_rule_span) {
+ break;
+ }
+
+ if (data_array->AtIndex(pattern_end) == data_array_symbol) {
+ new_positions.insert(new_positions.end(), positions.begin() + i,
+ positions.begin() + i + num_subpatterns);
+ if (prefix_ends_with_x) {
+ new_positions.push_back(pattern_end);
+ }
+ }
+ ++pattern_end;
+ }
+ }
+
+ return PhraseLocation(new_positions, phrase.Arity() + 1);
+}
+
+PhraseLocation FastIntersector::ExtendSuffixPhraseLocation(
+ PhraseLocation& suffix_location, const Phrase& phrase,
+ bool suffix_starts_with_x, int prev_symbol) const {
+ ExtendPhraseLocation(suffix_location);
+ vector<int> positions = *suffix_location.matchings;
+ int num_subpatterns = suffix_location.num_subpatterns;
+
+ vector<int> new_positions;
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ int data_array_symbol = data_array->GetWordId(
+ vocabulary->GetTerminalValue(prev_symbol));
+ if (data_array_symbol == -1) {
+ return PhraseLocation(new_positions, num_subpatterns);
+ }
+
+ pair<int, int> range = GetSearchRange(suffix_starts_with_x);
+ for (size_t i = 0; i < positions.size(); i += num_subpatterns) {
+ int sent_id = data_array->GetSentenceId(positions[i]);
+ int sent_start = data_array->GetSentenceStart(sent_id);
+ int pattern_start = positions[i] - range.first;
+ int pattern_end = positions[i + num_subpatterns - 1] +
+ phrase.GetChunkLen(phrase.Arity()) - 1;
+ // Searches for the first symbol in the phrase before each suffix
+ // occurrence.
+ for (int j = range.first; j < range.second; ++j) {
+ if (pattern_start < sent_start ||
+ pattern_end - pattern_start >= max_rule_span) {
+ break;
+ }
+
+ if (data_array->AtIndex(pattern_start) == data_array_symbol) {
+ new_positions.push_back(pattern_start);
+ new_positions.insert(new_positions.end(),
+ positions.begin() + i + !suffix_starts_with_x,
+ positions.begin() + i + num_subpatterns);
+ }
+ --pattern_start;
+ }
+ }
+
+ return PhraseLocation(new_positions, phrase.Arity() + 1);
+}
+
+void FastIntersector::ExtendPhraseLocation(PhraseLocation& location) const {
+ if (location.matchings != NULL) {
+ return;
+ }
+
+ location.num_subpatterns = 1;
+ location.matchings = make_shared<vector<int> >();
+ for (int i = location.sa_low; i < location.sa_high; ++i) {
+ location.matchings->push_back(suffix_array->GetSuffix(i));
+ }
+ location.sa_low = location.sa_high = 0;
+}
+
+pair<int, int> FastIntersector::GetSearchRange(bool has_marginal_x) const {
+ if (has_marginal_x) {
+ return make_pair(min_gap_size + 1, max_rule_span);
+ } else {
+ return make_pair(1, 2);
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
new file mode 100644
index 00000000..f950a2a9
--- /dev/null
+++ b/extractor/fast_intersector.h
@@ -0,0 +1,96 @@
+#ifndef _FAST_INTERSECTOR_H_
+#define _FAST_INTERSECTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+using namespace std;
+
+namespace extractor {
+
+typedef boost::hash<vector<int> > VectorHash;
+typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
+
+class Phrase;
+class PhraseLocation;
+class Precomputation;
+class SuffixArray;
+class Vocabulary;
+
+/**
+ * Component for searching the training data for occurrences of source phrases
+ * containing nonterminals
+ *
+ * Given a source phrase containing a nonterminal, we first query the
+ * precomputed index containing frequent collocations. If the phrase is not
+ * frequent enough, we extend the matchings of either its prefix or its suffix,
+ * depending on which operation seems to require less computations.
+ *
+ * Note: This method for intersecting phrase locations is faster than both
+ * mergers (linear or Baeza Yates) described in Adam Lopez' dissertation.
+ */
+class FastIntersector {
+ public:
+ FastIntersector(shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size);
+
+ virtual ~FastIntersector();
+
+ // Finds the locations of a phrase given the locations of its prefix and
+ // suffix.
+ virtual PhraseLocation Intersect(PhraseLocation& prefix_location,
+ PhraseLocation& suffix_location,
+ const Phrase& phrase);
+
+ protected:
+ FastIntersector();
+
+ private:
+ // Uses the vocabulary to convert the phrase from the numberized format
+ // specified by the source data array to the numberized format given by the
+ // vocabulary.
+ vector<int> ConvertPhrase(const vector<int>& old_phrase);
+
+ // Estimates the number of computations needed if the prefix/suffix is
+ // extended. If the last/first symbol is separated from the rest of the phrase
+ // by a nonterminal, then for each occurrence of the prefix/suffix we need to
+ // check max_rule_span positions. Otherwise, we only need to check a single
+ // position for each occurrence.
+ int EstimateNumOperations(const PhraseLocation& phrase_location,
+ bool has_margin_x) const;
+
+ // Uses the occurrences of the prefix to find the occurrences of the phrase.
+ PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location,
+ const Phrase& phrase,
+ bool prefix_ends_with_x,
+ int next_symbol) const;
+
+ // Uses the occurrences of the suffix to find the occurrences of the phrase.
+ PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location,
+ const Phrase& phrase,
+ bool suffix_starts_with_x,
+ int prev_symbol) const;
+
+ // Extends the prefix/suffix location to a list of subpatterns positions if it
+ // represents a suffix array range.
+ void ExtendPhraseLocation(PhraseLocation& location) const;
+
+ // Returns the range in which the search should be performed.
+ pair<int, int> GetSearchRange(bool has_marginal_x) const;
+
+ shared_ptr<SuffixArray> suffix_array;
+ shared_ptr<Vocabulary> vocabulary;
+ int max_rule_span;
+ int min_gap_size;
+ Index collocations;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc
new file mode 100644
index 00000000..76c3aaea
--- /dev/null
+++ b/extractor/fast_intersector_test.cc
@@ -0,0 +1,146 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "fast_intersector.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_precomputation.h"
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_location.h"
+#include "phrase_builder.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class FastIntersectorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> words = {"EOL", "it", "makes", "him", "and", "mars", ",",
+ "sets", "on", "takes", "off", "."};
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i]))
+ .WillRepeatedly(Return(i));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(words[i]));
+ }
+
+ vector<int> data = {1, 2, 3, 4, 1, 5, 3, 6, 1,
+ 7, 3, 8, 4, 1, 9, 3, 10, 11, 0};
+ data_array = make_shared<MockDataArray>();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+ EXPECT_CALL(*data_array, GetSentenceId(i))
+ .WillRepeatedly(Return(0));
+ }
+ EXPECT_CALL(*data_array, GetSentenceStart(0))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*data_array, GetSentenceStart(1))
+ .WillRepeatedly(Return(19));
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i));
+ EXPECT_CALL(*data_array, GetWord(i))
+ .WillRepeatedly(Return(words[i]));
+ }
+
+ vector<int> suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9,
+ 11, 14, 16, 17};
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (size_t i = 0; i < suffixes.size(); ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).
+ WillRepeatedly(Return(suffixes[i]));
+ }
+
+ precomputation = make_shared<MockPrecomputation>();
+ EXPECT_CALL(*precomputation, GetCollocations())
+ .WillRepeatedly(ReturnRef(collocations));
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ intersector = make_shared<FastIntersector>(suffix_array, precomputation,
+ vocabulary, 15, 1);
+ }
+
+ Index collocations;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<MockPrecomputation> precomputation;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<FastIntersector> intersector;
+ shared_ptr<PhraseBuilder> phrase_builder;
+};
+
+TEST_F(FastIntersectorTest, TestCachedCollocation) {
+ vector<int> symbols = {8, -1, 9};
+ vector<int> expected_location = {11};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_location(15, 16), suffix_location(16, 17);
+
+ collocations[symbols] = expected_location;
+ EXPECT_CALL(*precomputation, GetCollocations())
+ .WillRepeatedly(ReturnRef(collocations));
+ intersector = make_shared<FastIntersector>(suffix_array, precomputation,
+ vocabulary, 15, 1);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+
+ EXPECT_EQ(PhraseLocation(expected_location, 2), result);
+ EXPECT_EQ(PhraseLocation(15, 16), prefix_location);
+ EXPECT_EQ(PhraseLocation(16, 17), suffix_location);
+}
+
+TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) {
+ vector<int> symbols = {1, -1, 3, -1, 1};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
+ 8, 15, 3, 15};
+ vector<int> suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13};
+ PhraseLocation prefix_location(prefix_locs, 2);
+ PhraseLocation suffix_location(suffix_locs, 2);
+
+ vector<int> expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8,
+ 4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13,
+ 0, 10, 13};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 3), result);
+}
+
+TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) {
+ vector<int> symbols = {1, -1, 3};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_location(1, 5), suffix_location(6, 10);
+
+ vector<int> expected_prefix_locs = {0, 4, 8, 13};
+ vector<int> expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
+ 8, 15, 13, 15};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
+ EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location);
+}
+
+TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) {
+ // The suffix matches in fewer positions, but because it starts with an X
+ // it requires more operations and we prefer extending the prefix.
+ vector<int> symbols = {1, -1, 4, 1};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12};
+ PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12);
+
+ vector<int> expected_locs = {0, 3, 0, 12, 4, 12, 8, 12};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
+ EXPECT_EQ(PhraseLocation(10, 12), suffix_location);
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc
new file mode 100644
index 00000000..db0385e0
--- /dev/null
+++ b/extractor/features/count_source_target.cc
@@ -0,0 +1,17 @@
+#include "count_source_target.h"
+
+#include <cmath>
+
+namespace extractor {
+namespace features {
+
+double CountSourceTarget::Score(const FeatureContext& context) const {
+ return log10(1 + context.pair_count);
+}
+
+string CountSourceTarget::GetName() const {
+ return "CountEF";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h
new file mode 100644
index 00000000..8747fa60
--- /dev/null
+++ b/extractor/features/count_source_target.h
@@ -0,0 +1,22 @@
+#ifndef _COUNT_SOURCE_TARGET_H_
+#define _COUNT_SOURCE_TARGET_H_
+
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+/**
+ * Feature for the number of times a word pair was found in the bitext.
+ */
+class CountSourceTarget : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc
new file mode 100644
index 00000000..1fd0c2aa
--- /dev/null
+++ b/extractor/features/count_source_target_test.cc
@@ -0,0 +1,36 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "count_source_target.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class CountSourceTargetTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<CountSourceTarget>();
+ }
+
+ shared_ptr<CountSourceTarget> feature;
+};
+
+TEST_F(CountSourceTargetTest, TestGetName) {
+ EXPECT_EQ("CountEF", feature->GetName());
+}
+
+TEST_F(CountSourceTargetTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 9, 13);
+ EXPECT_EQ(1.0, feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc
new file mode 100644
index 00000000..939bcc59
--- /dev/null
+++ b/extractor/features/feature.cc
@@ -0,0 +1,11 @@
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+const double Feature::MAX_SCORE = 99.0;
+
+Feature::~Feature() {}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/feature.h b/extractor/features/feature.h
new file mode 100644
index 00000000..36ea504a
--- /dev/null
+++ b/extractor/features/feature.h
@@ -0,0 +1,47 @@
+#ifndef _FEATURE_H_
+#define _FEATURE_H_
+
+#include <string>
+
+#include "phrase.h"
+
+using namespace std;
+
+namespace extractor {
+namespace features {
+
+/**
+ * Structure providing context for computing feature scores.
+ */
+struct FeatureContext {
+ FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase,
+ double source_phrase_count, int pair_count, int num_samples) :
+ source_phrase(source_phrase), target_phrase(target_phrase),
+ source_phrase_count(source_phrase_count), pair_count(pair_count),
+ num_samples(num_samples) {}
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ double source_phrase_count;
+ int pair_count;
+ int num_samples;
+};
+
+/**
+ * Base class for features.
+ */
+class Feature {
+ public:
+ virtual double Score(const FeatureContext& context) const = 0;
+
+ virtual string GetName() const = 0;
+
+ virtual ~Feature();
+
+ static const double MAX_SCORE;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc
new file mode 100644
index 00000000..1abb486f
--- /dev/null
+++ b/extractor/features/is_source_singleton.cc
@@ -0,0 +1,17 @@
+#include "is_source_singleton.h"
+
+#include <cmath>
+
+namespace extractor {
+namespace features {
+
+double IsSourceSingleton::Score(const FeatureContext& context) const {
+ return fabs(context.source_phrase_count - 1) < 1e-6;
+}
+
+string IsSourceSingleton::GetName() const {
+ return "IsSingletonF";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h
new file mode 100644
index 00000000..b8352d0e
--- /dev/null
+++ b/extractor/features/is_source_singleton.h
@@ -0,0 +1,22 @@
+#ifndef _IS_SOURCE_SINGLETON_H_
+#define _IS_SOURCE_SINGLETON_H_
+
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+/**
+ * Boolean feature checking if the source phrase occurs only once in the data.
+ */
+class IsSourceSingleton : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc
new file mode 100644
index 00000000..f4266671
--- /dev/null
+++ b/extractor/features/is_source_singleton_test.cc
@@ -0,0 +1,39 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "is_source_singleton.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class IsSourceSingletonTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<IsSourceSingleton>();
+ }
+
+ shared_ptr<IsSourceSingleton> feature;
+};
+
+TEST_F(IsSourceSingletonTest, TestGetName) {
+ EXPECT_EQ("IsSingletonF", feature->GetName());
+}
+
+TEST_F(IsSourceSingletonTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 3, 31);
+ EXPECT_EQ(0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 1, 3, 25);
+ EXPECT_EQ(1, feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc
new file mode 100644
index 00000000..03b3c62c
--- /dev/null
+++ b/extractor/features/is_source_target_singleton.cc
@@ -0,0 +1,17 @@
+#include "is_source_target_singleton.h"
+
+#include <cmath>
+
+namespace extractor {
+namespace features {
+
+double IsSourceTargetSingleton::Score(const FeatureContext& context) const {
+ return context.pair_count == 1;
+}
+
+string IsSourceTargetSingleton::GetName() const {
+ return "IsSingletonFE";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h
new file mode 100644
index 00000000..dacfebba
--- /dev/null
+++ b/extractor/features/is_source_target_singleton.h
@@ -0,0 +1,22 @@
+#ifndef _IS_SOURCE_TARGET_SINGLETON_H_
+#define _IS_SOURCE_TARGET_SINGLETON_H_
+
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+/**
+ * Boolean feature checking if the phrase pair occurs only once in the data.
+ */
+class IsSourceTargetSingleton : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc
new file mode 100644
index 00000000..929635b0
--- /dev/null
+++ b/extractor/features/is_source_target_singleton_test.cc
@@ -0,0 +1,39 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "is_source_target_singleton.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class IsSourceTargetSingletonTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<IsSourceTargetSingleton>();
+ }
+
+ shared_ptr<IsSourceTargetSingleton> feature;
+};
+
+TEST_F(IsSourceTargetSingletonTest, TestGetName) {
+ EXPECT_EQ("IsSingletonFE", feature->GetName());
+}
+
+TEST_F(IsSourceTargetSingletonTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 3, 7);
+ EXPECT_EQ(0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 2.3, 1, 28);
+ EXPECT_EQ(1, feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc
new file mode 100644
index 00000000..65d0ec68
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target.cc
@@ -0,0 +1,37 @@
+#include "max_lex_source_given_target.h"
+
+#include <cmath>
+
+#include "data_array.h"
+#include "translation_table.h"
+
+namespace extractor {
+namespace features {
+
+MaxLexSourceGivenTarget::MaxLexSourceGivenTarget(
+ shared_ptr<TranslationTable> table) :
+ table(table) {}
+
+double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const {
+ vector<string> source_words = context.source_phrase.GetWords();
+ vector<string> target_words = context.target_phrase.GetWords();
+ target_words.push_back(DataArray::NULL_WORD_STR);
+
+ double score = 0;
+ for (string source_word: source_words) {
+ double max_score = 0;
+ for (string target_word: target_words) {
+ max_score = max(max_score,
+ table->GetSourceGivenTargetScore(source_word, target_word));
+ }
+ score += max_score > 0 ? -log10(max_score) : MAX_SCORE;
+ }
+ return score;
+}
+
+string MaxLexSourceGivenTarget::GetName() const {
+ return "MaxLexFgivenE";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h
new file mode 100644
index 00000000..461b0ebf
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target.h
@@ -0,0 +1,34 @@
+#ifndef _MAX_LEX_SOURCE_GIVEN_TARGET_H_
+#define _MAX_LEX_SOURCE_GIVEN_TARGET_H_
+
+#include <memory>
+
+#include "feature.h"
+
+using namespace std;
+
+namespace extractor {
+
+class TranslationTable;
+
+namespace features {
+
+/**
+ * Feature computing max(p(f | e)) across all pairs of words in the phrase pair.
+ */
+class MaxLexSourceGivenTarget : public Feature {
+ public:
+ MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table);
+
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+
+ private:
+ shared_ptr<TranslationTable> table;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc
new file mode 100644
index 00000000..7f6aae41
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target_test.cc
@@ -0,0 +1,78 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "data_array.h"
+#include "mocks/mock_translation_table.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase_builder.h"
+#include "max_lex_source_given_target.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class MaxLexSourceGivenTargetTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> source_words = {"f1", "f2", "f3"};
+ vector<string> target_words = {"e1", "e2", "e3"};
+
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(source_words[i]));
+ }
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size()))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ table = make_shared<MockTranslationTable>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ for (size_t j = 0; j < target_words.size(); ++j) {
+ int value = i - j;
+ EXPECT_CALL(*table, GetSourceGivenTargetScore(
+ source_words[i], target_words[j])).WillRepeatedly(Return(value));
+ }
+ }
+
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ int value = i * 3;
+ EXPECT_CALL(*table, GetSourceGivenTargetScore(
+ source_words[i], DataArray::NULL_WORD_STR))
+ .WillRepeatedly(Return(value));
+ }
+
+ feature = make_shared<MaxLexSourceGivenTarget>(table);
+ }
+
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockTranslationTable> table;
+ shared_ptr<MaxLexSourceGivenTarget> feature;
+};
+
+TEST_F(MaxLexSourceGivenTargetTest, TestGetName) {
+ EXPECT_EQ("MaxLexFgivenE", feature->GetName());
+}
+
+TEST_F(MaxLexSourceGivenTargetTest, TestScore) {
+ vector<int> source_symbols = {0, 1, 2};
+ Phrase source_phrase = phrase_builder->Build(source_symbols);
+ vector<int> target_symbols = {3, 4, 5};
+ Phrase target_phrase = phrase_builder->Build(target_symbols);
+ FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11);
+ EXPECT_EQ(99 - log10(18), feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc
new file mode 100644
index 00000000..33783054
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source.cc
@@ -0,0 +1,37 @@
+#include "max_lex_target_given_source.h"
+
+#include <cmath>
+
+#include "data_array.h"
+#include "translation_table.h"
+
+namespace extractor {
+namespace features {
+
+MaxLexTargetGivenSource::MaxLexTargetGivenSource(
+ shared_ptr<TranslationTable> table) :
+ table(table) {}
+
+double MaxLexTargetGivenSource::Score(const FeatureContext& context) const {
+ vector<string> source_words = context.source_phrase.GetWords();
+ source_words.push_back(DataArray::NULL_WORD_STR);
+ vector<string> target_words = context.target_phrase.GetWords();
+
+ double score = 0;
+ for (string target_word: target_words) {
+ double max_score = 0;
+ for (string source_word: source_words) {
+ max_score = max(max_score,
+ table->GetTargetGivenSourceScore(source_word, target_word));
+ }
+ score += max_score > 0 ? -log10(max_score) : MAX_SCORE;
+ }
+ return score;
+}
+
+string MaxLexTargetGivenSource::GetName() const {
+ return "MaxLexEgivenF";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h
new file mode 100644
index 00000000..c3c87327
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source.h
@@ -0,0 +1,34 @@
+#ifndef _MAX_LEX_TARGET_GIVEN_SOURCE_H_
+#define _MAX_LEX_TARGET_GIVEN_SOURCE_H_
+
+#include <memory>
+
+#include "feature.h"
+
+using namespace std;
+
+namespace extractor {
+
+class TranslationTable;
+
+namespace features {
+
+/**
+ * Feature computing max(p(e | f)) across all pairs of words in the phrase pair.
+ */
+class MaxLexTargetGivenSource : public Feature {
+ public:
+ MaxLexTargetGivenSource(shared_ptr<TranslationTable> table);
+
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+
+ private:
+ shared_ptr<TranslationTable> table;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc
new file mode 100644
index 00000000..6d0efd9c
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source_test.cc
@@ -0,0 +1,78 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "data_array.h"
+#include "mocks/mock_translation_table.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase_builder.h"
+#include "max_lex_target_given_source.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class MaxLexTargetGivenSourceTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> source_words = {"f1", "f2", "f3"};
+ vector<string> target_words = {"e1", "e2", "e3"};
+
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(source_words[i]));
+ }
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size()))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ table = make_shared<MockTranslationTable>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ for (size_t j = 0; j < target_words.size(); ++j) {
+ int value = i - j;
+ EXPECT_CALL(*table, GetTargetGivenSourceScore(
+ source_words[i], target_words[j])).WillRepeatedly(Return(value));
+ }
+ }
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ int value = i * 3;
+ EXPECT_CALL(*table, GetTargetGivenSourceScore(
+ DataArray::NULL_WORD_STR, target_words[i]))
+ .WillRepeatedly(Return(value));
+ }
+
+ feature = make_shared<MaxLexTargetGivenSource>(table);
+ }
+
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockTranslationTable> table;
+ shared_ptr<MaxLexTargetGivenSource> feature;
+};
+
+TEST_F(MaxLexTargetGivenSourceTest, TestGetName) {
+ EXPECT_EQ("MaxLexEgivenF", feature->GetName());
+}
+
+TEST_F(MaxLexTargetGivenSourceTest, TestScore) {
+ vector<int> source_symbols = {0, 1, 2};
+ Phrase source_phrase = phrase_builder->Build(source_symbols);
+ vector<int> target_symbols = {3, 4, 5};
+ Phrase target_phrase = phrase_builder->Build(target_symbols);
+ FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19);
+ EXPECT_EQ(-log10(36), feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc
new file mode 100644
index 00000000..b110fc51
--- /dev/null
+++ b/extractor/features/sample_source_count.cc
@@ -0,0 +1,17 @@
+#include "sample_source_count.h"
+
+#include <cmath>
+
+namespace extractor {
+namespace features {
+
+double SampleSourceCount::Score(const FeatureContext& context) const {
+ return log10(1 + context.num_samples);
+}
+
+string SampleSourceCount::GetName() const {
+ return "SampleCountF";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h
new file mode 100644
index 00000000..ee6e59a0
--- /dev/null
+++ b/extractor/features/sample_source_count.h
@@ -0,0 +1,23 @@
+#ifndef _SAMPLE_SOURCE_COUNT_H_
+#define _SAMPLE_SOURCE_COUNT_H_
+
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+/**
+ * Feature scoring the number of times the source phrase occurs in the sampled
+ * set.
+ */
+class SampleSourceCount : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc
new file mode 100644
index 00000000..63856b9d
--- /dev/null
+++ b/extractor/features/sample_source_count_test.cc
@@ -0,0 +1,40 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "sample_source_count.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class SampleSourceCountTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<SampleSourceCount>();
+ }
+
+ shared_ptr<SampleSourceCount> feature;
+};
+
+TEST_F(SampleSourceCountTest, TestGetName) {
+ EXPECT_EQ("SampleCountF", feature->GetName());
+}
+
+TEST_F(SampleSourceCountTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0, 3, 1);
+ EXPECT_EQ(log10(2), feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 3.2, 3, 9);
+ EXPECT_EQ(1.0, feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc
new file mode 100644
index 00000000..c4551d88
--- /dev/null
+++ b/extractor/features/target_given_source_coherent.cc
@@ -0,0 +1,18 @@
+#include "target_given_source_coherent.h"
+
+#include <cmath>
+
+namespace extractor {
+namespace features {
+
+double TargetGivenSourceCoherent::Score(const FeatureContext& context) const {
+ double prob = (double) context.pair_count / context.num_samples;
+ return prob > 0 ? -log10(prob) : MAX_SCORE;
+}
+
+string TargetGivenSourceCoherent::GetName() const {
+ return "EgivenFCoherent";
+}
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h
new file mode 100644
index 00000000..e66d70a5
--- /dev/null
+++ b/extractor/features/target_given_source_coherent.h
@@ -0,0 +1,23 @@
+#ifndef _TARGET_GIVEN_SOURCE_COHERENT_H_
+#define _TARGET_GIVEN_SOURCE_COHERENT_H_
+
+#include "feature.h"
+
+namespace extractor {
+namespace features {
+
+/**
+ * Feature computing the ratio of the phrase pair count over all source phrase
+ * occurrences (sampled).
+ */
+class TargetGivenSourceCoherent : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+} // namespace features
+} // namespace extractor
+
+#endif
diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc
new file mode 100644
index 00000000..454105e1
--- /dev/null
+++ b/extractor/features/target_given_source_coherent_test.cc
@@ -0,0 +1,39 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "target_given_source_coherent.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace features {
+namespace {
+
+class TargetGivenSourceCoherentTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<TargetGivenSourceCoherent>();
+ }
+
+ shared_ptr<TargetGivenSourceCoherent> feature;
+};
+
+TEST_F(TargetGivenSourceCoherentTest, TestGetName) {
+ EXPECT_EQ("EgivenFCoherent", feature->GetName());
+}
+
+TEST_F(TargetGivenSourceCoherentTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.3, 2, 20);
+ EXPECT_EQ(1.0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 1.9, 0, 1);
+ EXPECT_EQ(99.0, feature->Score(context));
+}
+
+} // namespace
+} // namespace features
+} // namespace extractor
diff --git a/extractor/grammar.cc b/extractor/grammar.cc
new file mode 100644
index 00000000..b45a8261
--- /dev/null
+++ b/extractor/grammar.cc
@@ -0,0 +1,43 @@
+#include "grammar.h"
+
+#include <iomanip>
+
+#include "rule.h"
+
+using namespace std;
+
+namespace extractor {
+
+Grammar::Grammar(const vector<Rule>& rules,
+ const vector<string>& feature_names) :
+ rules(rules), feature_names(feature_names) {}
+
+vector<Rule> Grammar::GetRules() const {
+ return rules;
+}
+
+vector<string> Grammar::GetFeatureNames() const {
+ return feature_names;
+}
+
+ostream& operator<<(ostream& os, const Grammar& grammar) {
+ vector<Rule> rules = grammar.GetRules();
+ vector<string> feature_names = grammar.GetFeatureNames();
+ os << setprecision(12);
+ for (Rule rule: rules) {
+ os << "[X] ||| " << rule.source_phrase << " ||| "
+ << rule.target_phrase << " |||";
+ for (size_t i = 0; i < rule.scores.size(); ++i) {
+ os << " " << feature_names[i] << "=" << rule.scores[i];
+ }
+ os << " |||";
+ for (auto link: rule.alignment) {
+ os << " " << link.first << "-" << link.second;
+ }
+ os << '\n';
+ }
+
+ return os;
+}
+
+} // namespace extractor
diff --git a/extractor/grammar.h b/extractor/grammar.h
new file mode 100644
index 00000000..fed41b16
--- /dev/null
+++ b/extractor/grammar.h
@@ -0,0 +1,34 @@
+#ifndef _GRAMMAR_H_
+#define _GRAMMAR_H_
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+class Rule;
+
+/**
+ * Grammar class wrapping the set of rules to be extracted.
+ */
+class Grammar {
+ public:
+ Grammar(const vector<Rule>& rules, const vector<string>& feature_names);
+
+ vector<Rule> GetRules() const;
+
+ vector<string> GetFeatureNames() const;
+
+ friend ostream& operator<<(ostream& os, const Grammar& grammar);
+
+ private:
+ vector<Rule> rules;
+ vector<string> feature_names;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
new file mode 100644
index 00000000..8050ce7b
--- /dev/null
+++ b/extractor/grammar_extractor.cc
@@ -0,0 +1,62 @@
+#include "grammar_extractor.h"
+
+#include <iterator>
+#include <sstream>
+#include <vector>
+
+#include "grammar.h"
+#include "rule.h"
+#include "rule_factory.h"
+#include "vocabulary.h"
+
+using namespace std;
+
+namespace extractor {
+
+GrammarExtractor::GrammarExtractor(
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
+ int max_nonterminals, int max_rule_symbols, int max_samples,
+ bool require_tight_phrases) :
+ vocabulary(make_shared<Vocabulary>()),
+ rule_factory(make_shared<HieroCachingRuleFactory>(
+ source_suffix_array, target_data_array, alignment, vocabulary,
+ precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
+ max_rule_symbols, max_samples, require_tight_phrases)) {}
+
+GrammarExtractor::GrammarExtractor(
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<HieroCachingRuleFactory> rule_factory) :
+ vocabulary(vocabulary),
+ rule_factory(rule_factory) {}
+
+Grammar GrammarExtractor::GetGrammar(const string& sentence) {
+ vector<string> words = TokenizeSentence(sentence);
+ vector<int> word_ids = AnnotateWords(words);
+ return rule_factory->GetGrammar(word_ids);
+}
+
+vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
+ vector<string> result;
+ result.push_back("<s>");
+
+ istringstream buffer(sentence);
+ copy(istream_iterator<string>(buffer),
+ istream_iterator<string>(),
+ back_inserter(result));
+
+ result.push_back("</s>");
+ return result;
+}
+
+vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) {
+ vector<int> result;
+ for (string word: words) {
+ result.push_back(vocabulary->GetTerminalIndex(word));
+ }
+ return result;
+}
+
+} // namespace extractor
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
new file mode 100644
index 00000000..b36ceeb9
--- /dev/null
+++ b/extractor/grammar_extractor.h
@@ -0,0 +1,62 @@
+#ifndef _GRAMMAR_EXTRACTOR_H_
+#define _GRAMMAR_EXTRACTOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+class Alignment;
+class DataArray;
+class Grammar;
+class HieroCachingRuleFactory;
+class Precomputation;
+class Rule;
+class Scorer;
+class SuffixArray;
+class Vocabulary;
+
+/**
+ * Class wrapping all the logic for extracting the synchronous context free
+ * grammars.
+ */
+class GrammarExtractor {
+ public:
+ GrammarExtractor(
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ int max_samples,
+ bool require_tight_phrases);
+
+ // For testing only.
+ GrammarExtractor(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<HieroCachingRuleFactory> rule_factory);
+
+ // Converts the sentence to a vector of word ids and uses the RuleFactory to
+ // extract the SCFG rules which may be used to decode the sentence.
+ Grammar GetGrammar(const string& sentence);
+
+ private:
+ // Splits the sentence in a vector of words.
+ vector<string> TokenizeSentence(const string& sentence);
+
+ // Maps the words to word ids.
+ vector<int> AnnotateWords(const vector<string>& words);
+
+ shared_ptr<Vocabulary> vocabulary;
+ shared_ptr<HieroCachingRuleFactory> rule_factory;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
new file mode 100644
index 00000000..823bb8b4
--- /dev/null
+++ b/extractor/grammar_extractor_test.cc
@@ -0,0 +1,51 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "mocks/mock_rule_factory.h"
+#include "mocks/mock_vocabulary.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+TEST(GrammarExtractorTest, TestAnnotatingWords) {
+ shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>"))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna"))
+ .WillRepeatedly(Return(1));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("has"))
+ .WillRepeatedly(Return(2));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("many"))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("apples"))
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("."))
+ .WillRepeatedly(Return(5));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>"))
+ .WillRepeatedly(Return(6));
+
+ shared_ptr<MockHieroCachingRuleFactory> factory =
+ make_shared<MockHieroCachingRuleFactory>();
+ vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6};
+ vector<Rule> rules;
+ vector<string> feature_names;
+ Grammar grammar(rules, feature_names);
+ EXPECT_CALL(*factory, GetGrammar(word_ids))
+ .WillOnce(Return(grammar));
+
+ GrammarExtractor extractor(vocabulary, factory);
+ string sentence = "Anna has many many apples .";
+ extractor.GetGrammar(sentence);
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc
new file mode 100644
index 00000000..ceed6891
--- /dev/null
+++ b/extractor/matchings_finder.cc
@@ -0,0 +1,25 @@
+#include "matchings_finder.h"
+
+#include "suffix_array.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+MatchingsFinder::MatchingsFinder(shared_ptr<SuffixArray> suffix_array) :
+ suffix_array(suffix_array) {}
+
+MatchingsFinder::MatchingsFinder() {}
+
+MatchingsFinder::~MatchingsFinder() {}
+
+PhraseLocation MatchingsFinder::Find(PhraseLocation& location,
+ const string& word, int offset) {
+ if (location.sa_low == -1 && location.sa_high == -1) {
+ location.sa_low = 0;
+ location.sa_high = suffix_array->GetSize();
+ }
+
+ return suffix_array->Lookup(location.sa_low, location.sa_high, word, offset);
+}
+
+} // namespace extractor
diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h
new file mode 100644
index 00000000..451f4a4c
--- /dev/null
+++ b/extractor/matchings_finder.h
@@ -0,0 +1,37 @@
+#ifndef _MATCHINGS_FINDER_H_
+#define _MATCHINGS_FINDER_H_
+
+#include <memory>
+#include <string>
+
+using namespace std;
+
+namespace extractor {
+
+class PhraseLocation;
+class SuffixArray;
+
+/**
+ * Class wrapping the suffix array lookup for a contiguous phrase.
+ */
+class MatchingsFinder {
+ public:
+ MatchingsFinder(shared_ptr<SuffixArray> suffix_array);
+
+ virtual ~MatchingsFinder();
+
+ // Uses the suffix array to search only for the last word of the phrase
+ // starting from the range in which the prefix of the phrase occurs.
+ virtual PhraseLocation Find(PhraseLocation& location, const string& word,
+ int offset);
+
+ protected:
+ MatchingsFinder();
+
+ private:
+ shared_ptr<SuffixArray> suffix_array;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/matchings_finder_test.cc b/extractor/matchings_finder_test.cc
new file mode 100644
index 00000000..d40e5191
--- /dev/null
+++ b/extractor/matchings_finder_test.cc
@@ -0,0 +1,44 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "matchings_finder.h"
+#include "mocks/mock_suffix_array.h"
+#include "phrase_location.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class MatchingsFinderTest : public Test {
+ protected:
+ virtual void SetUp() {
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, Lookup(0, 10, _, _))
+ .Times(1)
+ .WillOnce(Return(PhraseLocation(3, 5)));
+
+ matchings_finder = make_shared<MatchingsFinder>(suffix_array);
+ }
+
+ shared_ptr<MatchingsFinder> matchings_finder;
+ shared_ptr<MockSuffixArray> suffix_array;
+};
+
+TEST_F(MatchingsFinderTest, TestFind) {
+ PhraseLocation phrase_location(0, 10), expected_result(3, 5);
+ EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2));
+}
+
+TEST_F(MatchingsFinderTest, ResizeUnsetRange) {
+ EXPECT_CALL(*suffix_array, GetSize()).Times(1).WillOnce(Return(10));
+
+ PhraseLocation phrase_location, expected_result(3, 5);
+ EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2));
+ EXPECT_EQ(PhraseLocation(0, 10), phrase_location);
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc
new file mode 100644
index 00000000..7fb7a529
--- /dev/null
+++ b/extractor/matchings_trie.cc
@@ -0,0 +1,29 @@
+#include "matchings_trie.h"
+
+namespace extractor {
+
+MatchingsTrie::MatchingsTrie() {
+ root = make_shared<TrieNode>();
+}
+
+MatchingsTrie::~MatchingsTrie() {
+ DeleteTree(root);
+}
+
+shared_ptr<TrieNode> MatchingsTrie::GetRoot() const {
+ return root;
+}
+
+void MatchingsTrie::DeleteTree(shared_ptr<TrieNode> root) {
+ if (root != NULL) {
+ for (auto child: root->children) {
+ DeleteTree(child.second);
+ }
+ if (root->suffix_link != NULL) {
+ root->suffix_link.reset();
+ }
+ root.reset();
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h
new file mode 100644
index 00000000..1fb29693
--- /dev/null
+++ b/extractor/matchings_trie.h
@@ -0,0 +1,66 @@
+#ifndef _MATCHINGS_TRIE_
+#define _MATCHINGS_TRIE_
+
+#include <memory>
+#include <unordered_map>
+
+#include "phrase.h"
+#include "phrase_location.h"
+
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Trie node containing all the occurrences of the corresponding phrase in the
+ * source data.
+ */
+struct TrieNode {
+ TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(),
+ Phrase phrase = Phrase(),
+ PhraseLocation matchings = PhraseLocation()) :
+ suffix_link(suffix_link), phrase(phrase), matchings(matchings) {}
+
+ // Adds a trie node as a child of the current node.
+ void AddChild(int key, shared_ptr<TrieNode> child_node) {
+ children[key] = child_node;
+ }
+
+ // Checks if a child exists for a given key.
+ bool HasChild(int key) {
+ return children.count(key);
+ }
+
+ // Gets the child corresponding to the given key.
+ shared_ptr<TrieNode> GetChild(int key) {
+ return children[key];
+ }
+
+ shared_ptr<TrieNode> suffix_link;
+ Phrase phrase;
+ PhraseLocation matchings;
+ unordered_map<int, shared_ptr<TrieNode> > children;
+};
+
+/**
+ * Trie containing all the phrases that can be obtained from a sentence.
+ */
+class MatchingsTrie {
+ public:
+ MatchingsTrie();
+
+ virtual ~MatchingsTrie();
+
+ // Returns the root of the trie.
+ shared_ptr<TrieNode> GetRoot() const;
+
+ private:
+ // Recursively deletes a subtree of the trie.
+ void DeleteTree(shared_ptr<TrieNode> root);
+
+ shared_ptr<TrieNode> root;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h
new file mode 100644
index 00000000..299c3d1c
--- /dev/null
+++ b/extractor/mocks/mock_alignment.h
@@ -0,0 +1,14 @@
+#include <gmock/gmock.h>
+
+#include "alignment.h"
+
+namespace extractor {
+
+typedef vector<pair<int, int> > SentenceLinks;
+
+class MockAlignment : public Alignment {
+ public:
+ MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
new file mode 100644
index 00000000..6f85abb4
--- /dev/null
+++ b/extractor/mocks/mock_data_array.h
@@ -0,0 +1,23 @@
+#include <gmock/gmock.h>
+
+#include "data_array.h"
+
+namespace extractor {
+
+class MockDataArray : public DataArray {
+ public:
+ MOCK_CONST_METHOD0(GetData, const vector<int>&());
+ MOCK_CONST_METHOD1(AtIndex, int(int index));
+ MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
+ MOCK_CONST_METHOD0(GetSize, int());
+ MOCK_CONST_METHOD0(GetVocabularySize, int());
+ MOCK_CONST_METHOD1(HasWord, bool(const string& word));
+ MOCK_CONST_METHOD1(GetWordId, int(const string& word));
+ MOCK_CONST_METHOD1(GetWord, string(int word_id));
+ MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
+ MOCK_CONST_METHOD0(GetNumSentences, int());
+ MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id));
+ MOCK_CONST_METHOD1(GetSentenceId, int(int position));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h
new file mode 100644
index 00000000..f0b628d7
--- /dev/null
+++ b/extractor/mocks/mock_fast_intersector.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "fast_intersector.h"
+#include "phrase.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+class MockFastIntersector : public FastIntersector {
+ public:
+ MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&,
+ const Phrase&));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h
new file mode 100644
index 00000000..0b0f0ead
--- /dev/null
+++ b/extractor/mocks/mock_feature.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "features/feature.h"
+
+namespace extractor {
+namespace features {
+
+class MockFeature : public Feature {
+ public:
+ MOCK_CONST_METHOD1(Score, double(const FeatureContext& context));
+ MOCK_CONST_METHOD0(GetName, string());
+};
+
+} // namespace features
+} // namespace extractor
diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h
new file mode 100644
index 00000000..827526fd
--- /dev/null
+++ b/extractor/mocks/mock_matchings_finder.h
@@ -0,0 +1,13 @@
+#include <gmock/gmock.h>
+
+#include "matchings_finder.h"
+#include "phrase_location.h"
+
+namespace extractor {
+
+class MockMatchingsFinder : public MatchingsFinder {
+ public:
+ MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h
new file mode 100644
index 00000000..8753343e
--- /dev/null
+++ b/extractor/mocks/mock_precomputation.h
@@ -0,0 +1,12 @@
+#include <gmock/gmock.h>
+
+#include "precomputation.h"
+
+namespace extractor {
+
+class MockPrecomputation : public Precomputation {
+ public:
+ MOCK_CONST_METHOD0(GetCollocations, const Index&());
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h
new file mode 100644
index 00000000..aad11651
--- /dev/null
+++ b/extractor/mocks/mock_rule_extractor.h
@@ -0,0 +1,16 @@
+#include <gmock/gmock.h>
+
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "rule.h"
+#include "rule_extractor.h"
+
+namespace extractor {
+
+class MockRuleExtractor : public RuleExtractor {
+ public:
+ MOCK_CONST_METHOD2(ExtractRules, vector<Rule>(const Phrase&,
+ const PhraseLocation&));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h
new file mode 100644
index 00000000..468468f6
--- /dev/null
+++ b/extractor/mocks/mock_rule_extractor_helper.h
@@ -0,0 +1,82 @@
+#include <gmock/gmock.h>
+
+#include <vector>
+
+#include "rule_extractor_helper.h"
+
+using namespace std;
+
+namespace extractor {
+
+typedef unordered_map<int, int> Indexes;
+
+class MockRuleExtractorHelper : public RuleExtractorHelper {
+ public:
+ MOCK_CONST_METHOD5(GetLinksSpans, void(vector<int>&, vector<int>&,
+ vector<int>&, vector<int>&, int));
+ MOCK_CONST_METHOD4(CheckAlignedTerminals, bool(const vector<int>&,
+ const vector<int>&, const vector<int>&, int));
+ MOCK_CONST_METHOD4(CheckTightPhrases, bool(const vector<int>&,
+ const vector<int>&, const vector<int>&, int));
+ MOCK_CONST_METHOD1(GetGapOrder, vector<int>(const vector<pair<int, int> >&));
+ MOCK_CONST_METHOD4(GetSourceIndexes, Indexes(const vector<int>&,
+ const vector<int>&, int, int));
+
+ // We need to implement these methods, because Google Mock doesn't support
+ // methods with more than 10 arguments.
+ bool FindFixPoint(
+ int, int, const vector<int>&, const vector<int>&, int& target_phrase_low,
+ int& target_phrase_high, const vector<int>&, const vector<int>&,
+ int& source_back_low, int& source_back_high, int, int, int, int, bool,
+ bool, bool) const {
+ target_phrase_low = this->target_phrase_low;
+ target_phrase_high = this->target_phrase_high;
+ source_back_low = this->source_back_low;
+ source_back_high = this->source_back_high;
+ return find_fix_point;
+ }
+
+ bool GetGaps(vector<pair<int, int> >& source_gaps,
+ vector<pair<int, int> >& target_gaps,
+ const vector<int>&, const vector<int>&, const vector<int>&,
+ const vector<int>&, const vector<int>&, const vector<int>&,
+ int, int, int, int, int, int, int& num_symbols,
+ bool& met_constraints) const {
+ source_gaps = this->source_gaps;
+ target_gaps = this->target_gaps;
+ num_symbols = this->num_symbols;
+ met_constraints = this->met_constraints;
+ return get_gaps;
+ }
+
+ void SetUp(
+ int target_phrase_low, int target_phrase_high, int source_back_low,
+ int source_back_high, bool find_fix_point,
+ vector<pair<int, int> > source_gaps, vector<pair<int, int> > target_gaps,
+ int num_symbols, bool met_constraints, bool get_gaps) {
+ this->target_phrase_low = target_phrase_low;
+ this->target_phrase_high = target_phrase_high;
+ this->source_back_low = source_back_low;
+ this->source_back_high = source_back_high;
+ this->find_fix_point = find_fix_point;
+ this->source_gaps = source_gaps;
+ this->target_gaps = target_gaps;
+ this->num_symbols = num_symbols;
+ this->met_constraints = met_constraints;
+ this->get_gaps = get_gaps;
+ }
+
+ private:
+ int target_phrase_low;
+ int target_phrase_high;
+ int source_back_low;
+ int source_back_high;
+ bool find_fix_point;
+ vector<pair<int, int> > source_gaps;
+ vector<pair<int, int> > target_gaps;
+ int num_symbols;
+ bool met_constraints;
+ bool get_gaps;
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
new file mode 100644
index 00000000..7389b396
--- /dev/null
+++ b/extractor/mocks/mock_rule_factory.h
@@ -0,0 +1,13 @@
+#include <gmock/gmock.h>
+
+#include "grammar.h"
+#include "rule_factory.h"
+
+namespace extractor {
+
+class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
+ public:
+ MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
new file mode 100644
index 00000000..75c43c27
--- /dev/null
+++ b/extractor/mocks/mock_sampler.h
@@ -0,0 +1,13 @@
+#include <gmock/gmock.h>
+
+#include "phrase_location.h"
+#include "sampler.h"
+
+namespace extractor {
+
+class MockSampler : public Sampler {
+ public:
+ MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h
new file mode 100644
index 00000000..cc0c444d
--- /dev/null
+++ b/extractor/mocks/mock_scorer.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include "scorer.h"
+#include "features/feature.h"
+
+namespace extractor {
+
+class MockScorer : public Scorer {
+ public:
+ MOCK_CONST_METHOD1(Score, vector<double>(
+ const features::FeatureContext& context));
+ MOCK_CONST_METHOD0(GetFeatureNames, vector<string>());
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h
new file mode 100644
index 00000000..7018acc7
--- /dev/null
+++ b/extractor/mocks/mock_suffix_array.h
@@ -0,0 +1,23 @@
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+
+#include "data_array.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+using namespace std;
+
+namespace extractor {
+
+class MockSuffixArray : public SuffixArray {
+ public:
+ MOCK_CONST_METHOD0(GetSize, int());
+ MOCK_CONST_METHOD0(GetData, shared_ptr<DataArray>());
+ MOCK_CONST_METHOD0(BuildLCPArray, vector<int>());
+ MOCK_CONST_METHOD1(GetSuffix, int(int));
+ MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h
new file mode 100644
index 00000000..6aad853c
--- /dev/null
+++ b/extractor/mocks/mock_target_phrase_extractor.h
@@ -0,0 +1,16 @@
+#include <gmock/gmock.h>
+
+#include "target_phrase_extractor.h"
+
+namespace extractor {
+
+typedef pair<Phrase, PhraseAlignment> PhraseExtract;
+
+class MockTargetPhraseExtractor : public TargetPhraseExtractor {
+ public:
+ MOCK_CONST_METHOD6(ExtractPhrases, vector<PhraseExtract>(
+ const vector<pair<int, int> > &, const vector<int>&, int, int,
+ const unordered_map<int, int>&, int));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h
new file mode 100644
index 00000000..307e4282
--- /dev/null
+++ b/extractor/mocks/mock_translation_table.h
@@ -0,0 +1,13 @@
+#include <gmock/gmock.h>
+
+#include "translation_table.h"
+
+namespace extractor {
+
+class MockTranslationTable : public TranslationTable {
+ public:
+ MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&));
+ MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&));
+};
+
+} // namespace extractor
diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h
new file mode 100644
index 00000000..042c9ce2
--- /dev/null
+++ b/extractor/mocks/mock_vocabulary.h
@@ -0,0 +1,13 @@
+#include <gmock/gmock.h>
+
+#include "vocabulary.h"
+
+namespace extractor {
+
+class MockVocabulary : public Vocabulary {
+ public:
+ MOCK_METHOD1(GetTerminalValue, string(int word_id));
+ MOCK_METHOD1(GetTerminalIndex, int(const string& word));
+};
+
+} // namespace extractor
diff --git a/extractor/phrase.cc b/extractor/phrase.cc
new file mode 100644
index 00000000..e619bfe5
--- /dev/null
+++ b/extractor/phrase.cc
@@ -0,0 +1,58 @@
+#include "phrase.h"
+
+namespace extractor {
+
+int Phrase::Arity() const {
+ return var_pos.size();
+}
+
+int Phrase::GetChunkLen(int index) const {
+ if (var_pos.size() == 0) {
+ return symbols.size();
+ } else if (index == 0) {
+ return var_pos[0];
+ } else if (index == var_pos.size()) {
+ return symbols.size() - var_pos.back() - 1;
+ } else {
+ return var_pos[index] - var_pos[index - 1] - 1;
+ }
+}
+
+vector<int> Phrase::Get() const {
+ return symbols;
+}
+
+int Phrase::GetSymbol(int position) const {
+ return symbols[position];
+}
+
+int Phrase::GetNumSymbols() const {
+ return symbols.size();
+}
+
+vector<string> Phrase::GetWords() const {
+ return words;
+}
+
+bool Phrase::operator<(const Phrase& other) const {
+ return symbols < other.symbols;
+}
+
+ostream& operator<<(ostream& os, const Phrase& phrase) {
+ int current_word = 0;
+ for (size_t i = 0; i < phrase.symbols.size(); ++i) {
+ if (phrase.symbols[i] < 0) {
+ os << "[X," << -phrase.symbols[i] << "]";
+ } else {
+ os << phrase.words[current_word];
+ ++current_word;
+ }
+
+ if (i + 1 < phrase.symbols.size()) {
+ os << " ";
+ }
+ }
+ return os;
+}
+
+} // namspace extractor
diff --git a/extractor/phrase.h b/extractor/phrase.h
new file mode 100644
index 00000000..a8e91e3c
--- /dev/null
+++ b/extractor/phrase.h
@@ -0,0 +1,52 @@
+#ifndef _PHRASE_H_
+#define _PHRASE_H_
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "phrase_builder.h"
+
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Structure containing the data for a phrase.
+ */
+class Phrase {
+ public:
+ friend Phrase PhraseBuilder::Build(const vector<int>& phrase);
+
+ // Returns the number of nonterminals in the phrase.
+ int Arity() const;
+
+ // Returns the number of terminals (length) for the given chunk. (A chunk is a
+ // contiguous sequence of terminals in the phrase).
+ int GetChunkLen(int index) const;
+
+ // Returns the symbols (word ids) marking up the phrase.
+ vector<int> Get() const;
+
+ // Returns the symbol located at the given position in the phrase.
+ int GetSymbol(int position) const;
+
+ // Returns the number of symbols in the phrase.
+ int GetNumSymbols() const;
+
+ // Returns the words making up the phrase. (Nonterminals are stripped out.)
+ vector<string> GetWords() const;
+
+ bool operator<(const Phrase& other) const;
+
+ friend ostream& operator<<(ostream& os, const Phrase& phrase);
+
+ private:
+ vector<int> symbols;
+ vector<int> var_pos;
+ vector<string> words;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc
new file mode 100644
index 00000000..9faee4be
--- /dev/null
+++ b/extractor/phrase_builder.cc
@@ -0,0 +1,48 @@
+#include "phrase_builder.h"
+
+#include "phrase.h"
+#include "vocabulary.h"
+
+namespace extractor {
+
+PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) :
+ vocabulary(vocabulary) {}
+
+Phrase PhraseBuilder::Build(const vector<int>& symbols) {
+ Phrase phrase;
+ phrase.symbols = symbols;
+ for (size_t i = 0; i < symbols.size(); ++i) {
+ if (vocabulary->IsTerminal(symbols[i])) {
+ phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i]));
+ } else {
+ phrase.var_pos.push_back(i);
+ }
+ }
+ return phrase;
+}
+
+Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) {
+ vector<int> symbols = phrase.Get();
+ int num_nonterminals = 0;
+ if (start_x) {
+ num_nonterminals = 1;
+ symbols.insert(symbols.begin(),
+ vocabulary->GetNonterminalIndex(num_nonterminals));
+ }
+
+ for (size_t i = start_x; i < symbols.size(); ++i) {
+ if (!vocabulary->IsTerminal(symbols[i])) {
+ ++num_nonterminals;
+ symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals);
+ }
+ }
+
+ if (end_x) {
+ ++num_nonterminals;
+ symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals));
+ }
+
+ return Build(symbols);
+}
+
+} // namespace extractor
diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h
new file mode 100644
index 00000000..de86dbae
--- /dev/null
+++ b/extractor/phrase_builder.h
@@ -0,0 +1,33 @@
+#ifndef _PHRASE_BUILDER_H_
+#define _PHRASE_BUILDER_H_
+
+#include <memory>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+class Phrase;
+class Vocabulary;
+
+/**
+ * Component for constructing phrases.
+ */
+class PhraseBuilder {
+ public:
+ PhraseBuilder(shared_ptr<Vocabulary> vocabulary);
+
+ // Constructs a phrase starting from an array of symbols.
+ Phrase Build(const vector<int>& symbols);
+
+ // Extends a phrase with a leading and/or trailing nonterminal.
+ Phrase Extend(const Phrase& phrase, bool start_x, bool end_x);
+
+ private:
+ shared_ptr<Vocabulary> vocabulary;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
new file mode 100644
index 00000000..678ae270
--- /dev/null
+++ b/extractor/phrase_location.cc
@@ -0,0 +1,43 @@
+#include "phrase_location.h"
+
+namespace extractor {
+
+PhraseLocation::PhraseLocation(int sa_low, int sa_high) :
+ sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {}
+
+PhraseLocation::PhraseLocation(const vector<int>& matchings,
+ int num_subpatterns) :
+ sa_low(0), sa_high(0),
+ matchings(make_shared<vector<int> >(matchings)),
+ num_subpatterns(num_subpatterns) {}
+
+bool PhraseLocation::IsEmpty() const {
+ return GetSize() == 0;
+}
+
+int PhraseLocation::GetSize() const {
+ if (num_subpatterns > 0) {
+ return matchings->size();
+ } else {
+ return sa_high - sa_low;
+ }
+}
+
+bool operator==(const PhraseLocation& a, const PhraseLocation& b) {
+ if (a.sa_low != b.sa_low || a.sa_high != b.sa_high ||
+ a.num_subpatterns != b.num_subpatterns) {
+ return false;
+ }
+
+ if (a.matchings == NULL && b.matchings == NULL) {
+ return true;
+ }
+
+ if (a.matchings == NULL || b.matchings == NULL) {
+ return false;
+ }
+
+ return *a.matchings == *b.matchings;
+}
+
+} // namespace extractor
diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h
new file mode 100644
index 00000000..91950e03
--- /dev/null
+++ b/extractor/phrase_location.h
@@ -0,0 +1,41 @@
+#ifndef _PHRASE_LOCATION_H_
+#define _PHRASE_LOCATION_H_
+
+#include <memory>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Structure containing information about the occurrences of a phrase in the
+ * source data.
+ *
+ * Every consecutive (disjoint) group of num_subpatterns entries in matchings
+ * vector encodes an occurrence of the phrase. The i-th entry of a group
+ * represents the start of the i-th subpattern of the phrase. If the phrase
+ * doesn't contain any nonterminals, then it may also be represented as the
+ * range in the suffix array which matches the phrase.
+ */
+struct PhraseLocation {
+ PhraseLocation(int sa_low = -1, int sa_high = -1);
+
+ PhraseLocation(const vector<int>& matchings, int num_subpatterns);
+
+ // Checks if a phrase has any occurrences in the source data.
+ bool IsEmpty() const;
+
+ // Returns the number of occurrences of a phrase in the source data.
+ int GetSize() const;
+
+ friend bool operator==(const PhraseLocation& a, const PhraseLocation& b);
+
+ int sa_low, sa_high;
+ shared_ptr<vector<int> > matchings;
+ int num_subpatterns;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/phrase_test.cc b/extractor/phrase_test.cc
new file mode 100644
index 00000000..3ba9368a
--- /dev/null
+++ b/extractor/phrase_test.cc
@@ -0,0 +1,83 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class PhraseTest : public Test {
+ protected:
+ virtual void SetUp() {
+ shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>();
+ vector<string> words = {"w1", "w2", "w3", "w4"};
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i + 1))
+ .WillRepeatedly(Return(words[i]));
+ }
+ shared_ptr<PhraseBuilder> phrase_builder =
+ make_shared<PhraseBuilder>(vocabulary);
+
+ symbols1 = vector<int>{1, 2, 3};
+ phrase1 = phrase_builder->Build(symbols1);
+ symbols2 = vector<int>{1, 2, -1, 3, -2, 4};
+ phrase2 = phrase_builder->Build(symbols2);
+ }
+
+ vector<int> symbols1, symbols2;
+ Phrase phrase1, phrase2;
+};
+
+TEST_F(PhraseTest, TestArity) {
+ EXPECT_EQ(0, phrase1.Arity());
+ EXPECT_EQ(2, phrase2.Arity());
+}
+
+TEST_F(PhraseTest, GetChunkLen) {
+ EXPECT_EQ(3, phrase1.GetChunkLen(0));
+
+ EXPECT_EQ(2, phrase2.GetChunkLen(0));
+ EXPECT_EQ(1, phrase2.GetChunkLen(1));
+ EXPECT_EQ(1, phrase2.GetChunkLen(2));
+}
+
+TEST_F(PhraseTest, TestGet) {
+ EXPECT_EQ(symbols1, phrase1.Get());
+ EXPECT_EQ(symbols2, phrase2.Get());
+}
+
+TEST_F(PhraseTest, TestGetSymbol) {
+ for (size_t i = 0; i < symbols1.size(); ++i) {
+ EXPECT_EQ(symbols1[i], phrase1.GetSymbol(i));
+ }
+ for (size_t i = 0; i < symbols2.size(); ++i) {
+ EXPECT_EQ(symbols2[i], phrase2.GetSymbol(i));
+ }
+}
+
+TEST_F(PhraseTest, TestGetNumSymbols) {
+ EXPECT_EQ(3, phrase1.GetNumSymbols());
+ EXPECT_EQ(6, phrase2.GetNumSymbols());
+}
+
+TEST_F(PhraseTest, TestGetWords) {
+ vector<string> expected_words = {"w1", "w2", "w3"};
+ EXPECT_EQ(expected_words, phrase1.GetWords());
+ expected_words = {"w1", "w2", "w3", "w4"};
+ EXPECT_EQ(expected_words, phrase2.GetWords());
+}
+
+TEST_F(PhraseTest, TestComparator) {
+ EXPECT_FALSE(phrase1 < phrase2);
+ EXPECT_TRUE(phrase2 < phrase1);
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
new file mode 100644
index 00000000..b3906943
--- /dev/null
+++ b/extractor/precomputation.cc
@@ -0,0 +1,189 @@
+#include "precomputation.h"
+
+#include <iostream>
+#include <queue>
+
+#include "data_array.h"
+#include "suffix_array.h"
+
+using namespace std;
+
+namespace extractor {
+
+int Precomputation::FIRST_NONTERMINAL = -1;
+int Precomputation::SECOND_NONTERMINAL = -2;
+
+Precomputation::Precomputation(
+ shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
+ int num_super_frequent_patterns, int max_rule_span,
+ int max_rule_symbols, int min_gap_size,
+ int max_frequent_phrase_len, int min_frequency) {
+ vector<int> data = suffix_array->GetData()->GetData();
+ vector<vector<int> > frequent_patterns = FindMostFrequentPatterns(
+ suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
+ min_frequency);
+
+ // Construct sets containing the frequent and superfrequent contiguous
+ // collocations.
+ unordered_set<vector<int>, VectorHash> frequent_patterns_set;
+ unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
+ for (size_t i = 0; i < frequent_patterns.size(); ++i) {
+ frequent_patterns_set.insert(frequent_patterns[i]);
+ if (i < num_super_frequent_patterns) {
+ super_frequent_patterns_set.insert(frequent_patterns[i]);
+ }
+ }
+
+ vector<tuple<int, int, int> > matchings;
+ for (size_t i = 0; i < data.size(); ++i) {
+ // If the sentence is over, add all the discontiguous frequent patterns to
+ // the index.
+ if (data[i] == DataArray::END_OF_LINE) {
+ AddCollocations(matchings, data, max_rule_span, min_gap_size,
+ max_rule_symbols);
+ matchings.clear();
+ continue;
+ }
+ vector<int> pattern;
+ // Find all the contiguous frequent patterns starting at position i.
+ for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
+ pattern.push_back(data[i + j - 1]);
+ if (frequent_patterns_set.count(pattern)) {
+ int is_super_frequent = super_frequent_patterns_set.count(pattern);
+ matchings.push_back(make_tuple(i, j, is_super_frequent));
+ } else {
+ // If the current pattern is not frequent, any longer pattern having the
+ // current pattern as prefix will not be frequent.
+ break;
+ }
+ }
+ }
+}
+
+Precomputation::Precomputation() {}
+
+Precomputation::~Precomputation() {}
+
+vector<vector<int> > Precomputation::FindMostFrequentPatterns(
+ shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
+ int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) {
+ vector<int> lcp = suffix_array->BuildLCPArray();
+ vector<int> run_start(max_frequent_phrase_len);
+
+ // Find all the patterns occurring at least min_frequency times.
+ priority_queue<pair<int, pair<int, int> > > heap;
+ for (size_t i = 1; i < lcp.size(); ++i) {
+ for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {
+ int frequency = i - run_start[len];
+ if (frequency >= min_frequency) {
+ heap.push(make_pair(frequency,
+ make_pair(suffix_array->GetSuffix(run_start[len]), len + 1)));
+ }
+ run_start[len] = i;
+ }
+ }
+
+ // Extract the most frequent patterns.
+ vector<vector<int> > frequent_patterns;
+ while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
+ int start = heap.top().second.first;
+ int len = heap.top().second.second;
+ heap.pop();
+
+ vector<int> pattern(data.begin() + start, data.begin() + start + len);
+ if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==
+ pattern.end()) {
+ frequent_patterns.push_back(pattern);
+ }
+ }
+ return frequent_patterns;
+}
+
+void Precomputation::AddCollocations(
+ const vector<tuple<int, int, int> >& matchings, const vector<int>& data,
+ int max_rule_span, int min_gap_size, int max_rule_symbols) {
+ // Select the leftmost subpattern.
+ for (size_t i = 0; i < matchings.size(); ++i) {
+ int start1, size1, is_super1;
+ tie(start1, size1, is_super1) = matchings[i];
+
+ // Select the second (middle) subpattern
+ for (size_t j = i + 1; j < matchings.size(); ++j) {
+ int start2, size2, is_super2;
+ tie(start2, size2, is_super2) = matchings[j];
+ if (start2 - start1 >= max_rule_span) {
+ break;
+ }
+
+ if (start2 - start1 - size1 >= min_gap_size
+ && start2 + size2 - start1 <= max_rule_span
+ && size1 + size2 + 1 <= max_rule_symbols) {
+ vector<int> pattern(data.begin() + start1,
+ data.begin() + start1 + size1);
+ pattern.push_back(Precomputation::FIRST_NONTERMINAL);
+ pattern.insert(pattern.end(), data.begin() + start2,
+ data.begin() + start2 + size2);
+ AddStartPositions(collocations[pattern], start1, start2);
+
+ // Try extending the binary collocation to a ternary collocation.
+ if (is_super2) {
+ pattern.push_back(Precomputation::SECOND_NONTERMINAL);
+ // Select the rightmost subpattern.
+ for (size_t k = j + 1; k < matchings.size(); ++k) {
+ int start3, size3, is_super3;
+ tie(start3, size3, is_super3) = matchings[k];
+ if (start3 - start1 >= max_rule_span) {
+ break;
+ }
+
+ if (start3 - start2 - size2 >= min_gap_size
+ && start3 + size3 - start1 <= max_rule_span
+ && size1 + size2 + size3 + 2 <= max_rule_symbols
+ && (is_super1 || is_super3)) {
+ pattern.insert(pattern.end(), data.begin() + start3,
+ data.begin() + start3 + size3);
+ AddStartPositions(collocations[pattern], start1, start2, start3);
+ pattern.erase(pattern.end() - size3);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+void Precomputation::AddStartPositions(
+ vector<int>& positions, int pos1, int pos2) {
+ positions.push_back(pos1);
+ positions.push_back(pos2);
+}
+
+void Precomputation::AddStartPositions(
+ vector<int>& positions, int pos1, int pos2, int pos3) {
+ positions.push_back(pos1);
+ positions.push_back(pos2);
+ positions.push_back(pos3);
+}
+
+void Precomputation::WriteBinary(const fs::path& filepath) const {
+ FILE* file = fopen(filepath.string().c_str(), "w");
+
+ // TODO(pauldb): Refactor this code.
+ int size = collocations.size();
+ fwrite(&size, sizeof(int), 1, file);
+ for (auto entry: collocations) {
+ size = entry.first.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(entry.first.data(), sizeof(int), size, file);
+
+ size = entry.second.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(entry.second.data(), sizeof(int), size, file);
+ }
+}
+
+const Index& Precomputation::GetCollocations() const {
+ return collocations;
+}
+
+} // namespace extractor
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
new file mode 100644
index 00000000..e3c4d26a
--- /dev/null
+++ b/extractor/precomputation.h
@@ -0,0 +1,80 @@
+#ifndef _PRECOMPUTATION_H_
+#define _PRECOMPUTATION_H_
+
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <tuple>
+#include <vector>
+
+#include <boost/filesystem.hpp>
+#include <boost/functional/hash.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+typedef boost::hash<vector<int> > VectorHash;
+typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
+
+class SuffixArray;
+
+/**
+ * Data structure wrapping an index with all the occurrences of the most
+ * frequent discontiguous collocations in the source data.
+ *
+ * Let a, b, c be contiguous collocations. The index will contain an entry for
+ * every collocation of the form:
+ * - aXb, where a and b are frequent
+ * - aXbXc, where a and b are super-frequent and c is frequent or
+ * b and c are super-frequent and a is frequent.
+ */
+class Precomputation {
+ public:
+ // Constructs the index using the suffix array.
+ Precomputation(
+ shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
+ int num_super_frequent_patterns, int max_rule_span,
+ int max_rule_symbols, int min_gap_size,
+ int max_frequent_phrase_len, int min_frequency);
+
+ virtual ~Precomputation();
+
+ void WriteBinary(const fs::path& filepath) const;
+
+ // Returns a reference to the index.
+ virtual const Index& GetCollocations() const;
+
+ static int FIRST_NONTERMINAL;
+ static int SECOND_NONTERMINAL;
+
+ protected:
+ Precomputation();
+
+ private:
+ // Finds the most frequent contiguous collocations.
+ vector<vector<int> > FindMostFrequentPatterns(
+ shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
+ int num_frequent_patterns, int max_frequent_phrase_len,
+ int min_frequency);
+
+ // Given the locations of the frequent contiguous collocations in a sentence,
+ // it adds new entries to the index for each discontiguous collocation
+ // matching the criteria specified in the class description.
+ void AddCollocations(
+ const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data,
+ int max_rule_span, int min_gap_size, int max_rule_symbols);
+
+ // Adds an occurrence of a binary collocation.
+ void AddStartPositions(vector<int>& positions, int pos1, int pos2);
+
+ // Adds an occurrence of a ternary collocation.
+ void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
+
+ Index collocations;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
new file mode 100644
index 00000000..363febb7
--- /dev/null
+++ b/extractor/precomputation_test.cc
@@ -0,0 +1,106 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_suffix_array.h"
+#include "precomputation.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class PrecomputationTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
+ data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+
+ vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
+ vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (size_t i = 0; i < suffixes.size(); ++i) {
+ EXPECT_CALL(*suffix_array,
+ GetSuffix(i)).WillRepeatedly(Return(suffixes[i]));
+ }
+ EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
+ }
+
+ vector<int> data;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+};
+
+TEST_F(PrecomputationTest, TestCollocations) {
+ Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
+ Index collocations = precomputation.GetCollocations();
+
+ vector<int> key = {2, 3, -1, 2};
+ vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, 3, -1, 2, 3};
+ expected_value = {1, 5, 1, 8, 5, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, 3, -1, 3};
+ expected_value = {1, 6, 1, 9, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2};
+ expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3};
+ expected_value = {2, 6, 2, 9, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, 3};
+ expected_value = {2, 5, 2, 8, 6, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2};
+ expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2, 3};
+ expected_value = {1, 5, 1, 8, 5, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3};
+ expected_value = {1, 6, 1, 9, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+
+ key = {2, -1, 2, -2, 2};
+ expected_value = {1, 5, 8, 5, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2, -2, 3};
+ expected_value = {1, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3, -2, 2};
+ expected_value = {1, 6, 8, 5, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3, -2, 3};
+ expected_value = {1, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, -2, 2};
+ expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, -2, 3};
+ expected_value = {2, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3, -2, 2};
+ expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3, -2, 3};
+ expected_value = {2, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+
+ // Exceeds max_rule_symbols.
+ key = {2, -1, 2, -2, 2, 3};
+ EXPECT_EQ(0, collocations.count(key));
+ // Contains non frequent pattern.
+ key = {2, -1, 5};
+ EXPECT_EQ(0, collocations.count(key));
+}
+
+} // namespace
+} // namespace extractor
+
diff --git a/extractor/rule.cc b/extractor/rule.cc
new file mode 100644
index 00000000..b6c7d783
--- /dev/null
+++ b/extractor/rule.cc
@@ -0,0 +1,14 @@
+#include "rule.h"
+
+namespace extractor {
+
+Rule::Rule(const Phrase& source_phrase,
+ const Phrase& target_phrase,
+ const vector<double>& scores,
+ const vector<pair<int, int> >& alignment) :
+ source_phrase(source_phrase),
+ target_phrase(target_phrase),
+ scores(scores),
+ alignment(alignment) {}
+
+} // namespace extractor
diff --git a/extractor/rule.h b/extractor/rule.h
new file mode 100644
index 00000000..bc95709e
--- /dev/null
+++ b/extractor/rule.h
@@ -0,0 +1,27 @@
+#ifndef _RULE_H_
+#define _RULE_H_
+
+#include <vector>
+
+#include "phrase.h"
+
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Structure containing the data for a SCFG rule.
+ */
+struct Rule {
+ Rule(const Phrase& source_phrase, const Phrase& target_phrase,
+ const vector<double>& scores, const vector<pair<int, int> >& alignment);
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ vector<double> scores;
+ vector<pair<int, int> > alignment;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc
new file mode 100644
index 00000000..fa7386a4
--- /dev/null
+++ b/extractor/rule_extractor.cc
@@ -0,0 +1,343 @@
+#include "rule_extractor.h"
+
+#include <map>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/feature.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule.h"
+#include "rule_extractor_helper.h"
+#include "scorer.h"
+#include "target_phrase_extractor.h"
+
+using namespace std;
+
+namespace extractor {
+
+RuleExtractor::RuleExtractor(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ target_data_array(target_data_array),
+ source_data_array(source_data_array),
+ phrase_builder(phrase_builder),
+ scorer(scorer),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size),
+ max_nonterminals(max_nonterminals),
+ max_rule_symbols(max_rule_symbols),
+ require_tight_phrases(require_tight_phrases) {
+ helper = make_shared<RuleExtractorHelper>(
+ source_data_array, target_data_array, alignment, max_rule_span,
+ max_rule_symbols, require_aligned_terminal, require_aligned_chunks,
+ require_tight_phrases);
+ target_phrase_extractor = make_shared<TargetPhraseExtractor>(
+ target_data_array, alignment, phrase_builder, helper, vocabulary,
+ max_rule_span, require_tight_phrases);
+}
+
+RuleExtractor::RuleExtractor(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor,
+ shared_ptr<RuleExtractorHelper> helper,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ phrase_builder(phrase_builder),
+ scorer(scorer),
+ target_phrase_extractor(target_phrase_extractor),
+ helper(helper),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size),
+ max_nonterminals(max_nonterminals),
+ max_rule_symbols(max_rule_symbols),
+ require_tight_phrases(require_tight_phrases) {}
+
+RuleExtractor::RuleExtractor() {}
+
+RuleExtractor::~RuleExtractor() {}
+
+vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const {
+ int num_subpatterns = location.num_subpatterns;
+ vector<int> matchings = *location.matchings;
+
+ // Calculate statistics for the (sampled) occurrences of the source phrase.
+ map<Phrase, double> source_phrase_counter;
+ map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter;
+ for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) {
+ vector<int> matching(i, i + num_subpatterns);
+ vector<Extract> extracts = ExtractAlignments(phrase, matching);
+
+ for (Extract e: extracts) {
+ source_phrase_counter[e.source_phrase] += e.pairs_count;
+ alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1;
+ }
+ }
+
+ // Compute the feature scores and find the most likely (frequent) alignment
+ // for each pair of source-target phrases.
+ int num_samples = matchings.size() / num_subpatterns;
+ vector<Rule> rules;
+ for (auto source_phrase_entry: alignments_counter) {
+ Phrase source_phrase = source_phrase_entry.first;
+ for (auto target_phrase_entry: source_phrase_entry.second) {
+ Phrase target_phrase = target_phrase_entry.first;
+
+ int max_locations = 0, num_locations = 0;
+ PhraseAlignment most_frequent_alignment;
+ for (auto alignment_entry: target_phrase_entry.second) {
+ num_locations += alignment_entry.second;
+ if (alignment_entry.second > max_locations) {
+ most_frequent_alignment = alignment_entry.first;
+ max_locations = alignment_entry.second;
+ }
+ }
+
+ features::FeatureContext context(source_phrase, target_phrase,
+ source_phrase_counter[source_phrase], num_locations, num_samples);
+ vector<double> scores = scorer->Score(context);
+ rules.push_back(Rule(source_phrase, target_phrase, scores,
+ most_frequent_alignment));
+ }
+ }
+ return rules;
+}
+
+vector<Extract> RuleExtractor::ExtractAlignments(
+ const Phrase& phrase, const vector<int>& matching) const {
+ vector<Extract> extracts;
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ // Get the span in the opposite sentence for each word in the source-target
+ // sentece pair.
+ vector<int> source_low, source_high, target_low, target_high;
+ helper->GetLinksSpans(source_low, source_high, target_low, target_high,
+ sentence_id);
+
+ int num_subpatterns = matching.size();
+ vector<int> chunklen(num_subpatterns);
+ for (size_t i = 0; i < num_subpatterns; ++i) {
+ chunklen[i] = phrase.GetChunkLen(i);
+ }
+
+ // Basic checks to see if we can extract phrase pairs for this occurrence.
+ if (!helper->CheckAlignedTerminals(matching, chunklen, source_low,
+ source_sent_start) ||
+ !helper->CheckTightPhrases(matching, chunklen, source_low,
+ source_sent_start)) {
+ return extracts;
+ }
+
+ int source_back_low = -1, source_back_high = -1;
+ int source_phrase_low = matching[0] - source_sent_start;
+ int source_phrase_high = matching.back() + chunklen.back() -
+ source_sent_start;
+ int target_phrase_low = -1, target_phrase_high = -1;
+ // Find target span and reflected source span for the source phrase.
+ if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high,
+ target_low, target_high, source_back_low,
+ source_back_high, sentence_id, min_gap_size, 0,
+ max_nonterminals - matching.size() + 1, true, true,
+ false)) {
+ return extracts;
+ }
+
+ // Get spans for nonterminal gaps.
+ bool met_constraints = true;
+ int num_symbols = phrase.GetNumSymbols();
+ vector<pair<int, int> > source_gaps, target_gaps;
+ if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low,
+ source_high, target_low, target_high, source_phrase_low,
+ source_phrase_high, source_back_low, source_back_high,
+ sentence_id, source_sent_start, num_symbols,
+ met_constraints)) {
+ return extracts;
+ }
+
+ // Find target phrases aligned with the initial source phrase.
+ bool starts_with_x = source_back_low != source_phrase_low;
+ bool ends_with_x = source_back_high != source_phrase_high;
+ Phrase source_phrase = phrase_builder->Extend(
+ phrase, starts_with_x, ends_with_x);
+ unordered_map<int, int> source_indexes = helper->GetSourceIndexes(
+ matching, chunklen, starts_with_x, source_sent_start);
+ if (met_constraints) {
+ AddExtracts(extracts, source_phrase, source_indexes, target_gaps,
+ target_low, target_phrase_low, target_phrase_high, sentence_id);
+ }
+
+ if (source_gaps.size() >= max_nonterminals ||
+ source_phrase.GetNumSymbols() >= max_rule_symbols ||
+ source_back_high - source_back_low + min_gap_size > max_rule_span) {
+ // Cannot add any more nonterminals.
+ return extracts;
+ }
+
+ // Extend the source phrase by adding a leading and/or trailing nonterminal
+ // and find target phrases aligned with the extended source phrase.
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 1 - i; j < 2; ++j) {
+ AddNonterminalExtremities(extracts, matching, chunklen, source_phrase,
+ source_back_low, source_back_high, source_low, source_high,
+ target_low, target_high, target_gaps, sentence_id, source_sent_start,
+ starts_with_x, ends_with_x, i, j);
+ }
+ }
+
+ return extracts;
+}
+
+void RuleExtractor::AddExtracts(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ const unordered_map<int, int>& source_indexes,
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const {
+ auto target_phrases = target_phrase_extractor->ExtractPhrases(
+ target_gaps, target_low, target_phrase_low, target_phrase_high,
+ source_indexes, sentence_id);
+
+ if (target_phrases.size() > 0) {
+ // Split the probability equally across all target phrases that can be
+ // aligned with a single occurrence of the source phrase.
+ double pairs_count = 1.0 / target_phrases.size();
+ for (auto target_phrase: target_phrases) {
+ extracts.push_back(Extract(source_phrase, target_phrase.first,
+ pairs_count, target_phrase.second));
+ }
+ }
+}
+
+void RuleExtractor::AddNonterminalExtremities(
+ vector<Extract>& extracts, const vector<int>& matching,
+ const vector<int>& chunklen, const Phrase& source_phrase,
+ int source_back_low, int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high, vector<pair<int, int> > target_gaps,
+ int sentence_id, int source_sent_start, int starts_with_x, int ends_with_x,
+ int extend_left, int extend_right) const {
+ int source_x_low = source_back_low, source_x_high = source_back_high;
+
+ // Check if the extended source phrase will remain tight.
+ if (require_tight_phrases) {
+ if (source_low[source_back_low - extend_left] == -1 ||
+ source_low[source_back_high + extend_right - 1] == -1) {
+ return;
+ }
+ }
+
+ // Check if we can add a nonterminal to the left.
+ if (extend_left) {
+ if (starts_with_x || source_back_low < min_gap_size) {
+ return;
+ }
+
+ source_x_low = source_back_low - min_gap_size;
+ if (require_tight_phrases) {
+ while (source_x_low >= 0 && source_low[source_x_low] == -1) {
+ --source_x_low;
+ }
+ }
+ if (source_x_low < 0) {
+ return;
+ }
+ }
+
+ // Check if we can add a nonterminal to the right.
+ if (extend_right) {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ if (ends_with_x || source_back_high + min_gap_size > source_sent_len) {
+ return;
+ }
+ source_x_high = source_back_high + min_gap_size;
+ if (require_tight_phrases) {
+ while (source_x_high <= source_sent_len &&
+ source_low[source_x_high - 1] == -1) {
+ ++source_x_high;
+ }
+ }
+
+ if (source_x_high > source_sent_len) {
+ return;
+ }
+ }
+
+ // More length checks.
+ int new_nonterminals = extend_left + extend_right;
+ if (source_x_high - source_x_low > max_rule_span ||
+ target_gaps.size() + new_nonterminals > max_nonterminals ||
+ source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) {
+ return;
+ }
+
+ // Find the target span for the extended phrase and the reflected source span.
+ int target_x_low = -1, target_x_high = -1;
+ if (!helper->FindFixPoint(source_x_low, source_x_high, source_low,
+ source_high, target_x_low, target_x_high,
+ target_low, target_high, source_x_low,
+ source_x_high, sentence_id, 1, 1,
+ new_nonterminals, extend_left, extend_right,
+ true)) {
+ return;
+ }
+
+ // Check gap integrity for the leading nonterminal.
+ if (extend_left) {
+ int source_gap_low = -1, source_gap_high = -1;
+ int target_gap_low = -1, target_gap_high = -1;
+ if ((require_tight_phrases && source_low[source_x_low] == -1) ||
+ !helper->FindFixPoint(source_x_low, source_back_low, source_low,
+ source_high, target_gap_low, target_gap_high,
+ target_low, target_high, source_gap_low,
+ source_gap_high, sentence_id, 0, 0, 0, false,
+ false, false)) {
+ return;
+ }
+ target_gaps.insert(target_gaps.begin(),
+ make_pair(target_gap_low, target_gap_high));
+ }
+
+ // Check gap integrity for the trailing nonterminal.
+ if (extend_right) {
+ int target_gap_low = -1, target_gap_high = -1;
+ int source_gap_low = -1, source_gap_high = -1;
+ if ((require_tight_phrases && source_low[source_x_high - 1] == -1) ||
+ !helper->FindFixPoint(source_back_high, source_x_high, source_low,
+ source_high, target_gap_low, target_gap_high,
+ target_low, target_high, source_gap_low,
+ source_gap_high, sentence_id, 0, 0, 0, false,
+ false, false)) {
+ return;
+ }
+ target_gaps.push_back(make_pair(target_gap_low, target_gap_high));
+ }
+
+ // Find target phrases aligned with the extended source phrase.
+ Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left,
+ extend_right);
+ unordered_map<int, int> source_indexes = helper->GetSourceIndexes(
+ matching, chunklen, extend_left || starts_with_x, source_sent_start);
+ AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps,
+ target_low, target_x_low, target_x_high, sentence_id);
+}
+
+} // namespace extractor
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h
new file mode 100644
index 00000000..26e6f21c
--- /dev/null
+++ b/extractor/rule_extractor.h
@@ -0,0 +1,124 @@
+#ifndef _RULE_EXTRACTOR_H_
+#define _RULE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "phrase.h"
+
+using namespace std;
+
+namespace extractor {
+
+typedef vector<pair<int, int> > PhraseAlignment;
+
+class Alignment;
+class DataArray;
+class PhraseBuilder;
+class PhraseLocation;
+class Rule;
+class RuleExtractorHelper;
+class Scorer;
+class TargetPhraseExtractor;
+
+/**
+ * Structure containing data about the occurrences of a source-target phrase pair
+ * in the parallel corpus.
+ */
+struct Extract {
+ Extract(const Phrase& source_phrase, const Phrase& target_phrase,
+ double pairs_count, const PhraseAlignment& alignment) :
+ source_phrase(source_phrase), target_phrase(target_phrase),
+ pairs_count(pairs_count), alignment(alignment) {}
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ double pairs_count;
+ PhraseAlignment alignment;
+};
+
+/**
+ * Component for extracting SCFG rules.
+ */
+class RuleExtractor {
+ public:
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alingment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases);
+
+ // For testing only.
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor,
+ shared_ptr<RuleExtractorHelper> helper,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_tight_phrases);
+
+ virtual ~RuleExtractor();
+
+ // Extracts SCFG rules given a source phrase and a set of its occurrences
+ // in the source data.
+ virtual vector<Rule> ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const;
+
+ protected:
+ RuleExtractor();
+
+ private:
+ // Finds all target phrases that can be aligned with the source phrase for a
+ // particular occurrence in the data.
+ vector<Extract> ExtractAlignments(const Phrase& phrase,
+ const vector<int>& matching) const;
+
+ // Extracts all target phrases for a given occurrence of the source phrase in
+ // the data. Constructs a vector of Extracts using these target phrases.
+ void AddExtracts(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ const unordered_map<int, int>& source_indexes,
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const;
+
+ // Adds a leading and/or trailing nonterminal to the source phrase and
+ // extracts target phrases that can be aligned with the extended source
+ // phrase.
+ void AddNonterminalExtremities(
+ vector<Extract>& extracts, const vector<int>& matching,
+ const vector<int>& chunklen, const Phrase& source_phrase,
+ int source_back_low, int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high, vector<pair<int, int> > target_gaps,
+ int sentence_id, int source_sent_start, int starts_with_x,
+ int ends_with_x, int extend_left, int extend_right) const;
+
+ private:
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<Scorer> scorer;
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractorHelper> helper;
+ int max_rule_span;
+ int min_gap_size;
+ int max_nonterminals;
+ int max_rule_symbols;
+ bool require_tight_phrases;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
new file mode 100644
index 00000000..8a9516f2
--- /dev/null
+++ b/extractor/rule_extractor_helper.cc
@@ -0,0 +1,362 @@
+#include "rule_extractor_helper.h"
+
+#include "data_array.h"
+#include "alignment.h"
+
+namespace extractor {
+
+RuleExtractorHelper::RuleExtractorHelper(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ target_data_array(target_data_array),
+ alignment(alignment),
+ max_rule_span(max_rule_span),
+ max_rule_symbols(max_rule_symbols),
+ require_aligned_terminal(require_aligned_terminal),
+ require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {}
+
+RuleExtractorHelper::RuleExtractorHelper() {}
+
+RuleExtractorHelper::~RuleExtractorHelper() {}
+
+void RuleExtractorHelper::GetLinksSpans(
+ vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ source_low = vector<int>(source_sent_len, -1);
+ source_high = vector<int>(source_sent_len, -1);
+
+ target_low = vector<int>(target_sent_len, -1);
+ target_high = vector<int>(target_sent_len, -1);
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ for (auto link: links) {
+ if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
+ source_low[link.first] = link.second;
+ }
+ source_high[link.first] = max(source_high[link.first], link.second + 1);
+
+ if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
+ target_low[link.second] = link.first;
+ }
+ target_high[link.second] = max(target_high[link.second], link.first + 1);
+ }
+}
+
+bool RuleExtractorHelper::CheckAlignedTerminals(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const {
+ if (!require_aligned_terminal) {
+ return true;
+ }
+
+ int num_aligned_chunks = 0;
+ for (size_t i = 0; i < chunklen.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ int sent_index = matching[i] - source_sent_start + j;
+ if (source_low[sent_index] != -1) {
+ ++num_aligned_chunks;
+ break;
+ }
+ }
+ }
+
+ if (num_aligned_chunks == 0) {
+ return false;
+ }
+
+ return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
+}
+
+bool RuleExtractorHelper::CheckTightPhrases(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const {
+ if (!require_tight_phrases) {
+ return true;
+ }
+
+ // Check if the chunk extremities are aligned.
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - 1 - source_sent_start;
+ if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RuleExtractorHelper::FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const {
+ int prev_target_low = target_phrase_low;
+ int prev_target_high = target_phrase_high;
+
+ FindProjection(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+
+ if (target_phrase_low == -1) {
+ // Note: Low priority corner case inherited from Adam's code:
+ // If w is unaligned, but we don't require aligned terminals, returning an
+ // error here prevents the extraction of the allowed rule
+ // X -> X_1 w X_2 / X_1 X_2
+ return false;
+ }
+
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ // Extend the target span to the left.
+ if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
+ if (prev_target_low - target_phrase_low < min_target_gap_size) {
+ target_phrase_low = prev_target_low - min_target_gap_size;
+ if (target_phrase_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend the target span to the right.
+ if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
+ if (target_phrase_high - prev_target_high < min_target_gap_size) {
+ target_phrase_high = prev_target_high + min_target_gap_size;
+ if (target_phrase_high > target_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ // Check target span length.
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ return false;
+ }
+
+ // Find the initial reflected source span.
+ source_back_low = source_back_high = -1;
+ FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high);
+ int new_x = 0;
+ bool new_low_x = false, new_high_x = false;
+ while (true) {
+ source_back_low = min(source_back_low, source_phrase_low);
+ source_back_high = max(source_back_high, source_phrase_high);
+
+ // Stop if the reflected source span matches the previous source span.
+ if (source_back_low == source_phrase_low &&
+ source_back_high == source_phrase_high) {
+ return true;
+ }
+
+ if (!allow_low_x && source_back_low < source_phrase_low) {
+ // Extension on the left side not allowed.
+ return false;
+ }
+ if (!allow_high_x && source_back_high > source_phrase_high) {
+ // Extension on the right side not allowed.
+ return false;
+ }
+
+ // Extend left side.
+ if (source_back_low < source_phrase_low) {
+ if (new_low_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_low_x = true;
+ ++new_x;
+ }
+ if (source_phrase_low - source_back_low < min_source_gap_size) {
+ source_back_low = source_phrase_low - min_source_gap_size;
+ if (source_back_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend right side.
+ if (source_back_high > source_phrase_high) {
+ if (new_high_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_high_x = true;
+ ++new_x;
+ }
+ if (source_back_high - source_phrase_high < min_source_gap_size) {
+ source_back_high = source_phrase_high + min_source_gap_size;
+ if (source_back_high > source_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (source_back_high - source_back_low > max_rule_span) {
+ // Rule span too wide.
+ return false;
+ }
+
+ prev_target_low = target_phrase_low;
+ prev_target_high = target_phrase_high;
+ // Find the reflection including the left gap (if one was added).
+ FindProjection(source_back_low, source_phrase_low, source_low, source_high,
+ target_phrase_low, target_phrase_high);
+ // Find the reflection including the right gap (if one was added).
+ FindProjection(source_phrase_high, source_back_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+ // Stop if the new re-reflected target span matches the previous target
+ // span.
+ if (prev_target_low == target_phrase_low &&
+ prev_target_high == target_phrase_high) {
+ return true;
+ }
+
+ if (!allow_arbitrary_expansion) {
+ // Arbitrary expansion not allowed.
+ return false;
+ }
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ // Target side too wide.
+ return false;
+ }
+
+ source_phrase_low = source_back_low;
+ source_phrase_high = source_back_high;
+ // Re-reflect the target span.
+ FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
+ source_back_low, source_back_high);
+ FindProjection(prev_target_high, target_phrase_high, target_low,
+ target_high, source_back_low, source_back_high);
+ }
+
+ return false;
+}
+
+void RuleExtractorHelper::FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const {
+ for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
+ if (source_low[i] != -1) {
+ if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
+ target_phrase_low = source_low[i];
+ }
+ target_phrase_high = max(target_phrase_high, source_high[i]);
+ }
+ }
+}
+
+bool RuleExtractorHelper::GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int sentence_id, int source_sent_start,
+ int& num_symbols, bool& met_constraints) const {
+ if (source_back_low < source_phrase_low) {
+ source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_back_low] == -1 ||
+ source_low[source_phrase_low - 1] == -1)) {
+ // Inside edges of preceding gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - source_sent_start;
+ source_gaps.push_back(make_pair(gap_start, gap_end));
+ }
+
+ if (source_phrase_high < source_back_high) {
+ source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
+ source_low[source_back_high - 1] == -1)) {
+ // Inside edges of following gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases &&
+ source_low[source_phrase_high - 1] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
+ for (size_t i = 0; i < source_gaps.size(); ++i) {
+ if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
+ source_high, target_gaps[i].first, target_gaps[i].second,
+ target_low, target_high, source_gaps[i].first,
+ source_gaps[i].second, sentence_id, 0, 0, 0, false, false,
+ false)) {
+ // Gap fails integrity check.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+vector<int> RuleExtractorHelper::GetGapOrder(
+ const vector<pair<int, int> >& gaps) const {
+ vector<int> gap_order(gaps.size());
+ for (size_t i = 0; i < gap_order.size(); ++i) {
+ for (size_t j = 0; j < i; ++j) {
+ if (gaps[gap_order[j]] < gaps[i]) {
+ ++gap_order[i];
+ } else {
+ ++gap_order[j];
+ }
+ }
+ }
+ return gap_order;
+}
+
+unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x, int source_sent_start) const {
+ unordered_map<int, int> source_indexes;
+ int num_symbols = starts_with_x;
+ for (size_t i = 0; i < matching.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ source_indexes[matching[i] + j - source_sent_start] = num_symbols;
+ ++num_symbols;
+ }
+ ++num_symbols;
+ }
+ return source_indexes;
+}
+
+} // namespace extractor
diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h
new file mode 100644
index 00000000..d4ae45d4
--- /dev/null
+++ b/extractor/rule_extractor_helper.h
@@ -0,0 +1,101 @@
+#ifndef _RULE_EXTRACTOR_HELPER_H_
+#define _RULE_EXTRACTOR_HELPER_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+class Alignment;
+class DataArray;
+
+/**
+ * Helper class for extracting SCFG rules.
+ */
+class RuleExtractorHelper {
+ public:
+ RuleExtractorHelper(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases);
+
+ virtual ~RuleExtractorHelper();
+
+ // Find the alignment span for each word in the source target sentence pair.
+ virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high,
+ int sentence_id) const;
+
+ // Check if one chunk (all chunks) is aligned at least in one point.
+ virtual bool CheckAlignedTerminals(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const;
+
+ // Check if the chunks are tight.
+ virtual bool CheckTightPhrases(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const;
+
+ // Find the target span and the reflected source span for a source phrase
+ // occurrence.
+ virtual bool FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const;
+
+ // Find the gap spans for each nonterminal in the source phrase.
+ virtual bool GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int sentence_id, int source_sent_start,
+ int& num_symbols, bool& met_constraints) const;
+
+ // Get the order of the nonterminals in the target phrase.
+ virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const;
+
+ // Map each terminal symbol with its position in the source phrase.
+ virtual unordered_map<int, int> GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x, int source_sent_start) const;
+
+ protected:
+ RuleExtractorHelper();
+
+ private:
+ // Find the projection of a source phrase in the target sentence. May also be
+ // used to find the projection of a target phrase in the source sentence.
+ void FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const;
+
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<Alignment> alignment;
+ int max_rule_span;
+ int max_rule_symbols;
+ bool require_aligned_terminal;
+ bool require_aligned_chunks;
+ bool require_tight_phrases;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc
new file mode 100644
index 00000000..9b82abb1
--- /dev/null
+++ b/extractor/rule_extractor_helper_test.cc
@@ -0,0 +1,645 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "rule_extractor_helper.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class RuleExtractorHelperTest : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(12));
+
+ target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(12));
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(0, 1), make_pair(2, 2), make_pair(3, 1)
+ };
+ alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(_)).WillRepeatedly(Return(links));
+ }
+
+ shared_ptr<MockDataArray> source_data_array;
+ shared_ptr<MockDataArray> target_data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<RuleExtractorHelper> helper;
+};
+
+TEST_F(RuleExtractorHelperTest, TestGetLinksSpans) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low, source_high, target_low, target_high;
+ helper->GetLinksSpans(source_low, source_high, target_low, target_high, 0);
+
+ vector<int> expected_source_low = {0, -1, 2, 1};
+ EXPECT_EQ(expected_source_low, source_low);
+ vector<int> expected_source_high = {2, -1, 3, 2};
+ EXPECT_EQ(expected_source_high, source_high);
+ vector<int> expected_target_low = {0, 0, 2};
+ EXPECT_EQ(expected_target_low, target_low);
+ vector<int> expected_target_high = {1, 4, 3};
+ EXPECT_EQ(expected_target_high, target_high);
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedFalse) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, false, false, true);
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0);
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0);
+
+ vector<int> matching, chunklen, source_low;
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedTerminal) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, false, true);
+
+ vector<int> matching = {10, 12};
+ vector<int> chunklen = {1, 3};
+ vector<int> source_low = {-1, 1, -1, 3, -1};
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+ source_low = {-1, 1, -1, -1, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedChunks) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> matching = {10, 12};
+ vector<int> chunklen = {1, 3};
+ vector<int> source_low = {2, 1, -1, 3, -1};
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+ source_low = {-1, 1, -1, 3, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+ source_low = {2, 1, -1, -1, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen,
+ source_low, 10));
+}
+
+
+TEST_F(RuleExtractorHelperTest, TestCheckTightPhrasesFalse) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, false);
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0);
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0);
+
+ vector<int> matching, chunklen, source_low;
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckTightPhrases) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> matching = {10, 14, 18};
+ vector<int> chunklen = {2, 3, 1};
+ // No missing links.
+ vector<int> source_low = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+
+ // Missing link at the beginning or ending of a gap.
+ source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+ source_low = {0, 1, 2, -1, 4, 5, 6, 7, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+ source_low = {0, 1, 2, 3, 4, 5, 6, -1, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+
+ // Missing link inside the gap.
+ chunklen = {1, 3, 1};
+ source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8};
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointBadEdgeCase) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> source_low = {0, -1, 2};
+ vector<int> source_high = {1, -1, 3};
+ vector<int> target_low = {0, -1, 2};
+ vector<int> target_high = {1, -1, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = 1;
+
+ // This should be in fact true. See comment about the inherited bug.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 0, 0,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSentenceOutOfBounds) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 1, target_phrase_high = 2;
+
+ // Extend out of sentence to left.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 2,
+ 0, false, false, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 2,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetTooWide) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 5, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {0, 0, 0, 0, 0, 0, 0};
+ vector<int> source_high = {7, 7, 7, 7, 7, 7, 7};
+ vector<int> target_low = {0, -1, -1, -1, -1, -1, 0};
+ vector<int> target_high = {7, -1, -1, -1, -1, -1, 7};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Projection is too wide.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPoint) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {1, 1, 1, 3, 4, 5, 5};
+ vector<int> source_high = {2, 2, 3, 4, 6, 6, 6};
+ vector<int> target_low = {-1, 0, 2, 3, 4, 4, -1};
+ vector<int> target_high = {-1, 3, 3, 4, 5, 7, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 2, target_phrase_high = 5;
+
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 1, 1, 1,
+ 2, true, true, false));
+ EXPECT_EQ(1, target_phrase_low);
+ EXPECT_EQ(6, target_phrase_high);
+ EXPECT_EQ(0, source_back_low);
+ EXPECT_EQ(7, source_back_high);
+
+ source_low = {0, -1, 1, 3, 4, -1, 6};
+ source_high = {1, -1, 3, 4, 6, -1, 7};
+ target_low = {0, 2, 2, 3, 4, 4, 6};
+ target_high = {1, 3, 3, 4, 5, 5, 7};
+ source_phrase_low = 2, source_phrase_high = 5;
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 1, 1, 1,
+ 2, true, true, false));
+ EXPECT_EQ(1, target_phrase_low);
+ EXPECT_EQ(6, target_phrase_high);
+ EXPECT_EQ(2, source_back_low);
+ EXPECT_EQ(5, source_back_high);
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointExtensionsNotAllowed) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Extension on the left side not allowed.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 1, false, true, false));
+ // Extension on the left side is allowed, but we can't add anymore X.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, true, true, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ // Extension on the right side not allowed.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 1, true, false, false));
+ // Extension on the right side is allowed, but we can't add anymore X.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointSourceSentenceOutOfBounds) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 1, target_phrase_high = 2;
+ // Extend out of sentence to left.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 1,
+ 1, true, true, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ target_phrase_low = 1, target_phrase_high = 2;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 1,
+ 1, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSourceWide) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 5, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {2, -1, 2, 3, 4, -1, 4};
+ vector<int> source_high = {3, -1, 3, 4, 5, -1, 5};
+ vector<int> target_low = {-1, -1, 0, 3, 4, -1, -1};
+ vector<int> target_high = {-1, -1, 3, 4, 7, -1, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Second projection (on source side) is too wide.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 2, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointArbitraryExpansion) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 20, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(11));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(11));
+
+ vector<int> source_low = {1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9};
+ vector<int> source_high = {2, 3, 4, 5, 5, 6, 7, 8, 9, 10, 10};
+ vector<int> target_low = {-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, -1};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, 8, 9, 10, 11, -1};
+ int source_phrase_low = 4, source_phrase_high = 7;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 10, true, true, false));
+
+ source_phrase_low = 4, source_phrase_high = 7;
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 10, true, true, true));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapOrder) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<pair<int, int> > gaps =
+ {make_pair(0, 3), make_pair(5, 8), make_pair(11, 12), make_pair(15, 17)};
+ vector<int> expected_gap_order = {0, 1, 2, 3};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+
+ gaps = {make_pair(15, 17), make_pair(8, 9), make_pair(5, 6), make_pair(0, 3)};
+ expected_gap_order = {3, 2, 1, 0};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+
+ gaps = {make_pair(8, 9), make_pair(5, 6), make_pair(0, 3), make_pair(15, 17)};
+ expected_gap_order = {2, 1, 0, 3};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsExceedNumSymbols) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {1, 1, 2, 3, 4, 5, 6};
+ vector<int> source_high = {2, 2, 3, 4, 5, 6, 7};
+ vector<int> target_low = {-1, 0, 2, 3, 4, 5, 6};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, 7};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 0, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+
+ source_low = {0, 1, 2, 3, 4, 5, 5};
+ source_high = {1, 2, 3, 4, 5, 6, 6};
+ target_low = {0, 1, 2, 3, 4, 5, -1};
+ target_high = {1, 2, 3, 4, 5, 7, -1};
+ source_phrase_low = 1, source_phrase_high = 6;
+ source_back_low = 1, source_back_high = 7;
+ num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsExtensionsNotTight) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 7, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 1, 2, 3, 4, 5, -1};
+ vector<int> source_high = {-1, 2, 3, 4, 5, 6, -1};
+ vector<int> target_low = {-1, 1, 2, 3, 4, 5, -1};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, -1};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 0, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+
+ source_phrase_low = 1, source_phrase_high = 6;
+ source_back_low = 1, source_back_high = 7;
+ num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsNotTightExtremities) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 7, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, -1, 2, 3, 4, 5, 6};
+ vector<int> source_high = {-1, -1, 3, 4, 5, 6, 7};
+ vector<int> target_low = {-1, -1, 2, 3, 4, 5, 6};
+ vector<int> target_high = {-1, -1, 3, 4, 5, 6, 7};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+ EXPECT_FALSE(met_constraints);
+ vector<pair<int, int> > expected_gaps = {make_pair(2, 3), make_pair(4, 5)};
+ EXPECT_EQ(expected_gaps, source_gaps);
+ EXPECT_EQ(expected_gaps, target_gaps);
+
+ source_low = {-1, 1, 2, 3, 4, -1, 6};
+ source_high = {-1, 2, 3, 4, 5, -1, 7};
+ target_low = {-1, 1, 2, 3, 4, -1, 6};
+ target_high = {-1, 2, 3, 4, 5, -1, 7};
+ met_constraints = true;
+ source_gaps.clear();
+ target_gaps.clear();
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+ EXPECT_FALSE(met_constraints);
+ EXPECT_EQ(expected_gaps, source_gaps);
+ EXPECT_EQ(expected_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsWithExtensions) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 5, 2, 3, 4, 1, -1};
+ vector<int> source_high = {-1, 6, 3, 4, 5, 2, -1};
+ vector<int> target_low = {-1, 5, 2, 3, 4, 1, -1};
+ vector<int> target_high = {-1, 6, 3, 4, 5, 2, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {12, 14};
+ vector<int> chunklen = {1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 3;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+ vector<pair<int, int> > expected_source_gaps = {
+ make_pair(1, 2), make_pair(3, 4), make_pair(5, 6)
+ };
+ EXPECT_EQ(expected_source_gaps, source_gaps);
+ vector<pair<int, int> > expected_target_gaps = {
+ make_pair(5, 6), make_pair(3, 4), make_pair(1, 2)
+ };
+ EXPECT_EQ(expected_target_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGaps) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 1, 4, 3, 2, 5, -1};
+ vector<int> source_high = {-1, 2, 5, 4, 3, 6, -1};
+ vector<int> target_low = {-1, 1, 4, 3, 2, 5, -1};
+ vector<int> target_high = {-1, 2, 5, 4, 3, 6, -1};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+ vector<pair<int, int> > expected_source_gaps = {
+ make_pair(2, 3), make_pair(4, 5)
+ };
+ EXPECT_EQ(expected_source_gaps, source_gaps);
+ vector<pair<int, int> > expected_target_gaps = {
+ make_pair(4, 5), make_pair(2, 3)
+ };
+ EXPECT_EQ(expected_target_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 3, 2, 3, 4, 3, -1};
+ vector<int> source_high = {-1, 4, 3, 4, 5, 4, -1};
+ vector<int> target_low = {-1, -1, 2, 1, 4, -1, -1};
+ vector<int> target_high = {-1, -1, 3, 6, 5, -1, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low = 2, source_back_high = 5;
+ vector<int> matching = {12, 14};
+ vector<int> chunklen = {1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 3;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, 5, 10,
+ num_symbols, met_constraints));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetSourceIndexes) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> matching = {13, 18, 21};
+ vector<int> chunklen = {3, 2, 1};
+ unordered_map<int, int> expected_indexes = {
+ {3, 1}, {4, 2}, {5, 3}, {8, 5}, {9, 6}, {11, 8}
+ };
+ EXPECT_EQ(expected_indexes, helper->GetSourceIndexes(matching, chunklen,
+ 1, 10));
+
+ matching = {12, 17};
+ chunklen = {2, 4};
+ expected_indexes = {{2, 0}, {3, 1}, {7, 3}, {8, 4}, {9, 5}, {10, 6}};
+ EXPECT_EQ(expected_indexes, helper->GetSourceIndexes(matching, chunklen,
+ 0, 10));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc
new file mode 100644
index 00000000..5c1501c7
--- /dev/null
+++ b/extractor/rule_extractor_test.cc
@@ -0,0 +1,168 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_scorer.h"
+#include "mocks/mock_target_phrase_extractor.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule_extractor.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class RuleExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(10));
+
+ helper = make_shared<MockRuleExtractorHelper>();
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _))
+ .WillRepeatedly(Return(true));
+ unordered_map<int, int> source_indexes;
+ EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _))
+ .WillRepeatedly(Return(source_indexes));
+
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalValue(87))
+ .WillRepeatedly(Return("a"));
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ vector<int> symbols = {87};
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ PhraseAlignment phrase_alignment = {make_pair(0, 0)};
+
+ target_phrase_extractor = make_shared<MockTargetPhraseExtractor>();
+ vector<pair<Phrase, PhraseAlignment> > target_phrases = {
+ make_pair(target_phrase, phrase_alignment)
+ };
+ EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _))
+ .WillRepeatedly(Return(target_phrases));
+
+ scorer = make_shared<MockScorer>();
+ vector<double> scores = {0.3, 7.2};
+ EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores));
+
+ extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder,
+ scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false);
+ }
+
+ shared_ptr<MockDataArray> source_data_array;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<MockScorer> scorer;
+ shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractor> extractor;
+};
+
+TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set FindFixPoint to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set CheckGaps to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ vector<pair<int, int> > gaps(3);
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(1, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ vector<int> links(10, -1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll(
+ SetArgReferee<0>(links),
+ SetArgReferee<1>(links),
+ SetArgReferee<2>(links),
+ SetArgReferee<3>(links)));
+
+ vector<pair<int, int> > gaps;
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(4, rules.size());
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
new file mode 100644
index 00000000..8c30fb9e
--- /dev/null
+++ b/extractor/rule_factory.cc
@@ -0,0 +1,303 @@
+#include "rule_factory.h"
+
+#include <chrono>
+#include <memory>
+#include <queue>
+#include <vector>
+
+#include "grammar.h"
+#include "fast_intersector.h"
+#include "matchings_finder.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "rule.h"
+#include "rule_extractor.h"
+#include "sampler.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "vocabulary.h"
+
+using namespace std;
+using namespace chrono;
+
+namespace extractor {
+
+typedef high_resolution_clock Clock;
+
+struct State {
+ State(int start, int end, const vector<int>& phrase,
+ const vector<int>& subpatterns_start, shared_ptr<TrieNode> node,
+ bool starts_with_x) :
+ start(start), end(end), phrase(phrase),
+ subpatterns_start(subpatterns_start), node(node),
+ starts_with_x(starts_with_x) {}
+
+ int start, end;
+ vector<int> phrase, subpatterns_start;
+ shared_ptr<TrieNode> node;
+ bool starts_with_x;
+};
+
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ const shared_ptr<Vocabulary>& vocabulary,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ int max_samples,
+ bool require_tight_phrases) :
+ vocabulary(vocabulary),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_nonterminals + 1),
+ max_rule_symbols(max_rule_symbols) {
+ matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
+ fast_intersector = make_shared<FastIntersector>(source_suffix_array,
+ precomputation, vocabulary, max_rule_span, min_gap_size);
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
+ target_data_array, alignment, phrase_builder, scorer, vocabulary,
+ max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
+ false, require_tight_phrases);
+ sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<FastIntersector> fast_intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols) :
+ matchings_finder(finder),
+ fast_intersector(fast_intersector),
+ phrase_builder(phrase_builder),
+ rule_extractor(rule_extractor),
+ vocabulary(vocabulary),
+ sampler(sampler),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_chunks),
+ max_rule_symbols(max_rule_symbols) {}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory() {}
+
+HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
+
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+ Clock::time_point start_time = Clock::now();
+ double total_extract_time = 0;
+ double total_intersect_time = 0;
+ double total_lookup_time = 0;
+
+ MatchingsTrie trie;
+ shared_ptr<TrieNode> root = trie.GetRoot();
+
+ int first_x = vocabulary->GetNonterminalIndex(1);
+ shared_ptr<TrieNode> x_root(new TrieNode(root));
+ root->AddChild(first_x, x_root);
+
+ queue<State> states;
+ for (size_t i = 0; i < word_ids.size(); ++i) {
+ states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false));
+ }
+ for (size_t i = min_gap_size; i < word_ids.size(); ++i) {
+ states.push(State(i - min_gap_size, i, vector<int>(1, first_x),
+ vector<int>(1, i), x_root, true));
+ }
+
+ vector<Rule> rules;
+ while (!states.empty()) {
+ State state = states.front();
+ states.pop();
+
+ shared_ptr<TrieNode> node = state.node;
+ vector<int> phrase = state.phrase;
+ int word_id = word_ids[state.end];
+ phrase.push_back(word_id);
+ Phrase next_phrase = phrase_builder->Build(phrase);
+ shared_ptr<TrieNode> next_node;
+
+ if (CannotHaveMatchings(node, word_id)) {
+ if (!node->HasChild(word_id)) {
+ node->AddChild(word_id, shared_ptr<TrieNode>());
+ }
+ continue;
+ }
+
+ if (RequiresLookup(node, word_id)) {
+ shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ?
+ trie.GetRoot() : node->suffix_link->GetChild(word_id);
+ if (state.starts_with_x) {
+ // If the phrase starts with a non terminal, we simply use the matchings
+ // from the suffix link.
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, next_suffix_link->matchings);
+ } else {
+ PhraseLocation phrase_location;
+ if (next_phrase.Arity() > 0) {
+ // For phrases containing a nonterminal, we use either the occurrences
+ // of the prefix or the suffix to determine the occurrences of the
+ // phrase.
+ Clock::time_point intersect_start = Clock::now();
+ phrase_location = fast_intersector->Intersect(
+ node->matchings, next_suffix_link->matchings, next_phrase);
+ Clock::time_point intersect_stop = Clock::now();
+ total_intersect_time += GetDuration(intersect_start, intersect_stop);
+ } else {
+ // For phrases not containing any nonterminals, we simply query the
+ // suffix array using the suffix array range of the prefix as a
+ // starting point.
+ Clock::time_point lookup_start = Clock::now();
+ phrase_location = matchings_finder->Find(
+ node->matchings,
+ vocabulary->GetTerminalValue(word_id),
+ state.phrase.size());
+ Clock::time_point lookup_stop = Clock::now();
+ total_lookup_time += GetDuration(lookup_start, lookup_stop);
+ }
+
+ if (phrase_location.IsEmpty()) {
+ continue;
+ }
+
+ // Create new trie node to store data about the current phrase.
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, phrase_location);
+ }
+ // Add the new trie node to the trie cache.
+ node->AddChild(word_id, next_node);
+
+ // Automatically adds a trailing non terminal if allowed. Simply copy the
+ // matchings from the prefix node.
+ AddTrailingNonterminal(phrase, next_phrase, next_node,
+ state.starts_with_x);
+
+ Clock::time_point extract_start = Clock::now();
+ if (!state.starts_with_x) {
+ // Extract rules for the sampled set of occurrences.
+ PhraseLocation sample = sampler->Sample(next_node->matchings);
+ vector<Rule> new_rules =
+ rule_extractor->ExtractRules(next_phrase, sample);
+ rules.insert(rules.end(), new_rules.begin(), new_rules.end());
+ }
+ Clock::time_point extract_stop = Clock::now();
+ total_extract_time += GetDuration(extract_start, extract_stop);
+ } else {
+ next_node = node->GetChild(word_id);
+ }
+
+ // Create more states (phrases) to be analyzed.
+ vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase,
+ next_node);
+ for (State new_state: new_states) {
+ states.push(new_state);
+ }
+ }
+
+ Clock::time_point stop_time = Clock::now();
+ #pragma omp critical (stderr_write)
+ {
+ cerr << "Total time for rule lookup, extraction, and scoring = "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Extract time = " << total_extract_time << " seconds" << endl;
+ cerr << "Intersect time = " << total_intersect_time << " seconds" << endl;
+ cerr << "Lookup time = " << total_lookup_time << " seconds" << endl;
+ }
+ return Grammar(rules, scorer->GetFeatureNames());
+}
+
+bool HieroCachingRuleFactory::CannotHaveMatchings(
+ shared_ptr<TrieNode> node, int word_id) {
+ if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) {
+ return true;
+ }
+
+ shared_ptr<TrieNode> suffix_link = node->suffix_link;
+ return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL;
+}
+
+bool HieroCachingRuleFactory::RequiresLookup(
+ shared_ptr<TrieNode> node, int word_id) {
+ return !node->HasChild(word_id);
+}
+
+void HieroCachingRuleFactory::AddTrailingNonterminal(
+ vector<int> symbols,
+ const Phrase& prefix,
+ const shared_ptr<TrieNode>& prefix_node,
+ bool starts_with_x) {
+ if (prefix.Arity() >= max_nonterminals) {
+ return;
+ }
+
+ int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1);
+ symbols.push_back(var_id);
+ Phrase var_phrase = phrase_builder->Build(symbols);
+
+ int suffix_var_id = vocabulary->GetNonterminalIndex(
+ prefix.Arity() + (starts_with_x == 0));
+ shared_ptr<TrieNode> var_suffix_link =
+ prefix_node->suffix_link->GetChild(suffix_var_id);
+
+ prefix_node->AddChild(var_id, make_shared<TrieNode>(
+ var_suffix_link, var_phrase, prefix_node->matchings));
+}
+
+vector<State> HieroCachingRuleFactory::ExtendState(
+ const vector<int>& word_ids,
+ const State& state,
+ vector<int> symbols,
+ const Phrase& phrase,
+ const shared_ptr<TrieNode>& node) {
+ int span = state.end - state.start;
+ vector<State> new_states;
+ if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() ||
+ span >= max_rule_span) {
+ return new_states;
+ }
+
+ // New state for adding the next symbol.
+ new_states.push_back(State(state.start, state.end + 1, symbols,
+ state.subpatterns_start, node, state.starts_with_x));
+
+ int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);
+ if (symbols.size() + 1 >= max_rule_symbols ||
+ phrase.Arity() >= max_nonterminals ||
+ num_subpatterns >= max_chunks) {
+ return new_states;
+ }
+
+ // New states for adding a nonterminal followed by a new symbol.
+ int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1);
+ symbols.push_back(var_id);
+ vector<int> subpatterns_start = state.subpatterns_start;
+ size_t i = state.end + 1 + min_gap_size;
+ while (i < word_ids.size() && i - state.start <= max_rule_span) {
+ subpatterns_start.push_back(i);
+ new_states.push_back(State(state.start, i, symbols, subpatterns_start,
+ node->GetChild(var_id), state.starts_with_x));
+ subpatterns_start.pop_back();
+ ++i;
+ }
+
+ return new_states;
+}
+
+} // namespace extractor
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
new file mode 100644
index 00000000..52e8712a
--- /dev/null
+++ b/extractor/rule_factory.h
@@ -0,0 +1,118 @@
+#ifndef _RULE_FACTORY_H_
+#define _RULE_FACTORY_H_
+
+#include <memory>
+#include <vector>
+
+#include "matchings_trie.h"
+
+using namespace std;
+
+namespace extractor {
+
+class Alignment;
+class DataArray;
+class FastIntersector;
+class Grammar;
+class MatchingsFinder;
+class PhraseBuilder;
+class Precomputation;
+class Rule;
+class RuleExtractor;
+class Sampler;
+class Scorer;
+class State;
+class SuffixArray;
+class Vocabulary;
+
+/**
+ * Component containing most of the logic for extracting SCFG rules for a given
+ * sentence.
+ *
+ * Given a sentence (as a vector of word ids), this class constructs all the
+ * possible source phrases starting from this sentence. For each source phrase,
+ * it finds all its occurrences in the source data and samples some of these
+ * occurrences to extract aligned source-target phrase pairs. A trie cache is
+ * used to avoid unnecessary computations if a source phrase can be constructed
+ * more than once (e.g. some words occur more than once in the sentence).
+ */
+class HieroCachingRuleFactory {
+ public:
+ HieroCachingRuleFactory(
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ const shared_ptr<Vocabulary>& vocabulary,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ int max_samples,
+ bool require_tight_phrases);
+
+ // For testing only.
+ HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<FastIntersector> fast_intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols);
+
+ virtual ~HieroCachingRuleFactory();
+
+ // Constructs SCFG rules for a given sentence.
+ // (See class description for more details.)
+ virtual Grammar GetGrammar(const vector<int>& word_ids);
+
+ protected:
+ HieroCachingRuleFactory();
+
+ private:
+ // Checks if the phrase (if previously encountered) or its prefix have any
+ // occurrences in the source data.
+ bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id);
+
+ // Checks if the phrase has previously been analyzed.
+ bool RequiresLookup(shared_ptr<TrieNode> node, int word_id);
+
+ // Creates a new state in the trie that corresponds to adding a trailing
+ // nonterminal to the current phrase.
+ void AddTrailingNonterminal(vector<int> symbols,
+ const Phrase& prefix,
+ const shared_ptr<TrieNode>& prefix_node,
+ bool starts_with_x);
+
+ // Extends the current state by possibly adding a nonterminal followed by a
+ // terminal.
+ vector<State> ExtendState(const vector<int>& word_ids,
+ const State& state,
+ vector<int> symbols,
+ const Phrase& phrase,
+ const shared_ptr<TrieNode>& node);
+
+ shared_ptr<MatchingsFinder> matchings_finder;
+ shared_ptr<FastIntersector> fast_intersector;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<RuleExtractor> rule_extractor;
+ shared_ptr<Vocabulary> vocabulary;
+ shared_ptr<Sampler> sampler;
+ shared_ptr<Scorer> scorer;
+ int min_gap_size;
+ int max_rule_span;
+ int max_nonterminals;
+ int max_chunks;
+ int max_rule_symbols;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
new file mode 100644
index 00000000..2129dfa0
--- /dev/null
+++ b/extractor/rule_factory_test.cc
@@ -0,0 +1,103 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "grammar.h"
+#include "mocks/mock_fast_intersector.h"
+#include "mocks/mock_matchings_finder.h"
+#include "mocks/mock_rule_extractor.h"
+#include "mocks/mock_sampler.h"
+#include "mocks/mock_scorer.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule_factory.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class RuleFactoryTest : public Test {
+ protected:
+ virtual void SetUp() {
+ finder = make_shared<MockMatchingsFinder>();
+ fast_intersector = make_shared<MockFastIntersector>();
+
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a"));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(3)).WillRepeatedly(Return("b"));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(4)).WillRepeatedly(Return("c"));
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ scorer = make_shared<MockScorer>();
+ feature_names = {"f1"};
+ EXPECT_CALL(*scorer, GetFeatureNames())
+ .WillRepeatedly(Return(feature_names));
+
+ sampler = make_shared<MockSampler>();
+ EXPECT_CALL(*sampler, Sample(_))
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ Phrase phrase;
+ vector<double> scores = {0.5};
+ vector<pair<int, int> > phrase_alignment = {make_pair(0, 0)};
+ vector<Rule> rules = {Rule(phrase, phrase, scores, phrase_alignment)};
+ extractor = make_shared<MockRuleExtractor>();
+ EXPECT_CALL(*extractor, ExtractRules(_, _))
+ .WillRepeatedly(Return(rules));
+ }
+
+ vector<string> feature_names;
+ shared_ptr<MockMatchingsFinder> finder;
+ shared_ptr<MockFastIntersector> fast_intersector;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockScorer> scorer;
+ shared_ptr<MockSampler> sampler;
+ shared_ptr<MockRuleExtractor> extractor;
+ shared_ptr<HieroCachingRuleFactory> factory;
+};
+
+TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, fast_intersector,
+ phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(6)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(1)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ vector<int> word_ids = {2, 3, 4};
+ Grammar grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(7, grammar.GetRules().size());
+}
+
+TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, fast_intersector,
+ phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(12)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(16)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ vector<int> word_ids = {2, 3, 4, 2, 3};
+ Grammar grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(28, grammar.GetRules().size());
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
new file mode 100644
index 00000000..aec83e3b
--- /dev/null
+++ b/extractor/run_extractor.cc
@@ -0,0 +1,242 @@
+#include <chrono>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <omp.h>
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "translation_table.h"
+
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace std;
+using namespace extractor;
+using namespace features;
+
+// Returns the file path in which a given grammar should be written.
+fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
+ string file_name = "grammar." + to_string(file_number);
+ return grammar_path / file_name;
+}
+
+int main(int argc, char** argv) {
+ int num_threads_default = 1;
+ #pragma omp parallel
+ num_threads_default = omp_get_num_threads();
+
+ // Sets up the command line arguments map.
+ po::options_description desc("Command line options");
+ desc.add_options()
+ ("help,h", "Show available options")
+ ("source,f", po::value<string>(), "Source language corpus")
+ ("target,e", po::value<string>(), "Target language corpus")
+ ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
+ ("alignment,a", po::value<string>()->required(), "Bitext word alignment")
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("threads,t", po::value<int>()->default_value(num_threads_default),
+ "Number of parallel extractors")
+ ("frequent", po::value<int>()->default_value(100),
+ "Number of precomputed frequent patterns")
+ ("super_frequent", po::value<int>()->default_value(10),
+ "Number of precomputed super frequent patterns")
+ ("max_rule_span", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_phrase_len", po::value<int>()->default_value(4),
+ "Maximum frequent phrase length")
+ ("max_nonterminals", po::value<int>()->default_value(2),
+ "Maximum number of nonterminals in a rule")
+ ("min_frequency", po::value<int>()->default_value(1000),
+ "Minimum number of occurrences for a pharse to be considered frequent")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)");
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+
+ // Checks for the help option before calling notify, so the we don't get an
+ // exception for missing required arguments.
+ if (vm.count("help")) {
+ cout << desc << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) {
+ cerr << "A paralel corpus is required. "
+ << "Use -f (source) with -e (target) or -b (bitext)."
+ << endl;
+ return 1;
+ }
+
+ int num_threads = vm["threads"].as<int>();
+ cout << "Grammar extraction will use " << num_threads << " threads." << endl;
+
+ // Reads the parallel corpus.
+ Clock::time_point preprocess_start_time = Clock::now();
+ cerr << "Reading source and target data..." << endl;
+ Clock::time_point start_time = Clock::now();
+ shared_ptr<DataArray> source_data_array, target_data_array;
+ if (vm.count("bitext")) {
+ source_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), SOURCE);
+ target_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), TARGET);
+ } else {
+ source_data_array = make_shared<DataArray>(vm["source"].as<string>());
+ target_data_array = make_shared<DataArray>(vm["target"].as<string>());
+ }
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Reading data took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
+
+ // Constructs the suffix array for the source data.
+ cerr << "Creating source suffix array..." << endl;
+ start_time = Clock::now();
+ shared_ptr<SuffixArray> source_suffix_array =
+ make_shared<SuffixArray>(source_data_array);
+ stop_time = Clock::now();
+ cerr << "Creating suffix array took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Reads the alignment.
+ cerr << "Reading alignment..." << endl;
+ start_time = Clock::now();
+ shared_ptr<Alignment> alignment =
+ make_shared<Alignment>(vm["alignment"].as<string>());
+ stop_time = Clock::now();
+ cerr << "Reading alignment took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Constructs an index storing the occurrences in the source data for each
+ // frequent collocation.
+ cerr << "Precomputing collocations..." << endl;
+ start_time = Clock::now();
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
+ source_suffix_array,
+ vm["frequent"].as<int>(),
+ vm["super_frequent"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["min_gap_size"].as<int>(),
+ vm["max_phrase_len"].as<int>(),
+ vm["min_frequency"].as<int>());
+ stop_time = Clock::now();
+ cerr << "Precomputing collocations took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Constructs a table storing p(e | f) and p(f | e) for every pair of source
+ // and target words.
+ cerr << "Precomputing conditional probabilities..." << endl;
+ start_time = Clock::now();
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+ stop_time = Clock::now();
+ cerr << "Precomputing conditional probabilities took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ Clock::time_point preprocess_stop_time = Clock::now();
+ cerr << "Overall preprocessing step took "
+ << GetDuration(preprocess_start_time, preprocess_stop_time)
+ << " seconds" << endl;
+
+ // Features used to score each grammar rule.
+ Clock::time_point extraction_start_time = Clock::now();
+ vector<shared_ptr<Feature> > features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+
+ // Sets up the grammar extractor.
+ GrammarExtractor extractor(
+ source_suffix_array,
+ target_data_array,
+ alignment,
+ precomputation,
+ scorer,
+ vm["min_gap_size"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_nonterminals"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["max_samples"].as<int>(),
+ vm["tight_phrases"].as<bool>());
+
+ // Releases extra memory used by the initial precomputation.
+ precomputation.reset();
+
+ // Creates the grammars directory if it doesn't exist.
+ fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
+ string sentence;
+ vector<string> sentences;
+ while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ // Extracts the grammar for each sentence and saves it to a file.
+ vector<string> suffixes(sentences.size());
+ #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string suffix;
+ int position = sentences[i].find("|||");
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
+ }
+ suffixes[i] = suffix;
+
+ Grammar grammar = extractor.GetGrammar(sentences[i]);
+ ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
+ output << grammar;
+ }
+
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ cout << "<seg grammar=\"" << GetGrammarFilePath(grammar_path, i) << "\" id=\""
+ << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
+ }
+
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
+
+ return 0;
+}
diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt
new file mode 100644
index 00000000..80b446a4
--- /dev/null
+++ b/extractor/sample_alignment.txt
@@ -0,0 +1,2 @@
+0-0 1-1 2-2
+1-0 2-1
diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt
new file mode 100644
index 00000000..93d6b39d
--- /dev/null
+++ b/extractor/sample_bitext.txt
@@ -0,0 +1,2 @@
+ana are mere . ||| anna has apples .
+ana bea mult lapte . ||| anna drinks a lot of milk .
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
new file mode 100644
index 00000000..d81956b5
--- /dev/null
+++ b/extractor/sampler.cc
@@ -0,0 +1,46 @@
+#include "sampler.h"
+
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+namespace extractor {
+
+Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) :
+ suffix_array(suffix_array), max_samples(max_samples) {}
+
+Sampler::Sampler() {}
+
+Sampler::~Sampler() {}
+
+PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
+ vector<int> sample;
+ int num_subpatterns;
+ if (location.matchings == NULL) {
+ // Sample suffix array range.
+ num_subpatterns = 1;
+ int low = location.sa_low, high = location.sa_high;
+ double step = max(1.0, (double) (high - low) / max_samples);
+ for (double i = low; i < high && sample.size() < max_samples; i += step) {
+ sample.push_back(suffix_array->GetSuffix(Round(i)));
+ }
+ } else {
+ // Sample vector of occurrences.
+ num_subpatterns = location.num_subpatterns;
+ int num_matchings = location.matchings->size() / num_subpatterns;
+ double step = max(1.0, (double) num_matchings / max_samples);
+ for (double i = 0, num_samples = 0;
+ i < num_matchings && num_samples < max_samples;
+ i += step, ++num_samples) {
+ int start = Round(i) * num_subpatterns;
+ sample.insert(sample.end(), location.matchings->begin() + start,
+ location.matchings->begin() + start + num_subpatterns);
+ }
+ }
+ return PhraseLocation(sample, num_subpatterns);
+}
+
+int Sampler::Round(double x) const {
+ return x + 0.5;
+}
+
+} // namespace extractor
diff --git a/extractor/sampler.h b/extractor/sampler.h
new file mode 100644
index 00000000..be4aa1bb
--- /dev/null
+++ b/extractor/sampler.h
@@ -0,0 +1,38 @@
+#ifndef _SAMPLER_H_
+#define _SAMPLER_H_
+
+#include <memory>
+
+using namespace std;
+
+namespace extractor {
+
+class PhraseLocation;
+class SuffixArray;
+
+/**
+ * Provides uniform sampling for a PhraseLocation.
+ */
+class Sampler {
+ public:
+ Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
+
+ virtual ~Sampler();
+
+ // Samples uniformly at most max_samples phrase occurrences.
+ virtual PhraseLocation Sample(const PhraseLocation& location) const;
+
+ protected:
+ Sampler();
+
+ private:
+ // Round floating point number to the nearest integer.
+ int Round(double x) const;
+
+ shared_ptr<SuffixArray> suffix_array;
+ int max_samples;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
new file mode 100644
index 00000000..e9abebfa
--- /dev/null
+++ b/extractor/sampler_test.cc
@@ -0,0 +1,74 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_suffix_array.h"
+#include "phrase_location.h"
+#include "sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ suffix_array = make_shared<MockSuffixArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<Sampler> sampler;
+};
+
+TEST_F(SamplerTest, TestSuffixArrayRange) {
+ PhraseLocation location(0, 10);
+
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ vector<int> expected_locations = {0};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {0, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {0, 3, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 4);
+ expected_locations = {0, 3, 5, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+}
+
+TEST_F(SamplerTest, TestSubstringsSample) {
+ vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ PhraseLocation location(locations, 2);
+
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ vector<int> expected_locations = {0, 1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {0, 1, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {0, 1, 4, 5, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 7);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/scorer.cc b/extractor/scorer.cc
new file mode 100644
index 00000000..d3ebf1c9
--- /dev/null
+++ b/extractor/scorer.cc
@@ -0,0 +1,30 @@
+#include "scorer.h"
+
+#include "features/feature.h"
+
+namespace extractor {
+
+Scorer::Scorer(const vector<shared_ptr<features::Feature> >& features) :
+ features(features) {}
+
+Scorer::Scorer() {}
+
+Scorer::~Scorer() {}
+
+vector<double> Scorer::Score(const features::FeatureContext& context) const {
+ vector<double> scores;
+ for (auto feature: features) {
+ scores.push_back(feature->Score(context));
+ }
+ return scores;
+}
+
+vector<string> Scorer::GetFeatureNames() const {
+ vector<string> feature_names;
+ for (auto feature: features) {
+ feature_names.push_back(feature->GetName());
+ }
+ return feature_names;
+}
+
+} // namespace extractor
diff --git a/extractor/scorer.h b/extractor/scorer.h
new file mode 100644
index 00000000..af8a3b10
--- /dev/null
+++ b/extractor/scorer.h
@@ -0,0 +1,41 @@
+#ifndef _SCORER_H_
+#define _SCORER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+namespace features {
+ class Feature;
+ class FeatureContext;
+} // namespace features
+
+/**
+ * Computes the feature scores for a source-target phrase pair.
+ */
+class Scorer {
+ public:
+ Scorer(const vector<shared_ptr<features::Feature> >& features);
+
+ virtual ~Scorer();
+
+ // Computes the feature score for the given context.
+ virtual vector<double> Score(const features::FeatureContext& context) const;
+
+ // Returns the set of feature names used to score any context.
+ virtual vector<string> GetFeatureNames() const;
+
+ protected:
+ Scorer();
+
+ private:
+ vector<shared_ptr<features::Feature> > features;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc
new file mode 100644
index 00000000..3a09c9cc
--- /dev/null
+++ b/extractor/scorer_test.cc
@@ -0,0 +1,49 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "mocks/mock_feature.h"
+#include "scorer.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class ScorerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature1 = make_shared<features::MockFeature>();
+ EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5));
+ EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1"));
+
+ feature2 = make_shared<features::MockFeature>();
+ EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3));
+ EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2"));
+
+ vector<shared_ptr<features::Feature> > features = {feature1, feature2};
+ scorer = make_shared<Scorer>(features);
+ }
+
+ shared_ptr<features::MockFeature> feature1;
+ shared_ptr<features::MockFeature> feature2;
+ shared_ptr<Scorer> scorer;
+};
+
+TEST_F(ScorerTest, TestScore) {
+ vector<double> expected_scores = {0.5, -1.3};
+ Phrase phrase;
+ features::FeatureContext context(phrase, phrase, 0.3, 2, 11);
+ EXPECT_EQ(expected_scores, scorer->Score(context));
+}
+
+TEST_F(ScorerTest, TestGetNames) {
+ vector<string> expected_names = {"f1", "f2"};
+ EXPECT_EQ(expected_names, scorer->GetFeatureNames());
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
new file mode 100644
index 00000000..65b2d581
--- /dev/null
+++ b/extractor/suffix_array.cc
@@ -0,0 +1,235 @@
+#include "suffix_array.h"
+
+#include <cassert>
+#include <chrono>
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include "data_array.h"
+#include "phrase_location.h"
+#include "time_util.h"
+
+namespace fs = boost::filesystem;
+using namespace std;
+using namespace chrono;
+
+namespace extractor {
+
+SuffixArray::SuffixArray(shared_ptr<DataArray> data_array) :
+ data_array(data_array) {
+ BuildSuffixArray();
+}
+
+SuffixArray::SuffixArray() {}
+
+SuffixArray::~SuffixArray() {}
+
+void SuffixArray::BuildSuffixArray() {
+ vector<int> groups = data_array->GetData();
+ groups.reserve(groups.size() + 1);
+ groups.push_back(DataArray::NULL_WORD);
+ suffix_array.resize(groups.size());
+ word_start.resize(data_array->GetVocabularySize() + 1);
+
+ InitialBucketSort(groups);
+
+ int combined_group_size = 0;
+ for (size_t i = 1; i < word_start.size(); ++i) {
+ if (word_start[i] - word_start[i - 1] == 1) {
+ ++combined_group_size;
+ suffix_array[word_start[i] - combined_group_size] = -combined_group_size;
+ } else {
+ combined_group_size = 0;
+ }
+ }
+
+ PrefixDoublingSort(groups);
+ cerr << "\tFinalizing sort..." << endl;
+
+ for (size_t i = 0; i < groups.size(); ++i) {
+ suffix_array[groups[i]] = i;
+ }
+}
+
+void SuffixArray::InitialBucketSort(vector<int>& groups) {
+ Clock::time_point start_time = Clock::now();
+ for (size_t i = 0; i < groups.size(); ++i) {
+ ++word_start[groups[i]];
+ }
+
+ for (size_t i = 1; i < word_start.size(); ++i) {
+ word_start[i] += word_start[i - 1];
+ }
+
+ for (size_t i = 0; i < groups.size(); ++i) {
+ --word_start[groups[i]];
+ suffix_array[word_start[groups[i]]] = i;
+ }
+
+ for (size_t i = 0; i < suffix_array.size(); ++i) {
+ groups[i] = word_start[groups[i] + 1] - 1;
+ }
+ Clock::time_point stop_time = Clock::now();
+ cerr << "\tBucket sort took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
+}
+
+void SuffixArray::PrefixDoublingSort(vector<int>& groups) {
+ int step = 1;
+ while (suffix_array[0] != -suffix_array.size()) {
+ int combined_group_size = 0;
+ int i = 0;
+ while (i < suffix_array.size()) {
+ if (suffix_array[i] < 0) {
+ int skip = -suffix_array[i];
+ combined_group_size += skip;
+ i += skip;
+ suffix_array[i - combined_group_size] = -combined_group_size;
+ } else {
+ combined_group_size = 0;
+ int j = groups[suffix_array[i]];
+ TernaryQuicksort(i, j, step, groups);
+ i = j + 1;
+ }
+ }
+ step *= 2;
+ }
+}
+
+void SuffixArray::TernaryQuicksort(int left, int right, int step,
+ vector<int>& groups) {
+ if (left > right) {
+ return;
+ }
+
+ int pivot = left + rand() % (right - left + 1);
+ int pivot_value = groups[suffix_array[pivot] + step];
+ swap(suffix_array[pivot], suffix_array[left]);
+ int mid_left = left, mid_right = left;
+ for (int i = left + 1; i <= right; ++i) {
+ if (groups[suffix_array[i] + step] < pivot_value) {
+ ++mid_right;
+ int temp = suffix_array[i];
+ suffix_array[i] = suffix_array[mid_right];
+ suffix_array[mid_right] = suffix_array[mid_left];
+ suffix_array[mid_left] = temp;
+ ++mid_left;
+ } else if (groups[suffix_array[i] + step] == pivot_value) {
+ ++mid_right;
+ int temp = suffix_array[i];
+ suffix_array[i] = suffix_array[mid_right];
+ suffix_array[mid_right] = temp;
+ }
+ }
+
+ TernaryQuicksort(left, mid_left - 1, step, groups);
+
+ if (mid_left == mid_right) {
+ groups[suffix_array[mid_left]] = mid_left;
+ suffix_array[mid_left] = -1;
+ } else {
+ for (int i = mid_left; i <= mid_right; ++i) {
+ groups[suffix_array[i]] = mid_right;
+ }
+ }
+
+ TernaryQuicksort(mid_right + 1, right, step, groups);
+}
+
+vector<int> SuffixArray::BuildLCPArray() const {
+ Clock::time_point start_time = Clock::now();
+ cerr << "\tConstructing LCP array..." << endl;
+
+ vector<int> lcp(suffix_array.size());
+ vector<int> rank(suffix_array.size());
+ const vector<int>& data = data_array->GetData();
+
+ for (size_t i = 0; i < suffix_array.size(); ++i) {
+ rank[suffix_array[i]] = i;
+ }
+
+ int prefix_len = 0;
+ for (size_t i = 0; i < suffix_array.size(); ++i) {
+ if (rank[i] == 0) {
+ lcp[rank[i]] = -1;
+ } else {
+ int j = suffix_array[rank[i] - 1];
+ while (i + prefix_len < data.size() && j + prefix_len < data.size()
+ && data[i + prefix_len] == data[j + prefix_len]) {
+ ++prefix_len;
+ }
+ lcp[rank[i]] = prefix_len;
+ }
+
+ if (prefix_len > 0) {
+ --prefix_len;
+ }
+ }
+
+ Clock::time_point stop_time = Clock::now();
+ cerr << "\tConstructing LCP took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ return lcp;
+}
+
+int SuffixArray::GetSuffix(int rank) const {
+ return suffix_array[rank];
+}
+
+int SuffixArray::GetSize() const {
+ return suffix_array.size();
+}
+
+shared_ptr<DataArray> SuffixArray::GetData() const {
+ return data_array;
+}
+
+void SuffixArray::WriteBinary(const fs::path& filepath) const {
+ FILE* file = fopen(filepath.string().c_str(), "w");
+ assert(file);
+ data_array->WriteBinary(file);
+
+ int size = suffix_array.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(suffix_array.data(), sizeof(int), size, file);
+
+ size = word_start.size();
+ fwrite(&size, sizeof(int), 1, file);
+ fwrite(word_start.data(), sizeof(int), size, file);
+}
+
+PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
+ int offset) const {
+ if (!data_array->HasWord(word)) {
+ // Return empty phrase location.
+ return PhraseLocation(0, 0);
+ }
+
+ int word_id = data_array->GetWordId(word);
+ if (offset == 0) {
+ return PhraseLocation(word_start[word_id], word_start[word_id + 1]);
+ }
+
+ return PhraseLocation(LookupRangeStart(low, high, word_id, offset),
+ LookupRangeStart(low, high, word_id + 1, offset));
+}
+
+int SuffixArray::LookupRangeStart(int low, int high, int word_id,
+ int offset) const {
+ int result = high;
+ while (low < high) {
+ int middle = low + (high - low) / 2;
+ if (suffix_array[middle] + offset >= data_array->GetSize() ||
+ data_array->AtIndex(suffix_array[middle] + offset) < word_id) {
+ low = middle + 1;
+ } else {
+ result = middle;
+ high = middle;
+ }
+ }
+ return result;
+}
+
+} // namespace extractor
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
new file mode 100644
index 00000000..bf731d79
--- /dev/null
+++ b/extractor/suffix_array.h
@@ -0,0 +1,75 @@
+#ifndef _SUFFIX_ARRAY_H_
+#define _SUFFIX_ARRAY_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <boost/filesystem.hpp>
+
+namespace fs = boost::filesystem;
+using namespace std;
+
+namespace extractor {
+
+class DataArray;
+class PhraseLocation;
+
+class SuffixArray {
+ public:
+ // Creates a suffix array from a data array.
+ SuffixArray(shared_ptr<DataArray> data_array);
+
+ virtual ~SuffixArray();
+
+ // Returns the size of the suffix array.
+ virtual int GetSize() const;
+
+ // Returns the data array on top of which the suffix array is constructed.
+ virtual shared_ptr<DataArray> GetData() const;
+
+ // Constructs the longest-common-prefix array using the algorithm of Kasai et
+ // al. (2001).
+ virtual vector<int> BuildLCPArray() const;
+
+ // Returns the i-th suffix.
+ virtual int GetSuffix(int rank) const;
+
+ // Given the range in which a phrase is located and the next word, returns the
+ // range corresponding to the phrase extended with the next word.
+ virtual PhraseLocation Lookup(int low, int high, const string& word,
+ int offset) const;
+
+ void WriteBinary(const fs::path& filepath) const;
+
+ protected:
+ SuffixArray();
+
+ private:
+ // Constructs the suffix array using the algorithm of Larsson and Sadakane
+ // (1999).
+ void BuildSuffixArray();
+
+ // Bucket sort on the data array (used for initializing the construction of
+ // the suffix array.)
+ void InitialBucketSort(vector<int>& groups);
+
+ void TernaryQuicksort(int left, int right, int step, vector<int>& groups);
+
+ // Constructs the suffix array in log(n) steps by doubling the length of the
+ // suffixes at each step.
+ void PrefixDoublingSort(vector<int>& groups);
+
+ // Given a [low, high) range in the suffix array in which all elements have
+ // the first offset-1 values the same, it returns the first position where the
+ // offset value is greater or equal to word_id.
+ int LookupRangeStart(int low, int high, int word_id, int offset) const;
+
+ shared_ptr<DataArray> data_array;
+ vector<int> suffix_array;
+ vector<int> word_start;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
new file mode 100644
index 00000000..8431a16e
--- /dev/null
+++ b/extractor/suffix_array_test.cc
@@ -0,0 +1,78 @@
+#include <gtest/gtest.h>
+
+#include "mocks/mock_data_array.h"
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+#include <vector>
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SuffixArrayTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};
+ data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+ EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
+ EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
+ suffix_array = make_shared<SuffixArray>(data_array);
+ }
+
+ vector<int> data;
+ shared_ptr<SuffixArray> suffix_array;
+ shared_ptr<MockDataArray> data_array;
+};
+
+TEST_F(SuffixArrayTest, TestData) {
+ EXPECT_EQ(data_array, suffix_array->GetData());
+ EXPECT_EQ(14, suffix_array->GetSize());
+}
+
+TEST_F(SuffixArrayTest, TestBuildSuffixArray) {
+ vector<int> expected_suffix_array =
+ {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8};
+ for (size_t i = 0; i < expected_suffix_array.size(); ++i) {
+ EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i));
+ }
+}
+
+TEST_F(SuffixArrayTest, TestBuildLCP) {
+ vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1};
+ EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray());
+}
+
+TEST_F(SuffixArrayTest, TestLookup) {
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+ }
+
+ EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
+ EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
+ EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0));
+
+ EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
+ EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0));
+
+ EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
+ EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1));
+
+ EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
+ EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2));
+
+ EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
+ EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3));
+
+ EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4));
+ EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc
new file mode 100644
index 00000000..2b8a2e4a
--- /dev/null
+++ b/extractor/target_phrase_extractor.cc
@@ -0,0 +1,158 @@
+#include "target_phrase_extractor.h"
+
+#include <unordered_set>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "rule_extractor_helper.h"
+#include "vocabulary.h"
+
+using namespace std;
+
+namespace extractor {
+
+TargetPhraseExtractor::TargetPhraseExtractor(
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractorHelper> helper,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ bool require_tight_phrases) :
+ target_data_array(target_data_array),
+ alignment(alignment),
+ phrase_builder(phrase_builder),
+ helper(helper),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ require_tight_phrases(require_tight_phrases) {}
+
+TargetPhraseExtractor::TargetPhraseExtractor() {}
+
+TargetPhraseExtractor::~TargetPhraseExtractor() {}
+
+vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const {
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+
+ vector<int> target_gap_order = helper->GetGapOrder(target_gaps);
+
+ int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
+ if (!require_tight_phrases) {
+ // Extend loose target phrase to the left.
+ while (target_x_low > 0 &&
+ target_phrase_high - target_x_low < max_rule_span &&
+ target_low[target_x_low - 1] == -1) {
+ --target_x_low;
+ }
+ // Extend loose target phrase to the right.
+ while (target_x_high < target_sent_len &&
+ target_x_high - target_phrase_low < max_rule_span &&
+ target_low[target_x_high] == -1) {
+ ++target_x_high;
+ }
+ }
+
+ vector<pair<int, int> > gaps(target_gaps.size());
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ gaps[i] = target_gaps[target_gap_order[i]];
+ if (!require_tight_phrases) {
+ // Extend gap to the left.
+ while (gaps[i].first > target_x_low &&
+ target_low[gaps[i].first - 1] == -1) {
+ --gaps[i].first;
+ }
+ // Extend gap to the right.
+ while (gaps[i].second < target_x_high &&
+ target_low[gaps[i].second] == -1) {
+ ++gaps[i].second;
+ }
+ }
+ }
+
+ // Compute the range in which each chunk may start or end. (Even indexes
+ // represent the range in which the chunk may start, odd indexes represent the
+ // range in which the chunk may end.)
+ vector<pair<int, int> > ranges(2 * gaps.size() + 2);
+ ranges.front() = make_pair(target_x_low, target_phrase_low);
+ ranges.back() = make_pair(target_phrase_high, target_x_high);
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ int j = target_gap_order[i];
+ ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first);
+ ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second);
+ }
+
+ vector<pair<Phrase, PhraseAlignment> > target_phrases;
+ vector<int> subpatterns(ranges.size());
+ GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order,
+ target_phrase_low, target_phrase_high, source_indexes,
+ sentence_id);
+ return target_phrases;
+}
+
+void TargetPhraseExtractor::GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns,
+ const vector<int>& target_gap_order, int target_phrase_low,
+ int target_phrase_high, const unordered_map<int, int>& source_indexes,
+ int sentence_id) const {
+ if (index >= ranges.size()) {
+ if (subpatterns.back() - subpatterns.front() > max_rule_span) {
+ return;
+ }
+
+ vector<int> symbols;
+ unordered_map<int, int> target_indexes;
+
+ // Construct target phrase chunk by chunk.
+ int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
+ for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
+ target_indexes[j] = symbols.size();
+ string target_word = target_data_array->GetWordAtIndex(
+ target_sent_start + j);
+ symbols.push_back(vocabulary->GetTerminalIndex(target_word));
+ }
+ if (i < target_gap_order.size()) {
+ symbols.push_back(vocabulary->GetNonterminalIndex(
+ target_gap_order[i] + 1));
+ }
+ }
+
+ // Construct the alignment between the source and the target phrase.
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ vector<pair<int, int> > alignment;
+ for (pair<int, int> link: links) {
+ if (target_indexes.count(link.second)) {
+ alignment.push_back(make_pair(source_indexes.find(link.first)->second,
+ target_indexes[link.second]));
+ }
+ }
+
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ target_phrases.push_back(make_pair(target_phrase, alignment));
+ return;
+ }
+
+ subpatterns[index] = ranges[index].first;
+ if (index > 0) {
+ subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
+ }
+ // Choose every possible combination of [start, end) for the current chunk.
+ while (subpatterns[index] <= ranges[index].second) {
+ subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
+ while (subpatterns[index + 1] <= ranges[index + 1].second) {
+ GeneratePhrases(target_phrases, ranges, index + 2, subpatterns,
+ target_gap_order, target_phrase_low, target_phrase_high,
+ source_indexes, sentence_id);
+ ++subpatterns[index + 1];
+ }
+ ++subpatterns[index];
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h
new file mode 100644
index 00000000..289bae2f
--- /dev/null
+++ b/extractor/target_phrase_extractor.h
@@ -0,0 +1,64 @@
+#ifndef _TARGET_PHRASE_EXTRACTOR_H_
+#define _TARGET_PHRASE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+typedef vector<pair<int, int> > PhraseAlignment;
+
+class Alignment;
+class DataArray;
+class Phrase;
+class PhraseBuilder;
+class RuleExtractorHelper;
+class Vocabulary;
+
+class TargetPhraseExtractor {
+ public:
+ TargetPhraseExtractor(shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractorHelper> helper,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ bool require_tight_phrases);
+
+ virtual ~TargetPhraseExtractor();
+
+ // Finds all the target phrases that can extracted from a span in the
+ // target sentence (matching the given set of target phrase gaps).
+ virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const;
+
+ protected:
+ TargetPhraseExtractor();
+
+ private:
+ // Computes the cartesian product over the sets of possible target phrase
+ // chunks.
+ void GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index,
+ vector<int>& subpatterns, const vector<int>& target_gap_order,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const;
+
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<Alignment> alignment;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<RuleExtractorHelper> helper;
+ shared_ptr<Vocabulary> vocabulary;
+ int max_rule_span;
+ bool require_tight_phrases;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc
new file mode 100644
index 00000000..80927dee
--- /dev/null
+++ b/extractor/target_phrase_extractor_test.cc
@@ -0,0 +1,143 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "target_phrase_extractor.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class TargetPhraseExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ alignment = make_shared<MockAlignment>();
+ vocabulary = make_shared<MockVocabulary>();
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ helper = make_shared<MockRuleExtractorHelper>();
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<TargetPhraseExtractor> extractor;
+};
+
+TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) {
+ EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5));
+ EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3));
+
+ vector<string> target_words = {"a", "b", "c", "d", "e"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24};
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i + 3))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1),
+ make_pair(4, 4)
+ };
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {1, 0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, true);
+
+ vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)};
+ vector<int> target_low = {0, 3, 2, 1, 4};
+ unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 0, 5, source_indexes, 1);
+ EXPECT_EQ(1, results.size());
+ vector<int> expected_symbols = {20, -2, 22, -1, 24};
+ EXPECT_EQ(expected_symbols, results[0].first.Get());
+ vector<string> expected_words = {"a", "c", "e"};
+ EXPECT_EQ(expected_words, results[0].first.GetWords());
+ vector<pair<int, int> > expected_alignment = {
+ make_pair(0, 0), make_pair(2, 2), make_pair(4, 4)
+ };
+ EXPECT_EQ(expected_alignment, results[0].second);
+}
+
+TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) {
+ vector<string> target_words = {"a", "b", "c", "d", "e", "f", "END_OF_LINE"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 1};
+ EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6));
+ EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0));
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {make_pair(1, 1)};
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, false);
+
+ vector<pair<int, int> > target_gaps = {make_pair(2, 4)};
+ vector<int> target_low = {-1, 1, -1, -1, -1, -1};
+ unordered_map<int, int> source_indexes = {{1, 1}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 1, 5, source_indexes, 0);
+ EXPECT_EQ(10, results.size());
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 4; j <= 6; ++j) {
+ for (int k = 4; k <= j; ++k) {
+ vector<string> expected_words;
+ for (int l = i; l < 2; ++l) {
+ expected_words.push_back(target_words[l]);
+ }
+ for (int l = k; l < j; ++l) {
+ expected_words.push_back(target_words[l]);
+ }
+
+ PhraseAlignment expected_alignment;
+ expected_alignment.push_back(make_pair(1, 1 - i));
+
+ bool found_expected_pair = false;
+ for (auto result: results) {
+ if (result.first.GetWords() == expected_words &&
+ result.second == expected_alignment) {
+ found_expected_pair = true;
+ }
+ }
+
+ EXPECT_TRUE(found_expected_pair);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/time_util.cc b/extractor/time_util.cc
new file mode 100644
index 00000000..e46a0c3d
--- /dev/null
+++ b/extractor/time_util.cc
@@ -0,0 +1,10 @@
+#include "time_util.h"
+
+namespace extractor {
+
+double GetDuration(const Clock::time_point& start_time,
+ const Clock::time_point& stop_time) {
+ return duration_cast<milliseconds>(stop_time - start_time).count() / 1000.0;
+}
+
+} // namespace extractor
diff --git a/extractor/time_util.h b/extractor/time_util.h
new file mode 100644
index 00000000..f7fd51d3
--- /dev/null
+++ b/extractor/time_util.h
@@ -0,0 +1,19 @@
+#ifndef _TIME_UTIL_H_
+#define _TIME_UTIL_H_
+
+#include <chrono>
+
+using namespace std;
+using namespace chrono;
+
+namespace extractor {
+
+typedef high_resolution_clock Clock;
+
+// Computes the duration in seconds of the specified time interval.
+double GetDuration(const Clock::time_point& start_time,
+ const Clock::time_point& stop_time);
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
new file mode 100644
index 00000000..45da707a
--- /dev/null
+++ b/extractor/translation_table.cc
@@ -0,0 +1,126 @@
+#include "translation_table.h"
+
+#include <string>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+#include "alignment.h"
+#include "data_array.h"
+
+using namespace std;
+
+namespace extractor {
+
+TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment) :
+ source_data_array(source_data_array), target_data_array(target_data_array) {
+ const vector<int>& source_data = source_data_array->GetData();
+ const vector<int>& target_data = target_data_array->GetData();
+
+ unordered_map<int, int> source_links_count;
+ unordered_map<int, int> target_links_count;
+ unordered_map<pair<int, int>, int, PairHash> links_count;
+
+ // For each pair of aligned source target words increment their link count by
+ // 1. Unaligned words are paired with the NULL token.
+ for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) {
+ vector<pair<int, int> > links = alignment->GetLinks(i);
+ int source_start = source_data_array->GetSentenceStart(i);
+ int target_start = target_data_array->GetSentenceStart(i);
+ // Ignore END_OF_LINE markers.
+ int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1;
+ int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1;
+ vector<int> source_sentence(source_data.begin() + source_start,
+ source_data.begin() + next_source_start);
+ vector<int> target_sentence(target_data.begin() + target_start,
+ target_data.begin() + next_target_start);
+ vector<int> source_linked_words(source_sentence.size());
+ vector<int> target_linked_words(target_sentence.size());
+
+ for (pair<int, int> link: links) {
+ source_linked_words[link.first] = 1;
+ target_linked_words[link.second] = 1;
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
+ source_sentence[link.first], target_sentence[link.second]);
+ }
+
+ for (size_t i = 0; i < source_sentence.size(); ++i) {
+ if (!source_linked_words[i]) {
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
+ source_sentence[i], DataArray::NULL_WORD);
+ }
+ }
+
+ for (size_t i = 0; i < target_sentence.size(); ++i) {
+ if (!target_linked_words[i]) {
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
+ DataArray::NULL_WORD, target_sentence[i]);
+ }
+ }
+ }
+
+ // Calculating:
+ // p(e | f) = count(e, f) / count(f)
+ // p(f | e) = count(e, f) / count(e)
+ for (pair<pair<int, int>, int> link_count: links_count) {
+ int source_word = link_count.first.first;
+ int target_word = link_count.first.second;
+ double score1 = 1.0 * link_count.second / source_links_count[source_word];
+ double score2 = 1.0 * link_count.second / target_links_count[target_word];
+ translation_probabilities[link_count.first] = make_pair(score1, score2);
+ }
+}
+
+TranslationTable::TranslationTable() {}
+
+TranslationTable::~TranslationTable() {}
+
+void TranslationTable::IncrementLinksCount(
+ unordered_map<int, int>& source_links_count,
+ unordered_map<int, int>& target_links_count,
+ unordered_map<pair<int, int>, int, PairHash>& links_count,
+ int source_word_id,
+ int target_word_id) const {
+ ++source_links_count[source_word_id];
+ ++target_links_count[target_word_id];
+ ++links_count[make_pair(source_word_id, target_word_id)];
+}
+
+double TranslationTable::GetTargetGivenSourceScore(
+ const string& source_word, const string& target_word) {
+ if (!source_data_array->HasWord(source_word) ||
+ !target_data_array->HasWord(target_word)) {
+ return -1;
+ }
+
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ return translation_probabilities[make_pair(source_id, target_id)].first;
+}
+
+double TranslationTable::GetSourceGivenTargetScore(
+ const string& source_word, const string& target_word) {
+ if (!source_data_array->HasWord(source_word) ||
+ !target_data_array->HasWord(target_word)) {
+ return -1;
+ }
+
+ int source_id = source_data_array->GetWordId(source_word);
+ int target_id = target_data_array->GetWordId(target_word);
+ return translation_probabilities[make_pair(source_id, target_id)].second;
+}
+
+void TranslationTable::WriteBinary(const fs::path& filepath) const {
+ FILE* file = fopen(filepath.string().c_str(), "w");
+
+ int size = translation_probabilities.size();
+ fwrite(&size, sizeof(int), 1, file);
+ for (auto entry: translation_probabilities) {
+ fwrite(&entry.first, sizeof(entry.first), 1, file);
+ fwrite(&entry.second, sizeof(entry.second), 1, file);
+ }
+}
+
+} // namespace extractor
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
new file mode 100644
index 00000000..10504d3b
--- /dev/null
+++ b/extractor/translation_table.h
@@ -0,0 +1,63 @@
+#ifndef _TRANSLATION_TABLE_
+#define _TRANSLATION_TABLE_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include <boost/filesystem.hpp>
+#include <boost/functional/hash.hpp>
+
+using namespace std;
+namespace fs = boost::filesystem;
+
+namespace extractor {
+
+typedef boost::hash<pair<int, int> > PairHash;
+
+class Alignment;
+class DataArray;
+
+/**
+ * Bilexical table with conditional probabilities.
+ */
+class TranslationTable {
+ public:
+ TranslationTable(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment);
+
+ virtual ~TranslationTable();
+
+ // Returns p(e | f).
+ virtual double GetTargetGivenSourceScore(const string& source_word,
+ const string& target_word);
+
+ // Returns p(f | e).
+ virtual double GetSourceGivenTargetScore(const string& source_word,
+ const string& target_word);
+
+ void WriteBinary(const fs::path& filepath) const;
+
+ protected:
+ TranslationTable();
+
+ private:
+ // Increment links count for the given (f, e) word pair.
+ void IncrementLinksCount(
+ unordered_map<int, int>& source_links_count,
+ unordered_map<int, int>& target_links_count,
+ unordered_map<pair<int, int>, int, PairHash>& links_count,
+ int source_word_id,
+ int target_word_id) const;
+
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<DataArray> target_data_array;
+ unordered_map<pair<int, int>, pair<double, double>, PairHash>
+ translation_probabilities;
+};
+
+} // namespace extractor
+
+#endif
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
new file mode 100644
index 00000000..051b5715
--- /dev/null
+++ b/extractor/translation_table_test.cc
@@ -0,0 +1,84 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "translation_table.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+TEST(TranslationTableTest, TestScores) {
+ vector<string> words = {"a", "b", "c"};
+
+ vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0};
+ vector<int> source_sentence_start = {0, 6, 10, 14};
+ shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetData())
+ .WillRepeatedly(ReturnRef(source_data));
+ EXPECT_CALL(*source_data_array, GetNumSentences())
+ .WillRepeatedly(Return(3));
+ for (size_t i = 0; i < source_sentence_start.size(); ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(source_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*source_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*source_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*source_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
+ vector<int> target_sentence_start = {0, 7, 10, 13};
+ shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetData())
+ .WillRepeatedly(ReturnRef(target_data));
+ for (size_t i = 0; i < target_sentence_start.size(); ++i) {
+ EXPECT_CALL(*target_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(target_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*target_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*target_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*target_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<pair<int, int> > links1 = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
+ make_pair(4, 4), make_pair(4, 5)
+ };
+ vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)};
+ vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)};
+ shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1));
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2));
+ EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3));
+
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+
+ EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a"));
+ EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b"));
+ EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c"));
+ EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d"));
+
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a"));
+ EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b"));
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c"));
+ EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d"));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
new file mode 100644
index 00000000..15795d1e
--- /dev/null
+++ b/extractor/vocabulary.cc
@@ -0,0 +1,37 @@
+#include "vocabulary.h"
+
+namespace extractor {
+
+Vocabulary::~Vocabulary() {}
+
+int Vocabulary::GetTerminalIndex(const string& word) {
+ int word_id = -1;
+ #pragma omp critical (vocabulary)
+ {
+ if (!dictionary.count(word)) {
+ word_id = words.size();
+ dictionary[word] = word_id;
+ words.push_back(word);
+ } else {
+ word_id = dictionary[word];
+ }
+ }
+ return word_id;
+}
+
+int Vocabulary::GetNonterminalIndex(int position) {
+ return -position;
+}
+
+bool Vocabulary::IsTerminal(int symbol) {
+ return symbol >= 0;
+}
+
+string Vocabulary::GetTerminalValue(int symbol) {
+ string word;
+ #pragma omp critical (vocabulary)
+ word = words[symbol];
+ return word;
+}
+
+} // namespace extractor
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
new file mode 100644
index 00000000..c8fd9411
--- /dev/null
+++ b/extractor/vocabulary.h
@@ -0,0 +1,48 @@
+#ifndef _VOCABULARY_H_
+#define _VOCABULARY_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+using namespace std;
+
+namespace extractor {
+
+/**
+ * Data structure for mapping words to word ids.
+ *
+ * This strucure contains words located in the frequent collocations and words
+ * encountered during the grammar extraction time. This dictionary is
+ * considerably smaller than the dictionaries in the data arrays (and so is the
+ * query time). Note that this is the single data structure that changes state
+ * and needs to have thread safe read/write operations.
+ *
+ * Note: For an experiment using different vocabulary instances for each thread,
+ * the running time did not improve implying that the critical regions do not
+ * cause bottlenecks.
+ */
+class Vocabulary {
+ public:
+ virtual ~Vocabulary();
+
+ // Returns the word id for the given word.
+ virtual int GetTerminalIndex(const string& word);
+
+ // Returns the id for a nonterminal located at the given position in a phrase.
+ int GetNonterminalIndex(int position);
+
+ // Checks if a symbol is a nonterminal.
+ bool IsTerminal(int symbol);
+
+ // Returns the word corresponding to the given word id.
+ virtual string GetTerminalValue(int symbol);
+
+ private:
+ unordered_map<string, int> dictionary;
+ vector<string> words;
+};
+
+} // namespace extractor
+
+#endif