summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-02-14 23:17:15 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-02-14 23:17:15 +0000
commit9a026ba2db8fa7723374109e6a4a8dcaff8733cd (patch)
tree34a60703a53ada76e7213da5940e86d6f476f1e4
parent252fb164c208ec8f3005f8a652eb3b48c0644e3d (diff)
Working version of the grammar extractor.
-rw-r--r--extractor/Makefile.am90
-rw-r--r--extractor/alignment.cc6
-rw-r--r--extractor/alignment.h7
-rw-r--r--extractor/alignment_test.cc31
-rw-r--r--extractor/binary_search_merger.cc6
-rw-r--r--extractor/binary_search_merger.h6
-rw-r--r--extractor/data_array.cc19
-rw-r--r--extractor/data_array.h16
-rw-r--r--extractor/features/count_source_target_test.cc32
-rw-r--r--extractor/features/feature.cc2
-rw-r--r--extractor/features/feature.h10
-rw-r--r--extractor/features/is_source_singleton.cc2
-rw-r--r--extractor/features/is_source_singleton_test.cc35
-rw-r--r--extractor/features/is_source_target_singleton.cc2
-rw-r--r--extractor/features/is_source_target_singleton_test.cc35
-rw-r--r--extractor/features/max_lex_source_given_target.cc5
-rw-r--r--extractor/features/max_lex_source_given_target_test.cc74
-rw-r--r--extractor/features/max_lex_target_given_source.cc5
-rw-r--r--extractor/features/max_lex_target_given_source_test.cc74
-rw-r--r--extractor/features/sample_source_count.cc2
-rw-r--r--extractor/features/sample_source_count_test.cc36
-rw-r--r--extractor/features/target_given_source_coherent.cc4
-rw-r--r--extractor/features/target_given_source_coherent_test.cc35
-rw-r--r--extractor/grammar.cc19
-rw-r--r--extractor/grammar.h4
-rw-r--r--extractor/grammar_extractor.cc45
-rw-r--r--extractor/grammar_extractor.h8
-rw-r--r--extractor/grammar_extractor_test.cc49
-rw-r--r--extractor/intersector.cc31
-rw-r--r--extractor/intersector.h16
-rw-r--r--extractor/intersector_test.cc6
-rw-r--r--extractor/linear_merger.cc12
-rw-r--r--extractor/linear_merger.h6
-rw-r--r--extractor/matchings_finder.cc4
-rw-r--r--extractor/matchings_finder.h8
-rw-r--r--extractor/matchings_trie.cc14
-rw-r--r--extractor/matchings_trie.h5
-rw-r--r--extractor/mocks/mock_alignment.h10
-rw-r--r--extractor/mocks/mock_binary_search_merger.h4
-rw-r--r--extractor/mocks/mock_data_array.h4
-rw-r--r--extractor/mocks/mock_feature.h9
-rw-r--r--extractor/mocks/mock_intersector.h11
-rw-r--r--extractor/mocks/mock_linear_merger.h2
-rw-r--r--extractor/mocks/mock_matchings_finder.h9
-rw-r--r--extractor/mocks/mock_rule_extractor.h12
-rw-r--r--extractor/mocks/mock_rule_extractor_helper.h78
-rw-r--r--extractor/mocks/mock_rule_factory.h9
-rw-r--r--extractor/mocks/mock_sampler.h9
-rw-r--r--extractor/mocks/mock_scorer.h10
-rw-r--r--extractor/mocks/mock_target_phrase_extractor.h12
-rw-r--r--extractor/mocks/mock_translation_table.h9
-rw-r--r--extractor/phrase_builder.cc5
-rw-r--r--extractor/phrase_location.cc6
-rw-r--r--extractor/precomputation.cc8
-rw-r--r--extractor/precomputation.h5
-rw-r--r--extractor/rule_extractor.cc618
-rw-r--r--extractor/rule_extractor.h92
-rw-r--r--extractor/rule_extractor_helper.cc356
-rw-r--r--extractor/rule_extractor_helper.h82
-rw-r--r--extractor/rule_extractor_helper_test.cc622
-rw-r--r--extractor/rule_extractor_test.cc166
-rw-r--r--extractor/rule_factory.cc84
-rw-r--r--extractor/rule_factory.h22
-rw-r--r--extractor/rule_factory_test.cc98
-rw-r--r--extractor/run_extractor.cc9
-rw-r--r--extractor/sample_alignment.txt2
-rw-r--r--extractor/sampler.cc7
-rw-r--r--extractor/sampler.h7
-rw-r--r--extractor/scorer.cc4
-rw-r--r--extractor/scorer.h9
-rw-r--r--extractor/scorer_test.cc47
-rw-r--r--extractor/suffix_array.cc9
-rw-r--r--extractor/suffix_array_test.cc29
-rw-r--r--extractor/target_phrase_extractor.cc144
-rw-r--r--extractor/target_phrase_extractor.h56
-rw-r--r--extractor/target_phrase_extractor_test.cc116
-rw-r--r--extractor/translation_table.cc47
-rw-r--r--extractor/translation_table.h23
-rw-r--r--extractor/translation_table_test.cc82
-rw-r--r--extractor/vocabulary.h3
80 files changed, 3007 insertions, 700 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index ded06239..c82fc1ae 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -1,8 +1,17 @@
bin_PROGRAMS = compile run_extractor
noinst_PROGRAMS = \
+ alignment_test \
binary_search_merger_test \
data_array_test \
+ feature_count_source_target_test \
+ feature_is_source_singleton_test \
+ feature_is_source_target_singleton_test \
+ feature_max_lex_source_given_target_test \
+ feature_max_lex_target_given_source_test \
+ feature_sample_source_count_test \
+ feature_target_given_source_coherent_test \
+ grammar_extractor_test \
intersector_test \
linear_merger_test \
matching_comparator_test \
@@ -10,27 +19,66 @@ noinst_PROGRAMS = \
matchings_finder_test \
phrase_test \
precomputation_test \
+ rule_extractor_helper_test \
+ rule_extractor_test \
+ rule_factory_test \
sampler_test \
+ scorer_test \
suffix_array_test \
+ target_phrase_extractor_test \
+ translation_table_test \
veb_test
-TESTS = sampler_test
-#TESTS = binary_search_merger_test \
-# data_array_test \
-# intersector_test \
-# linear_merger_test \
-# matching_comparator_test \
-# matching_test \
-# matchings_finder_test \
-# phrase_test \
-# precomputation_test \
-# suffix_array_test \
-# veb_test
+TESTS = alignment_test \
+ binary_search_merger_test \
+ data_array_test \
+ feature_count_source_target_test \
+ feature_is_source_singleton_test \
+ feature_is_source_target_singleton_test \
+ feature_max_lex_source_given_target_test \
+ feature_max_lex_target_given_source_test \
+ feature_sample_source_count_test \
+ feature_target_given_source_coherent_test \
+ grammar_extractor_test \
+ intersector_test \
+ linear_merger_test \
+ matching_comparator_test \
+ matching_test \
+ matchings_finder_test \
+ phrase_test \
+ precomputation_test \
+ rule_extractor_helper_test \
+ rule_extractor_test \
+ rule_factory_test \
+ sampler_test \
+ scorer_test \
+ suffix_array_test \
+ target_phrase_extractor_test \
+ translation_table_test \
+ veb_test
+alignment_test_SOURCES = alignment_test.cc
+alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
binary_search_merger_test_SOURCES = binary_search_merger_test.cc
binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
data_array_test_SOURCES = data_array_test.cc
data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_count_source_target_test_SOURCES = features/count_source_target_test.cc
+feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc
+feature_is_source_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_is_source_target_singleton_test_SOURCES = features/is_source_target_singleton_test.cc
+feature_is_source_target_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_max_lex_source_given_target_test_SOURCES = features/max_lex_source_given_target_test.cc
+feature_max_lex_source_given_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+feature_max_lex_target_given_source_test_SOURCES = features/max_lex_target_given_source_test.cc
+feature_max_lex_target_given_source_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+feature_sample_source_count_test_SOURCES = features/sample_source_count_test.cc
+feature_sample_source_count_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+feature_target_given_source_coherent_test_SOURCES = features/target_given_source_coherent_test.cc
+feature_target_given_source_coherent_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+grammar_extractor_test_SOURCES = grammar_extractor_test.cc
+grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
intersector_test_SOURCES = intersector_test.cc
intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
linear_merger_test_SOURCES = linear_merger_test.cc
@@ -45,10 +93,22 @@ phrase_test_SOURCES = phrase_test.cc
phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
precomputation_test_SOURCES = precomputation_test.cc
precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
-suffix_array_test_SOURCES = suffix_array_test.cc
-suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_extractor_helper_test_SOURCES = rule_extractor_helper_test.cc
+rule_extractor_helper_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_extractor_test_SOURCES = rule_extractor_test.cc
+rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+rule_factory_test_SOURCES = rule_factory_test.cc
+rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
sampler_test_SOURCES = sampler_test.cc
sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+scorer_test_SOURCES = scorer_test.cc
+scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+suffix_array_test_SOURCES = suffix_array_test.cc
+suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
+target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+translation_table_test_SOURCES = translation_table_test.cc
+translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
veb_test_SOURCES = veb_test.cc
veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
@@ -93,10 +153,12 @@ libextractor_a_SOURCES = \
precomputation.cc \
rule.cc \
rule_extractor.cc \
+ rule_extractor_helper.cc \
rule_factory.cc \
sampler.cc \
scorer.cc \
suffix_array.cc \
+ target_phrase_extractor.cc \
translation_table.cc \
veb.cc \
veb_bitset.cc \
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index 2fa0abac..ff39d484 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -31,7 +31,11 @@ Alignment::Alignment(const string& filename) {
alignments.shrink_to_fit();
}
-const vector<pair<int, int> >& Alignment::GetLinks(int sentence_index) const {
+Alignment::Alignment() {}
+
+Alignment::~Alignment() {}
+
+vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const {
return alignments[sentence_index];
}
diff --git a/extractor/alignment.h b/extractor/alignment.h
index 290d6015..f7e79585 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -13,10 +13,15 @@ class Alignment {
public:
Alignment(const string& filename);
- const vector<pair<int, int> >& GetLinks(int sentence_index) const;
+ virtual vector<pair<int, int> > GetLinks(int sentence_index) const;
void WriteBinary(const fs::path& filepath);
+ virtual ~Alignment();
+
+ protected:
+ Alignment();
+
private:
vector<vector<pair<int, int> > > alignments;
};
diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc
new file mode 100644
index 00000000..1bc51a56
--- /dev/null
+++ b/extractor/alignment_test.cc
@@ -0,0 +1,31 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "alignment.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class AlignmentTest : public Test {
+ protected:
+ virtual void SetUp() {
+ alignment = make_shared<Alignment>("sample_alignment.txt");
+ }
+
+ shared_ptr<Alignment> alignment;
+};
+
+TEST_F(AlignmentTest, TestGetLinks) {
+ vector<pair<int, int> > expected_links = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2)
+ };
+ EXPECT_EQ(expected_links, alignment->GetLinks(0));
+ expected_links = {make_pair(1, 0), make_pair(2, 1)};
+ EXPECT_EQ(expected_links, alignment->GetLinks(1));
+}
+
+} // namespace
diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc
index 43d2f734..c1b86a77 100644
--- a/extractor/binary_search_merger.cc
+++ b/extractor/binary_search_merger.cc
@@ -25,8 +25,10 @@ BinarySearchMerger::~BinarySearchMerger() {}
void BinarySearchMerger::Merge(
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
- vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
- vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
+ const vector<int>::iterator& prefix_start,
+ const vector<int>::iterator& prefix_end,
+ const vector<int>::iterator& suffix_start,
+ const vector<int>::iterator& suffix_end,
int prefix_subpatterns, int suffix_subpatterns) const {
if (IsIntersectionVoid(prefix_start, prefix_end, suffix_start, suffix_end,
prefix_subpatterns, suffix_subpatterns, suffix)) {
diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h
index ffa47c8e..c887e012 100644
--- a/extractor/binary_search_merger.h
+++ b/extractor/binary_search_merger.h
@@ -24,8 +24,10 @@ class BinarySearchMerger {
virtual void Merge(
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
- vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
- vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
+ const vector<int>::iterator& prefix_start,
+ const vector<int>::iterator& prefix_end,
+ const vector<int>::iterator& suffix_start,
+ const vector<int>::iterator& suffix_end,
int prefix_subpatterns, int suffix_subpatterns) const;
static double BAEZA_YATES_FACTOR;
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 383b08a7..1097caf3 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -10,9 +10,9 @@
namespace fs = boost::filesystem;
using namespace std;
-int DataArray::END_OF_FILE = 0;
+int DataArray::NULL_WORD = 0;
int DataArray::END_OF_LINE = 1;
-string DataArray::END_OF_FILE_STR = "__END_OF_FILE__";
+string DataArray::NULL_WORD_STR = "__NULL__";
string DataArray::END_OF_LINE_STR = "__END_OF_LINE__";
DataArray::DataArray() {
@@ -47,9 +47,9 @@ DataArray::DataArray(const string& filename, const Side& side) {
}
void DataArray::InitializeDataArray() {
- word2id[END_OF_FILE_STR] = END_OF_FILE;
- id2word.push_back(END_OF_FILE_STR);
- word2id[END_OF_LINE_STR] = END_OF_FILE;
+ word2id[NULL_WORD_STR] = NULL_WORD;
+ id2word.push_back(NULL_WORD_STR);
+ word2id[END_OF_LINE_STR] = END_OF_LINE;
id2word.push_back(END_OF_LINE_STR);
}
@@ -87,6 +87,10 @@ int DataArray::AtIndex(int index) const {
return data[index];
}
+string DataArray::GetWordAtIndex(int index) const {
+ return id2word[data[index]];
+}
+
int DataArray::GetSize() const {
return data.size();
}
@@ -103,6 +107,11 @@ int DataArray::GetSentenceStart(int position) const {
return sentence_start[position];
}
+int DataArray::GetSentenceLength(int sentence_id) const {
+ // Ignore end of line markers.
+ return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1;
+}
+
int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 19fbff88..7c120b3c 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -2,14 +2,13 @@
#define _DATA_ARRAY_H_
#include <string>
-#include <tr1/unordered_map>
+#include <unordered_map>
#include <vector>
#include <boost/filesystem.hpp>
namespace fs = boost::filesystem;
using namespace std;
-using namespace tr1;
enum Side {
SOURCE,
@@ -18,9 +17,9 @@ enum Side {
class DataArray {
public:
- static int END_OF_FILE;
+ static int NULL_WORD;
static int END_OF_LINE;
- static string END_OF_FILE_STR;
+ static string NULL_WORD_STR;
static string END_OF_LINE_STR;
DataArray(const string& filename);
@@ -33,6 +32,8 @@ class DataArray {
virtual int AtIndex(int index) const;
+ virtual string GetWordAtIndex(int index) const;
+
virtual int GetSize() const;
virtual int GetVocabularySize() const;
@@ -43,9 +44,12 @@ class DataArray {
virtual string GetWord(int word_id) const;
- int GetNumSentences() const;
+ virtual int GetNumSentences() const;
+
+ virtual int GetSentenceStart(int position) const;
- int GetSentenceStart(int position) const;
+ //TODO(pauldb): Add unit tests.
+ virtual int GetSentenceLength(int sentence_id) const;
virtual int GetSentenceId(int position) const;
diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc
new file mode 100644
index 00000000..22633bb6
--- /dev/null
+++ b/extractor/features/count_source_target_test.cc
@@ -0,0 +1,32 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "count_source_target.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class CountSourceTargetTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<CountSourceTarget>();
+ }
+
+ shared_ptr<CountSourceTarget> feature;
+};
+
+TEST_F(CountSourceTargetTest, TestGetName) {
+ EXPECT_EQ("CountEF", feature->GetName());
+}
+
+TEST_F(CountSourceTargetTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 9, 13);
+ EXPECT_EQ(1.0, feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc
index 7381c35a..876f5f8f 100644
--- a/extractor/features/feature.cc
+++ b/extractor/features/feature.cc
@@ -1,3 +1,5 @@
#include "feature.h"
const double Feature::MAX_SCORE = 99.0;
+
+Feature::~Feature() {}
diff --git a/extractor/features/feature.h b/extractor/features/feature.h
index ad22d3e7..aca58401 100644
--- a/extractor/features/feature.h
+++ b/extractor/features/feature.h
@@ -10,14 +10,16 @@ using namespace std;
struct FeatureContext {
FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase,
- double sample_source_count, int pair_count) :
+ double source_phrase_count, int pair_count, int num_samples) :
source_phrase(source_phrase), target_phrase(target_phrase),
- sample_source_count(sample_source_count), pair_count(pair_count) {}
+ source_phrase_count(source_phrase_count), pair_count(pair_count),
+ num_samples(num_samples) {}
Phrase source_phrase;
Phrase target_phrase;
- double sample_source_count;
+ double source_phrase_count;
int pair_count;
+ int num_samples;
};
class Feature {
@@ -26,6 +28,8 @@ class Feature {
virtual string GetName() const = 0;
+ virtual ~Feature();
+
static const double MAX_SCORE;
};
diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc
index 754df3bf..98d4e5fe 100644
--- a/extractor/features/is_source_singleton.cc
+++ b/extractor/features/is_source_singleton.cc
@@ -3,7 +3,7 @@
#include <cmath>
double IsSourceSingleton::Score(const FeatureContext& context) const {
- return context.sample_source_count == 1;
+ return context.source_phrase_count == 1;
}
string IsSourceSingleton::GetName() const {
diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc
new file mode 100644
index 00000000..8c71e593
--- /dev/null
+++ b/extractor/features/is_source_singleton_test.cc
@@ -0,0 +1,35 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "is_source_singleton.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class IsSourceSingletonTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<IsSourceSingleton>();
+ }
+
+ shared_ptr<IsSourceSingleton> feature;
+};
+
+TEST_F(IsSourceSingletonTest, TestGetName) {
+ EXPECT_EQ("IsSingletonF", feature->GetName());
+}
+
+TEST_F(IsSourceSingletonTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 3, 31);
+ EXPECT_EQ(0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 1, 3, 25);
+ EXPECT_EQ(1, feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc
index ec816509..31d36532 100644
--- a/extractor/features/is_source_target_singleton.cc
+++ b/extractor/features/is_source_target_singleton.cc
@@ -7,5 +7,5 @@ double IsSourceTargetSingleton::Score(const FeatureContext& context) const {
}
string IsSourceTargetSingleton::GetName() const {
- return "IsSingletonEF";
+ return "IsSingletonFE";
}
diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc
new file mode 100644
index 00000000..a51f77c9
--- /dev/null
+++ b/extractor/features/is_source_target_singleton_test.cc
@@ -0,0 +1,35 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "is_source_target_singleton.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class IsSourceTargetSingletonTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<IsSourceTargetSingleton>();
+ }
+
+ shared_ptr<IsSourceTargetSingleton> feature;
+};
+
+TEST_F(IsSourceTargetSingletonTest, TestGetName) {
+ EXPECT_EQ("IsSingletonFE", feature->GetName());
+}
+
+TEST_F(IsSourceTargetSingletonTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.5, 3, 7);
+ EXPECT_EQ(0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 2.3, 1, 28);
+ EXPECT_EQ(1, feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc
index c4792d49..21f5c76a 100644
--- a/extractor/features/max_lex_source_given_target.cc
+++ b/extractor/features/max_lex_source_given_target.cc
@@ -2,6 +2,7 @@
#include <cmath>
+#include "../data_array.h"
#include "../translation_table.h"
MaxLexSourceGivenTarget::MaxLexSourceGivenTarget(
@@ -10,8 +11,8 @@ MaxLexSourceGivenTarget::MaxLexSourceGivenTarget(
double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const {
vector<string> source_words = context.source_phrase.GetWords();
- // TODO(pauldb): Add NULL to target_words, after fixing translation table.
vector<string> target_words = context.target_phrase.GetWords();
+ target_words.push_back(DataArray::NULL_WORD_STR);
double score = 0;
for (string source_word: source_words) {
@@ -26,5 +27,5 @@ double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const {
}
string MaxLexSourceGivenTarget::GetName() const {
- return "MaxLexFGivenE";
+ return "MaxLexFgivenE";
}
diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc
new file mode 100644
index 00000000..5fd41f8b
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target_test.cc
@@ -0,0 +1,74 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "../mocks/mock_translation_table.h"
+#include "../mocks/mock_vocabulary.h"
+#include "../data_array.h"
+#include "../phrase_builder.h"
+#include "max_lex_source_given_target.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class MaxLexSourceGivenTargetTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> source_words = {"f1", "f2", "f3"};
+ vector<string> target_words = {"e1", "e2", "e3"};
+
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(source_words[i]));
+ }
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size()))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ table = make_shared<MockTranslationTable>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ for (size_t j = 0; j < target_words.size(); ++j) {
+ int value = i - j;
+ EXPECT_CALL(*table, GetSourceGivenTargetScore(
+ source_words[i], target_words[j])).WillRepeatedly(Return(value));
+ }
+ }
+
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ int value = i * 3;
+ EXPECT_CALL(*table, GetSourceGivenTargetScore(
+ source_words[i], DataArray::NULL_WORD_STR))
+ .WillRepeatedly(Return(value));
+ }
+
+ feature = make_shared<MaxLexSourceGivenTarget>(table);
+ }
+
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockTranslationTable> table;
+ shared_ptr<MaxLexSourceGivenTarget> feature;
+};
+
+TEST_F(MaxLexSourceGivenTargetTest, TestGetName) {
+ EXPECT_EQ("MaxLexFgivenE", feature->GetName());
+}
+
+TEST_F(MaxLexSourceGivenTargetTest, TestScore) {
+ vector<int> source_symbols = {0, 1, 2};
+ Phrase source_phrase = phrase_builder->Build(source_symbols);
+ vector<int> target_symbols = {3, 4, 5};
+ Phrase target_phrase = phrase_builder->Build(target_symbols);
+ FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11);
+ EXPECT_EQ(99 - log10(18), feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc
index d82182fe..f2bc2474 100644
--- a/extractor/features/max_lex_target_given_source.cc
+++ b/extractor/features/max_lex_target_given_source.cc
@@ -2,6 +2,7 @@
#include <cmath>
+#include "../data_array.h"
#include "../translation_table.h"
MaxLexTargetGivenSource::MaxLexTargetGivenSource(
@@ -9,8 +10,8 @@ MaxLexTargetGivenSource::MaxLexTargetGivenSource(
table(table) {}
double MaxLexTargetGivenSource::Score(const FeatureContext& context) const {
- // TODO(pauldb): Add NULL to source_words, after fixing translation table.
vector<string> source_words = context.source_phrase.GetWords();
+ source_words.push_back(DataArray::NULL_WORD_STR);
vector<string> target_words = context.target_phrase.GetWords();
double score = 0;
@@ -26,5 +27,5 @@ double MaxLexTargetGivenSource::Score(const FeatureContext& context) const {
}
string MaxLexTargetGivenSource::GetName() const {
- return "MaxLexEGivenF";
+ return "MaxLexEgivenF";
}
diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc
new file mode 100644
index 00000000..c8701bf7
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source_test.cc
@@ -0,0 +1,74 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "../mocks/mock_translation_table.h"
+#include "../mocks/mock_vocabulary.h"
+#include "../data_array.h"
+#include "../phrase_builder.h"
+#include "max_lex_target_given_source.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class MaxLexTargetGivenSourceTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> source_words = {"f1", "f2", "f3"};
+ vector<string> target_words = {"e1", "e2", "e3"};
+
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(source_words[i]));
+ }
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size()))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ table = make_shared<MockTranslationTable>();
+ for (size_t i = 0; i < source_words.size(); ++i) {
+ for (size_t j = 0; j < target_words.size(); ++j) {
+ int value = i - j;
+ EXPECT_CALL(*table, GetTargetGivenSourceScore(
+ source_words[i], target_words[j])).WillRepeatedly(Return(value));
+ }
+ }
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ int value = i * 3;
+ EXPECT_CALL(*table, GetTargetGivenSourceScore(
+ DataArray::NULL_WORD_STR, target_words[i]))
+ .WillRepeatedly(Return(value));
+ }
+
+ feature = make_shared<MaxLexTargetGivenSource>(table);
+ }
+
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockTranslationTable> table;
+ shared_ptr<MaxLexTargetGivenSource> feature;
+};
+
+TEST_F(MaxLexTargetGivenSourceTest, TestGetName) {
+ EXPECT_EQ("MaxLexEgivenF", feature->GetName());
+}
+
+TEST_F(MaxLexTargetGivenSourceTest, TestScore) {
+ vector<int> source_symbols = {0, 1, 2};
+ Phrase source_phrase = phrase_builder->Build(source_symbols);
+ vector<int> target_symbols = {3, 4, 5};
+ Phrase target_phrase = phrase_builder->Build(target_symbols);
+ FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19);
+ EXPECT_EQ(-log10(36), feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc
index c8124cfb..88b645b1 100644
--- a/extractor/features/sample_source_count.cc
+++ b/extractor/features/sample_source_count.cc
@@ -3,7 +3,7 @@
#include <cmath>
double SampleSourceCount::Score(const FeatureContext& context) const {
- return log10(1 + context.sample_source_count);
+ return log10(1 + context.num_samples);
}
string SampleSourceCount::GetName() const {
diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc
new file mode 100644
index 00000000..7d226104
--- /dev/null
+++ b/extractor/features/sample_source_count_test.cc
@@ -0,0 +1,36 @@
+#include <gtest/gtest.h>
+
+#include <cmath>
+#include <memory>
+#include <string>
+
+#include "sample_source_count.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class SampleSourceCountTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<SampleSourceCount>();
+ }
+
+ shared_ptr<SampleSourceCount> feature;
+};
+
+TEST_F(SampleSourceCountTest, TestGetName) {
+ EXPECT_EQ("SampleCountF", feature->GetName());
+}
+
+TEST_F(SampleSourceCountTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0, 3, 1);
+ EXPECT_EQ(log10(2), feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 3.2, 3, 9);
+ EXPECT_EQ(1.0, feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc
index 748413c3..274b3364 100644
--- a/extractor/features/target_given_source_coherent.cc
+++ b/extractor/features/target_given_source_coherent.cc
@@ -3,10 +3,10 @@
#include <cmath>
double TargetGivenSourceCoherent::Score(const FeatureContext& context) const {
- double prob = context.pair_count / context.sample_source_count;
+ double prob = (double) context.pair_count / context.num_samples;
return prob > 0 ? -log10(prob) : MAX_SCORE;
}
string TargetGivenSourceCoherent::GetName() const {
- return "EGivenFCoherent";
+ return "EgivenFCoherent";
}
diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc
new file mode 100644
index 00000000..c54c06c2
--- /dev/null
+++ b/extractor/features/target_given_source_coherent_test.cc
@@ -0,0 +1,35 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+
+#include "target_given_source_coherent.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class TargetGivenSourceCoherentTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature = make_shared<TargetGivenSourceCoherent>();
+ }
+
+ shared_ptr<TargetGivenSourceCoherent> feature;
+};
+
+TEST_F(TargetGivenSourceCoherentTest, TestGetName) {
+ EXPECT_EQ("EgivenFCoherent", feature->GetName());
+}
+
+TEST_F(TargetGivenSourceCoherentTest, TestScore) {
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.3, 2, 20);
+ EXPECT_EQ(1.0, feature->Score(context));
+
+ context = FeatureContext(phrase, phrase, 1.9, 0, 1);
+ EXPECT_EQ(99.0, feature->Score(context));
+}
+
+} // namespace
diff --git a/extractor/grammar.cc b/extractor/grammar.cc
index 79a0541d..8124a804 100644
--- a/extractor/grammar.cc
+++ b/extractor/grammar.cc
@@ -1,17 +1,32 @@
#include "grammar.h"
+#include <iomanip>
+
#include "rule.h"
+using namespace std;
+
Grammar::Grammar(const vector<Rule>& rules,
const vector<string>& feature_names) :
rules(rules), feature_names(feature_names) {}
+vector<Rule> Grammar::GetRules() const {
+ return rules;
+}
+
+vector<string> Grammar::GetFeatureNames() const {
+ return feature_names;
+}
+
ostream& operator<<(ostream& os, const Grammar& grammar) {
- for (Rule rule: grammar.rules) {
+ vector<Rule> rules = grammar.GetRules();
+ vector<string> feature_names = grammar.GetFeatureNames();
+ os << setprecision(12);
+ for (Rule rule: rules) {
os << "[X] ||| " << rule.source_phrase << " ||| "
<< rule.target_phrase << " |||";
for (size_t i = 0; i < rule.scores.size(); ++i) {
- os << " " << grammar.feature_names[i] << "=" << rule.scores[i];
+ os << " " << feature_names[i] << "=" << rule.scores[i];
}
os << " |||";
for (auto link: rule.alignment) {
diff --git a/extractor/grammar.h b/extractor/grammar.h
index db15fa7e..889cc2f3 100644
--- a/extractor/grammar.h
+++ b/extractor/grammar.h
@@ -13,6 +13,10 @@ class Grammar {
public:
Grammar(const vector<Rule>& rules, const vector<string>& feature_names);
+ vector<Rule> GetRules() const;
+
+ vector<string> GetFeatureNames() const;
+
friend ostream& operator<<(ostream& os, const Grammar& grammar);
private:
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 15268165..2f008026 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -10,19 +10,6 @@
using namespace std;
-vector<string> Tokenize(const string& sentence) {
- vector<string> result;
- result.push_back("<s>");
-
- istringstream buffer(sentence);
- copy(istream_iterator<string>(buffer),
- istream_iterator<string>(),
- back_inserter(result));
-
- result.push_back("</s>");
- return result;
-}
-
GrammarExtractor::GrammarExtractor(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
@@ -31,15 +18,35 @@ GrammarExtractor::GrammarExtractor(
int max_nonterminals, int max_rule_symbols, int max_samples,
bool use_baeza_yates, bool require_tight_phrases) :
vocabulary(make_shared<Vocabulary>()),
- rule_factory(source_suffix_array, target_data_array, alignment,
- vocabulary, precomputation, scorer, min_gap_size, max_rule_span,
- max_nonterminals, max_rule_symbols, max_samples, use_baeza_yates,
- require_tight_phrases) {}
+ rule_factory(make_shared<HieroCachingRuleFactory>(
+ source_suffix_array, target_data_array, alignment, vocabulary,
+ precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
+ max_rule_symbols, max_samples, use_baeza_yates,
+ require_tight_phrases)) {}
+
+GrammarExtractor::GrammarExtractor(
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<HieroCachingRuleFactory> rule_factory) :
+ vocabulary(vocabulary),
+ rule_factory(rule_factory) {}
Grammar GrammarExtractor::GetGrammar(const string& sentence) {
- vector<string> words = Tokenize(sentence);
+ vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory.GetGrammar(word_ids);
+ return rule_factory->GetGrammar(word_ids);
+}
+
+vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
+ vector<string> result;
+ result.push_back("<s>");
+
+ istringstream buffer(sentence);
+ copy(istream_iterator<string>(buffer),
+ istream_iterator<string>(),
+ back_inserter(result));
+
+ result.push_back("</s>");
+ return result;
}
vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 243f33cf..5f87faa7 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -32,13 +32,19 @@ class GrammarExtractor {
bool use_baeza_yates,
bool require_tight_phrases);
+ // For testing only.
+ GrammarExtractor(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<HieroCachingRuleFactory> rule_factory);
+
Grammar GetGrammar(const string& sentence);
private:
+ vector<string> TokenizeSentence(const string& sentence);
+
vector<int> AnnotateWords(const vector<string>& words);
shared_ptr<Vocabulary> vocabulary;
- HieroCachingRuleFactory rule_factory;
+ shared_ptr<HieroCachingRuleFactory> rule_factory;
};
#endif
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
new file mode 100644
index 00000000..d4ed7d4f
--- /dev/null
+++ b/extractor/grammar_extractor_test.cc
@@ -0,0 +1,49 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "mocks/mock_rule_factory.h"
+#include "mocks/mock_vocabulary.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+TEST(GrammarExtractorTest, TestAnnotatingWords) {
+ shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>"))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna"))
+ .WillRepeatedly(Return(1));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("has"))
+ .WillRepeatedly(Return(2));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("many"))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("apples"))
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("."))
+ .WillRepeatedly(Return(5));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>"))
+ .WillRepeatedly(Return(6));
+
+ shared_ptr<MockHieroCachingRuleFactory> factory =
+ make_shared<MockHieroCachingRuleFactory>();
+ vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6};
+ vector<Rule> rules;
+ vector<string> feature_names;
+ Grammar grammar(rules, feature_names);
+ EXPECT_CALL(*factory, GetGrammar(word_ids))
+ .WillOnce(Return(grammar));
+
+ GrammarExtractor extractor(vocabulary, factory);
+ string sentence = "Anna has many many apples .";
+ extractor.GetGrammar(sentence);
+}
+
+} // namespace
diff --git a/extractor/intersector.cc b/extractor/intersector.cc
index b53479af..cf42f630 100644
--- a/extractor/intersector.cc
+++ b/extractor/intersector.cc
@@ -1,5 +1,7 @@
#include "intersector.h"
+#include <chrono>
+
#include "data_array.h"
#include "matching_comparator.h"
#include "phrase.h"
@@ -9,6 +11,10 @@
#include "veb.h"
#include "vocabulary.h"
+using namespace std::chrono;
+
+typedef high_resolution_clock Clock;
+
Intersector::Intersector(shared_ptr<Vocabulary> vocabulary,
shared_ptr<Precomputation> precomputation,
shared_ptr<SuffixArray> suffix_array,
@@ -38,12 +44,22 @@ Intersector::Intersector(shared_ptr<Vocabulary> vocabulary,
ConvertIndexes(precomputation, suffix_array->GetData());
}
+Intersector::Intersector() {}
+
+Intersector::~Intersector() {}
+
void Intersector::ConvertIndexes(shared_ptr<Precomputation> precomputation,
shared_ptr<DataArray> data_array) {
const Index& precomputed_index = precomputation->GetInvertedIndex();
for (pair<vector<int>, vector<int> > entry: precomputed_index) {
vector<int> phrase = ConvertPhrase(entry.first, data_array);
inverted_index[phrase] = entry.second;
+
+ phrase.push_back(vocabulary->GetNonterminalIndex(1));
+ inverted_index[phrase] = entry.second;
+ phrase.pop_back();
+ phrase.insert(phrase.begin(), vocabulary->GetNonterminalIndex(1));
+ inverted_index[phrase] = entry.second;
}
const Index& precomputed_collocations = precomputation->GetCollocations();
@@ -76,6 +92,9 @@ PhraseLocation Intersector::Intersect(
const Phrase& prefix, PhraseLocation& prefix_location,
const Phrase& suffix, PhraseLocation& suffix_location,
const Phrase& phrase) {
+ if (linear_merge_time == 0) {
+ linear_merger->linear_merge_time = 0;
+ }
vector<int> symbols = phrase.Get();
// We should never attempt to do an intersect query for a pattern starting or
@@ -95,17 +114,23 @@ PhraseLocation Intersector::Intersect(
shared_ptr<vector<int> > prefix_matchings = prefix_location.matchings;
shared_ptr<vector<int> > suffix_matchings = suffix_location.matchings;
int prefix_subpatterns = prefix_location.num_subpatterns;
- int suffix_subpatterns = prefix_location.num_subpatterns;
+ int suffix_subpatterns = suffix_location.num_subpatterns;
if (use_baeza_yates) {
+ double prev_linear_merge_time = linear_merger->linear_merge_time;
+ Clock::time_point start = Clock::now();
binary_search_merger->Merge(locations, phrase, suffix,
prefix_matchings->begin(), prefix_matchings->end(),
suffix_matchings->begin(), suffix_matchings->end(),
prefix_subpatterns, suffix_subpatterns);
+ Clock::time_point stop = Clock::now();
+ binary_merge_time += duration_cast<milliseconds>(stop - start).count() -
+ (linear_merger->linear_merge_time - prev_linear_merge_time);
} else {
linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(),
prefix_matchings->end(), suffix_matchings->begin(),
suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns);
}
+ linear_merge_time = linear_merger->linear_merge_time;
return PhraseLocation(locations, phrase.Arity() + 1);
}
@@ -116,6 +141,8 @@ void Intersector::ExtendPhraseLocation(
return;
}
+ Clock::time_point sort_start = Clock::now();
+
phrase_location.num_subpatterns = 1;
phrase_location.sa_low = phrase_location.sa_high = 0;
@@ -140,4 +167,6 @@ void Intersector::ExtendPhraseLocation(
}
phrase_location.matchings = make_shared<vector<int> >(matchings);
+ Clock::time_point sort_stop = Clock::now();
+ sort_time += duration_cast<milliseconds>(sort_stop - sort_start).count();
}
diff --git a/extractor/intersector.h b/extractor/intersector.h
index f023cc96..8b159f17 100644
--- a/extractor/intersector.h
+++ b/extractor/intersector.h
@@ -2,7 +2,7 @@
#define _INTERSECTOR_H_
#include <memory>
-#include <tr1/unordered_map>
+#include <unordered_map>
#include <vector>
#include <boost/functional/hash.hpp>
@@ -11,7 +11,6 @@
#include "linear_merger.h"
using namespace std;
-using namespace tr1;
typedef boost::hash<vector<int> > VectorHash;
typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
@@ -42,11 +41,16 @@ class Intersector {
shared_ptr<BinarySearchMerger> binary_search_merger,
bool use_baeza_yates);
- PhraseLocation Intersect(
+ virtual ~Intersector();
+
+ virtual PhraseLocation Intersect(
const Phrase& prefix, PhraseLocation& prefix_location,
const Phrase& suffix, PhraseLocation& suffix_location,
const Phrase& phrase);
+ protected:
+ Intersector();
+
private:
void ConvertIndexes(shared_ptr<Precomputation> precomputation,
shared_ptr<DataArray> data_array);
@@ -64,6 +68,12 @@ class Intersector {
Index inverted_index;
Index collocations;
bool use_baeza_yates;
+
+ // TODO(pauldb): Don't forget to remove these.
+ public:
+ double sort_time;
+ double linear_merge_time;
+ double binary_merge_time;
};
#endif
diff --git a/extractor/intersector_test.cc b/extractor/intersector_test.cc
index a3756902..ec318362 100644
--- a/extractor/intersector_test.cc
+++ b/extractor/intersector_test.cc
@@ -34,7 +34,7 @@ class IntersectorTest : public Test {
.WillRepeatedly(Return(words[i]));
}
- vector<int> suffixes = {0, 1, 3, 5, 2, 4, 6};
+ vector<int> suffixes = {6, 0, 5, 3, 1, 4, 2};
suffix_array = make_shared<MockSuffixArray>();
EXPECT_CALL(*suffix_array, GetData())
.WillRepeatedly(Return(data_array));
@@ -103,7 +103,7 @@ TEST_F(IntersectorTest, TestLinearMergeaXb) {
Phrase suffix = phrase_builder->Build(suffix_symbols);
vector<int> symbols = {3, -1, 4};
Phrase phrase = phrase_builder->Build(symbols);
- PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6);
+ PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7);
vector<int> ex_prefix_locs = {1, 3, 5};
PhraseLocation extended_prefix_locs(ex_prefix_locs, 1);
@@ -135,7 +135,7 @@ TEST_F(IntersectorTest, TestBinarySearchMergeaXb) {
Phrase suffix = phrase_builder->Build(suffix_symbols);
vector<int> symbols = {3, -1, 4};
Phrase phrase = phrase_builder->Build(symbols);
- PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6);
+ PhraseLocation prefix_locs(2, 5), suffix_locs(5, 7);
vector<int> ex_prefix_locs = {1, 3, 5};
PhraseLocation extended_prefix_locs(ex_prefix_locs, 1);
diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc
index 666f8d87..7233f945 100644
--- a/extractor/linear_merger.cc
+++ b/extractor/linear_merger.cc
@@ -1,5 +1,6 @@
#include "linear_merger.h"
+#include <chrono>
#include <cmath>
#include "data_array.h"
@@ -9,6 +10,10 @@
#include "phrase_location.h"
#include "vocabulary.h"
+using namespace std::chrono;
+
+typedef high_resolution_clock Clock;
+
LinearMerger::LinearMerger(shared_ptr<Vocabulary> vocabulary,
shared_ptr<DataArray> data_array,
shared_ptr<MatchingComparator> comparator) :
@@ -22,7 +27,9 @@ void LinearMerger::Merge(
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
- int prefix_subpatterns, int suffix_subpatterns) const {
+ int prefix_subpatterns, int suffix_subpatterns) {
+ Clock::time_point start = Clock::now();
+
int last_chunk_len = suffix.GetChunkLen(suffix.Arity());
bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0));
@@ -62,4 +69,7 @@ void LinearMerger::Merge(
prefix_start += prefix_subpatterns;
}
}
+
+ Clock::time_point stop = Clock::now();
+ linear_merge_time += duration_cast<milliseconds>(stop - start).count();
}
diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h
index 6a69b804..25692b15 100644
--- a/extractor/linear_merger.h
+++ b/extractor/linear_merger.h
@@ -24,7 +24,7 @@ class LinearMerger {
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
- int prefix_subpatterns, int suffix_subpatterns) const;
+ int prefix_subpatterns, int suffix_subpatterns);
protected:
LinearMerger();
@@ -33,6 +33,10 @@ class LinearMerger {
shared_ptr<Vocabulary> vocabulary;
shared_ptr<DataArray> data_array;
shared_ptr<MatchingComparator> comparator;
+
+ // TODO(pauldb): Remove this eventually.
+ public:
+ double linear_merge_time;
};
#endif
diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc
index ba4edab1..eaf493b2 100644
--- a/extractor/matchings_finder.cc
+++ b/extractor/matchings_finder.cc
@@ -6,6 +6,10 @@
MatchingsFinder::MatchingsFinder(shared_ptr<SuffixArray> suffix_array) :
suffix_array(suffix_array) {}
+MatchingsFinder::MatchingsFinder() {}
+
+MatchingsFinder::~MatchingsFinder() {}
+
PhraseLocation MatchingsFinder::Find(PhraseLocation& location,
const string& word, int offset) {
if (location.sa_low == -1 && location.sa_high == -1) {
diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h
index 0458a4d8..ed04d8b8 100644
--- a/extractor/matchings_finder.h
+++ b/extractor/matchings_finder.h
@@ -13,7 +13,13 @@ class MatchingsFinder {
public:
MatchingsFinder(shared_ptr<SuffixArray> suffix_array);
- PhraseLocation Find(PhraseLocation& location, const string& word, int offset);
+ virtual ~MatchingsFinder();
+
+ virtual PhraseLocation Find(PhraseLocation& location, const string& word,
+ int offset);
+
+ protected:
+ MatchingsFinder();
private:
shared_ptr<SuffixArray> suffix_array;
diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc
index 851d4596..921ec582 100644
--- a/extractor/matchings_trie.cc
+++ b/extractor/matchings_trie.cc
@@ -1,11 +1,19 @@
#include "matchings_trie.h"
void MatchingsTrie::Reset() {
- // TODO(pauldb): This is probably memory leaking because of the suffix links.
- // Check if it's true and free the memory properly.
- root.reset(new TrieNode());
+ ResetTree(root);
+ root = make_shared<TrieNode>();
}
shared_ptr<TrieNode> MatchingsTrie::GetRoot() const {
return root;
}
+
+void MatchingsTrie::ResetTree(shared_ptr<TrieNode> root) {
+ if (root != NULL) {
+ for (auto child: root->children) {
+ ResetTree(child.second);
+ }
+ root.reset();
+ }
+}
diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h
index f935d1a9..6e72b2db 100644
--- a/extractor/matchings_trie.h
+++ b/extractor/matchings_trie.h
@@ -2,13 +2,12 @@
#define _MATCHINGS_TRIE_
#include <memory>
-#include <tr1/unordered_map>
+#include <unordered_map>
#include "phrase.h"
#include "phrase_location.h"
using namespace std;
-using namespace tr1;
struct TrieNode {
TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(),
@@ -40,6 +39,8 @@ class MatchingsTrie {
shared_ptr<TrieNode> GetRoot() const;
private:
+ void ResetTree(shared_ptr<TrieNode> root);
+
shared_ptr<TrieNode> root;
};
diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h
new file mode 100644
index 00000000..4a5077ad
--- /dev/null
+++ b/extractor/mocks/mock_alignment.h
@@ -0,0 +1,10 @@
+#include <gmock/gmock.h>
+
+#include "../alignment.h"
+
+typedef vector<pair<int, int> > SentenceLinks;
+
+class MockAlignment : public Alignment {
+ public:
+ MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id));
+};
diff --git a/extractor/mocks/mock_binary_search_merger.h b/extractor/mocks/mock_binary_search_merger.h
index e1375ee3..e23386f0 100644
--- a/extractor/mocks/mock_binary_search_merger.h
+++ b/extractor/mocks/mock_binary_search_merger.h
@@ -10,6 +10,6 @@ using namespace std;
class MockBinarySearchMerger: public BinarySearchMerger {
public:
MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&,
- vector<int>::iterator, vector<int>::iterator, vector<int>::iterator,
- vector<int>::iterator, int, int));
+ const vector<int>::iterator&, const vector<int>::iterator&,
+ const vector<int>::iterator&, const vector<int>::iterator&, int, int));
};
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index 54497cf5..004e8906 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -6,10 +6,14 @@ class MockDataArray : public DataArray {
public:
MOCK_CONST_METHOD0(GetData, const vector<int>&());
MOCK_CONST_METHOD1(AtIndex, int(int index));
+ MOCK_CONST_METHOD1(GetWordAtIndex, string(int index));
MOCK_CONST_METHOD0(GetSize, int());
MOCK_CONST_METHOD0(GetVocabularySize, int());
MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
MOCK_CONST_METHOD1(GetWord, string(int word_id));
+ MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id));
+ MOCK_CONST_METHOD0(GetNumSentences, int());
+ MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id));
MOCK_CONST_METHOD1(GetSentenceId, int(int position));
};
diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h
new file mode 100644
index 00000000..d2137629
--- /dev/null
+++ b/extractor/mocks/mock_feature.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../features/feature.h"
+
+class MockFeature : public Feature {
+ public:
+ MOCK_CONST_METHOD1(Score, double(const FeatureContext& context));
+ MOCK_CONST_METHOD0(GetName, string());
+};
diff --git a/extractor/mocks/mock_intersector.h b/extractor/mocks/mock_intersector.h
new file mode 100644
index 00000000..372fa7ea
--- /dev/null
+++ b/extractor/mocks/mock_intersector.h
@@ -0,0 +1,11 @@
+#include <gmock/gmock.h>
+
+#include "../intersector.h"
+#include "../phrase.h"
+#include "../phrase_location.h"
+
+class MockIntersector : public Intersector {
+ public:
+ MOCK_METHOD5(Intersect, PhraseLocation(const Phrase&, PhraseLocation&,
+ const Phrase&, PhraseLocation&, const Phrase&));
+};
diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h
index 82243428..522c1f31 100644
--- a/extractor/mocks/mock_linear_merger.h
+++ b/extractor/mocks/mock_linear_merger.h
@@ -9,7 +9,7 @@ using namespace std;
class MockLinearMerger: public LinearMerger {
public:
- MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&,
+ MOCK_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&,
vector<int>::iterator, vector<int>::iterator, vector<int>::iterator,
vector<int>::iterator, int, int));
};
diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h
new file mode 100644
index 00000000..3e80d266
--- /dev/null
+++ b/extractor/mocks/mock_matchings_finder.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../matchings_finder.h"
+#include "../phrase_location.h"
+
+class MockMatchingsFinder : public MatchingsFinder {
+ public:
+ MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int));
+};
diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h
new file mode 100644
index 00000000..f18e009a
--- /dev/null
+++ b/extractor/mocks/mock_rule_extractor.h
@@ -0,0 +1,12 @@
+#include <gmock/gmock.h>
+
+#include "../phrase.h"
+#include "../phrase_builder.h"
+#include "../rule.h"
+#include "../rule_extractor.h"
+
+class MockRuleExtractor : public RuleExtractor {
+ public:
+ MOCK_CONST_METHOD2(ExtractRules, vector<Rule>(const Phrase&,
+ const PhraseLocation&));
+};
diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h
new file mode 100644
index 00000000..63ff1048
--- /dev/null
+++ b/extractor/mocks/mock_rule_extractor_helper.h
@@ -0,0 +1,78 @@
+#include <gmock/gmock.h>
+
+#include <vector>
+
+#include "../rule_extractor_helper.h"
+
+using namespace std;
+
+typedef unordered_map<int, int> Indexes;
+
+class MockRuleExtractorHelper : public RuleExtractorHelper {
+ public:
+ MOCK_CONST_METHOD5(GetLinksSpans, void(vector<int>&, vector<int>&,
+ vector<int>&, vector<int>&, int));
+ MOCK_CONST_METHOD3(CheckAlignedTerminals, bool(const vector<int>&,
+ const vector<int>&, const vector<int>&));
+ MOCK_CONST_METHOD3(CheckTightPhrases, bool(const vector<int>&,
+ const vector<int>&, const vector<int>&));
+ MOCK_CONST_METHOD1(GetGapOrder, vector<int>(const vector<pair<int, int> >&));
+ MOCK_CONST_METHOD3(GetSourceIndexes, Indexes(const vector<int>&,
+ const vector<int>&, int));
+
+ // We need to implement these methods, because Google Mock doesn't support
+ // methods with more than 10 arguments.
+ bool FindFixPoint(
+ int, int, const vector<int>&, const vector<int>&, int& target_phrase_low,
+ int& target_phrase_high, const vector<int>&, const vector<int>&,
+ int& source_back_low, int& source_back_high, int, int, int, int, bool,
+ bool, bool) const {
+ target_phrase_low = this->target_phrase_low;
+ target_phrase_high = this->target_phrase_high;
+ source_back_low = this->source_back_low;
+ source_back_high = this->source_back_high;
+ return find_fix_point;
+ }
+
+ bool GetGaps(vector<pair<int, int> >& source_gaps,
+ vector<pair<int, int> >& target_gaps,
+ const vector<int>&, const vector<int>&, const vector<int>&,
+ const vector<int>&, const vector<int>&, const vector<int>&,
+ int, int, int, int, int& num_symbols,
+ bool& met_constraints) const {
+ source_gaps = this->source_gaps;
+ target_gaps = this->target_gaps;
+ num_symbols = this->num_symbols;
+ met_constraints = this->met_constraints;
+ return get_gaps;
+ }
+
+ void SetUp(
+ int target_phrase_low, int target_phrase_high, int source_back_low,
+ int source_back_high, bool find_fix_point,
+ vector<pair<int, int> > source_gaps, vector<pair<int, int> > target_gaps,
+ int num_symbols, bool met_constraints, bool get_gaps) {
+ this->target_phrase_low = target_phrase_low;
+ this->target_phrase_high = target_phrase_high;
+ this->source_back_low = source_back_low;
+ this->source_back_high = source_back_high;
+ this->find_fix_point = find_fix_point;
+ this->source_gaps = source_gaps;
+ this->target_gaps = target_gaps;
+ this->num_symbols = num_symbols;
+ this->met_constraints = met_constraints;
+ this->get_gaps = get_gaps;
+ }
+
+ private:
+ int target_phrase_low;
+ int target_phrase_high;
+ int source_back_low;
+ int source_back_high;
+ bool find_fix_point;
+ vector<pair<int, int> > source_gaps;
+ vector<pair<int, int> > target_gaps;
+ int num_symbols;
+ bool met_constraints;
+ bool get_gaps;
+};
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
new file mode 100644
index 00000000..2a96be93
--- /dev/null
+++ b/extractor/mocks/mock_rule_factory.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../grammar.h"
+#include "../rule_factory.h"
+
+class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
+ public:
+ MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids));
+};
diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h
new file mode 100644
index 00000000..b2306109
--- /dev/null
+++ b/extractor/mocks/mock_sampler.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../phrase_location.h"
+#include "../sampler.h"
+
+class MockSampler : public Sampler {
+ public:
+ MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location));
+};
diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h
new file mode 100644
index 00000000..48115ef4
--- /dev/null
+++ b/extractor/mocks/mock_scorer.h
@@ -0,0 +1,10 @@
+#include <gmock/gmock.h>
+
+#include "../scorer.h"
+#include "../features/feature.h"
+
+class MockScorer : public Scorer {
+ public:
+ MOCK_CONST_METHOD1(Score, vector<double>(const FeatureContext& context));
+ MOCK_CONST_METHOD0(GetFeatureNames, vector<string>());
+};
diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h
new file mode 100644
index 00000000..6dc6bba6
--- /dev/null
+++ b/extractor/mocks/mock_target_phrase_extractor.h
@@ -0,0 +1,12 @@
+#include <gmock/gmock.h>
+
+#include "../target_phrase_extractor.h"
+
+typedef pair<Phrase, PhraseAlignment> PhraseExtract;
+
+class MockTargetPhraseExtractor : public TargetPhraseExtractor {
+ public:
+ MOCK_CONST_METHOD6(ExtractPhrases, vector<PhraseExtract>(
+ const vector<pair<int, int> > &, const vector<int>&, int, int,
+ const unordered_map<int, int>&, int));
+};
diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h
new file mode 100644
index 00000000..a35c9327
--- /dev/null
+++ b/extractor/mocks/mock_translation_table.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../translation_table.h"
+
+class MockTranslationTable : public TranslationTable {
+ public:
+ MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&));
+ MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&));
+};
diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc
index c4e0c2ed..4325390c 100644
--- a/extractor/phrase_builder.cc
+++ b/extractor/phrase_builder.cc
@@ -9,10 +9,9 @@ PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) :
Phrase PhraseBuilder::Build(const vector<int>& symbols) {
Phrase phrase;
phrase.symbols = symbols;
- phrase.words.resize(symbols.size());
for (size_t i = 0; i < symbols.size(); ++i) {
if (vocabulary->IsTerminal(symbols[i])) {
- phrase.words[i] = vocabulary->GetTerminalValue(symbols[i]);
+ phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i]));
} else {
phrase.var_pos.push_back(i);
}
@@ -30,7 +29,7 @@ Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) {
}
for (size_t i = start_x; i < symbols.size(); ++i) {
- if (vocabulary->IsTerminal(symbols[i])) {
+ if (!vocabulary->IsTerminal(symbols[i])) {
++num_nonterminals;
symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals);
}
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
index 984407c5..62f1e714 100644
--- a/extractor/phrase_location.cc
+++ b/extractor/phrase_location.cc
@@ -10,7 +10,11 @@ PhraseLocation::PhraseLocation(const vector<int>& matchings,
num_subpatterns(num_subpatterns) {}
bool PhraseLocation::IsEmpty() {
- return sa_low >= sa_high || (num_subpatterns > 0 && matchings->size() == 0);
+ if (num_subpatterns > 0) {
+ return matchings->size() == 0;
+ } else {
+ return sa_low >= sa_high;
+ }
}
bool operator==(const PhraseLocation& a, const PhraseLocation& b) {
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 9a167976..8a76beb1 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -7,7 +7,6 @@
#include "suffix_array.h"
using namespace std;
-using namespace tr1;
int Precomputation::NON_TERMINAL = -1;
@@ -79,13 +78,16 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns(
}
vector<vector<int> > frequent_patterns;
- for (size_t i = 0; i < num_frequent_patterns && !heap.empty(); ++i) {
+ while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
int start = heap.top().second.first;
int len = heap.top().second.second;
heap.pop();
vector<int> pattern(data.begin() + start, data.begin() + start + len);
- frequent_patterns.push_back(pattern);
+ if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) ==
+ pattern.end()) {
+ frequent_patterns.push_back(pattern);
+ }
}
return frequent_patterns;
}
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 428505d8..28426bfa 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -2,8 +2,8 @@
#define _PRECOMPUTATION_H_
#include <memory>
-#include <tr1/unordered_map>
-#include <tr1/unordered_set>
+#include <unordered_map>
+#include <unordered_set>
#include <tuple>
#include <vector>
@@ -12,7 +12,6 @@
namespace fs = boost::filesystem;
using namespace std;
-using namespace tr1;
class SuffixArray;
diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc
index 9460020f..92343241 100644
--- a/extractor/rule_extractor.cc
+++ b/extractor/rule_extractor.cc
@@ -1,7 +1,6 @@
#include "rule_extractor.h"
#include <map>
-#include <tr1/unordered_set>
#include "alignment.h"
#include "data_array.h"
@@ -9,11 +8,11 @@
#include "phrase_builder.h"
#include "phrase_location.h"
#include "rule.h"
+#include "rule_extractor_helper.h"
#include "scorer.h"
-#include "vocabulary.h"
+#include "target_phrase_extractor.h"
using namespace std;
-using namespace tr1;
RuleExtractor::RuleExtractor(
shared_ptr<DataArray> source_data_array,
@@ -29,20 +28,50 @@ RuleExtractor::RuleExtractor(
bool require_aligned_terminal,
bool require_aligned_chunks,
bool require_tight_phrases) :
- source_data_array(source_data_array),
target_data_array(target_data_array),
- alignment(alignment),
+ source_data_array(source_data_array),
phrase_builder(phrase_builder),
scorer(scorer),
- vocabulary(vocabulary),
max_rule_span(max_rule_span),
min_gap_size(min_gap_size),
max_nonterminals(max_nonterminals),
max_rule_symbols(max_rule_symbols),
- require_aligned_terminal(require_aligned_terminal),
- require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {
+ helper = make_shared<RuleExtractorHelper>(
+ source_data_array, target_data_array, alignment, max_rule_span,
+ max_rule_symbols, require_aligned_terminal, require_aligned_chunks,
+ require_tight_phrases);
+ target_phrase_extractor = make_shared<TargetPhraseExtractor>(
+ target_data_array, alignment, phrase_builder, helper, vocabulary,
+ max_rule_span, require_tight_phrases);
+}
+
+RuleExtractor::RuleExtractor(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor,
+ shared_ptr<RuleExtractorHelper> helper,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ phrase_builder(phrase_builder),
+ scorer(scorer),
+ target_phrase_extractor(target_phrase_extractor),
+ helper(helper),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size),
+ max_nonterminals(max_nonterminals),
+ max_rule_symbols(max_rule_symbols),
require_tight_phrases(require_tight_phrases) {}
+RuleExtractor::RuleExtractor() {}
+
+RuleExtractor::~RuleExtractor() {}
+
vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
const PhraseLocation& location) const {
int num_subpatterns = location.num_subpatterns;
@@ -60,6 +89,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
}
}
+ int num_samples = matchings.size() / num_subpatterns;
vector<Rule> rules;
for (auto source_phrase_entry: alignments_counter) {
Phrase source_phrase = source_phrase_entry.first;
@@ -77,7 +107,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
}
FeatureContext context(source_phrase, target_phrase,
- source_phrase_counter[source_phrase], num_locations);
+ source_phrase_counter[source_phrase], num_locations, num_samples);
vector<double> scores = scorer->Score(context);
rules.push_back(Rule(source_phrase, target_phrase, scores,
most_frequent_alignment));
@@ -93,7 +123,8 @@ vector<Extract> RuleExtractor::ExtractAlignments(
int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
vector<int> source_low, source_high, target_low, target_high;
- GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id);
+ helper->GetLinksSpans(source_low, source_high, target_low, target_high,
+ sentence_id);
int num_subpatterns = matching.size();
vector<int> chunklen(num_subpatterns);
@@ -101,39 +132,44 @@ vector<Extract> RuleExtractor::ExtractAlignments(
chunklen[i] = phrase.GetChunkLen(i);
}
- if (!CheckAlignedTerminals(matching, chunklen, source_low) ||
- !CheckTightPhrases(matching, chunklen, source_low)) {
+ if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) ||
+ !helper->CheckTightPhrases(matching, chunklen, source_low)) {
return extracts;
}
int source_back_low = -1, source_back_high = -1;
int source_phrase_low = matching[0] - source_sent_start;
- int source_phrase_high = matching.back() + chunklen.back() - source_sent_start;
+ int source_phrase_high = matching.back() + chunklen.back() -
+ source_sent_start;
int target_phrase_low = -1, target_phrase_high = -1;
- if (!FindFixPoint(source_phrase_low, source_phrase_high, source_low,
- source_high, target_phrase_low, target_phrase_high,
- target_low, target_high, source_back_low, source_back_high,
- sentence_id, min_gap_size, 0,
- max_nonterminals - matching.size() + 1, 1, 1, false)) {
+ if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high,
+ target_low, target_high, source_back_low,
+ source_back_high, sentence_id, min_gap_size, 0,
+ max_nonterminals - matching.size() + 1, true, true,
+ false)) {
return extracts;
}
bool met_constraints = true;
int num_symbols = phrase.GetNumSymbols();
vector<pair<int, int> > source_gaps, target_gaps;
- if (!CheckGaps(source_gaps, target_gaps, matching, chunklen, source_low,
- source_high, target_low, target_high, source_phrase_low,
- source_phrase_high, source_back_low, source_back_high,
- num_symbols, met_constraints)) {
+ if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low,
+ source_high, target_low, target_high, source_phrase_low,
+ source_phrase_high, source_back_low, source_back_high,
+ num_symbols, met_constraints)) {
return extracts;
}
- bool start_x = source_back_low != source_phrase_low;
- bool end_x = source_back_high != source_phrase_high;
- Phrase source_phrase = phrase_builder->Extend(phrase, start_x, end_x);
+ bool starts_with_x = source_back_low != source_phrase_low;
+ bool ends_with_x = source_back_high != source_phrase_high;
+ Phrase source_phrase = phrase_builder->Extend(
+ phrase, starts_with_x, ends_with_x);
+ unordered_map<int, int> source_indexes = helper->GetSourceIndexes(
+ matching, chunklen, starts_with_x);
if (met_constraints) {
- AddExtracts(extracts, source_phrase, target_gaps, target_low,
- target_phrase_low, target_phrase_high, sentence_id);
+ AddExtracts(extracts, source_phrase, source_indexes, target_gaps,
+ target_low, target_phrase_low, target_phrase_high, sentence_id);
}
if (source_gaps.size() >= max_nonterminals ||
@@ -145,317 +181,24 @@ vector<Extract> RuleExtractor::ExtractAlignments(
for (int i = 0; i < 2; ++i) {
for (int j = 1 - i; j < 2; ++j) {
- AddNonterminalExtremities(extracts, source_phrase, source_phrase_low,
- source_phrase_high, source_back_low, source_back_high, source_low,
- source_high, target_low, target_high, target_gaps, sentence_id, i, j);
+ AddNonterminalExtremities(extracts, matching, chunklen, source_phrase,
+ source_back_low, source_back_high, source_low, source_high,
+ target_low, target_high, target_gaps, sentence_id, starts_with_x,
+ ends_with_x, i, j);
}
}
return extracts;
}
-void RuleExtractor::GetLinksSpans(
- vector<int>& source_low, vector<int>& source_high,
- vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
- // Ignore end of line markers.
- int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
- source_data_array->GetSentenceStart(sentence_id) - 1;
- int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
- target_data_array->GetSentenceStart(sentence_id) - 1;
- source_low = vector<int>(source_sent_len, -1);
- source_high = vector<int>(source_sent_len, -1);
-
- // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we
- // can speed it up.
- target_low = vector<int>(target_sent_len, -1);
- target_high = vector<int>(target_sent_len, -1);
- const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id);
- for (auto link: links) {
- if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
- source_low[link.first] = link.second;
- }
- source_high[link.first] = max(source_high[link.first], link.second + 1);
-
- if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
- target_low[link.second] = link.first;
- }
- target_high[link.second] = max(target_high[link.second], link.first + 1);
- }
-}
-
-bool RuleExtractor::CheckAlignedTerminals(const vector<int>& matching,
- const vector<int>& chunklen,
- const vector<int>& source_low) const {
- if (!require_aligned_terminal) {
- return true;
- }
-
- int sentence_id = source_data_array->GetSentenceId(matching[0]);
- int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
-
- int num_aligned_chunks = 0;
- for (size_t i = 0; i < chunklen.size(); ++i) {
- for (size_t j = 0; j < chunklen[i]; ++j) {
- int sent_index = matching[i] - source_sent_start + j;
- if (source_low[sent_index] != -1) {
- ++num_aligned_chunks;
- break;
- }
- }
- }
-
- if (num_aligned_chunks == 0) {
- return false;
- }
-
- return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
-}
-
-bool RuleExtractor::CheckTightPhrases(const vector<int>& matching,
- const vector<int>& chunklen,
- const vector<int>& source_low) const {
- if (!require_tight_phrases) {
- return true;
- }
-
- int sentence_id = source_data_array->GetSentenceId(matching[0]);
- int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
- for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
- int gap_start = matching[i] + chunklen[i] - source_sent_start;
- int gap_end = matching[i + 1] - 1 - source_sent_start;
- if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
- return false;
- }
- }
-
- return true;
-}
-
-bool RuleExtractor::FindFixPoint(
- int source_phrase_low, int source_phrase_high,
- const vector<int>& source_low, const vector<int>& source_high,
- int& target_phrase_low, int& target_phrase_high,
- const vector<int>& target_low, const vector<int>& target_high,
- int& source_back_low, int& source_back_high, int sentence_id,
- int min_source_gap_size, int min_target_gap_size,
- int max_new_x, int max_low_x, int max_high_x,
- bool allow_arbitrary_expansion) const {
- int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
- source_data_array->GetSentenceStart(sentence_id) - 1;
- int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
- target_data_array->GetSentenceStart(sentence_id) - 1;
-
- int prev_target_low = target_phrase_low;
- int prev_target_high = target_phrase_high;
- FindProjection(source_phrase_low, source_phrase_high, source_low,
- source_high, target_phrase_low, target_phrase_high);
-
- if (target_phrase_low == -1) {
- // TODO(pauldb): Low priority corner case inherited from Adam's code:
- // If w is unaligned, but we don't require aligned terminals, returning an
- // error here prevents the extraction of the allowed rule
- // X -> X_1 w X_2 / X_1 X_2
- return false;
- }
-
- if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
- if (prev_target_low - target_phrase_low < min_target_gap_size) {
- target_phrase_low = prev_target_low - min_target_gap_size;
- if (target_phrase_low < 0) {
- return false;
- }
- }
- }
-
- if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
- if (target_phrase_high - prev_target_high < min_target_gap_size) {
- target_phrase_high = prev_target_high + min_target_gap_size;
- if (target_phrase_high > target_sent_len) {
- return false;
- }
- }
- }
-
- if (target_phrase_high - target_phrase_low > max_rule_span) {
- return false;
- }
-
- source_back_low = source_back_high = -1;
- FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
- source_back_low, source_back_high);
- int new_x = 0, new_low_x = 0, new_high_x = 0;
-
- while (true) {
- source_back_low = min(source_back_low, source_phrase_low);
- source_back_high = max(source_back_high, source_phrase_high);
-
- if (source_back_low == source_phrase_low &&
- source_back_high == source_phrase_high) {
- return true;
- }
-
- if (new_low_x >= max_low_x && source_back_low < source_phrase_low) {
- // Extension on the left side not allowed.
- return false;
- }
- if (new_high_x >= max_high_x && source_back_high > source_phrase_high) {
- // Extension on the right side not allowed.
- return false;
- }
-
- // Extend left side.
- if (source_back_low < source_phrase_low) {
- if (new_x >= max_new_x) {
- return false;
- }
- ++new_x; ++new_low_x;
- if (source_phrase_low - source_back_low < min_source_gap_size) {
- source_back_low = source_phrase_low - min_source_gap_size;
- if (source_back_low < 0) {
- return false;
- }
- }
- }
-
- // Extend right side.
- if (source_back_high > source_phrase_high) {
- if (new_x >= max_new_x) {
- return false;
- }
- ++new_x; ++new_high_x;
- if (source_back_high - source_phrase_high < min_source_gap_size) {
- source_back_high = source_phrase_high + min_source_gap_size;
- if (source_back_high > source_sent_len) {
- return false;
- }
- }
- }
-
- if (source_back_high - source_back_low > max_rule_span) {
- // Rule span too wide.
- return false;
- }
-
- prev_target_low = target_phrase_low;
- prev_target_high = target_phrase_high;
- FindProjection(source_back_low, source_phrase_low, source_low, source_high,
- target_phrase_low, target_phrase_high);
- FindProjection(source_phrase_high, source_back_high, source_low,
- source_high, target_phrase_low, target_phrase_high);
- if (prev_target_low == target_phrase_low &&
- prev_target_high == target_phrase_high) {
- return true;
- }
-
- if (!allow_arbitrary_expansion) {
- // Arbitrary expansion not allowed.
- return false;
- }
- if (target_phrase_high - target_phrase_low > max_rule_span) {
- // Target side too wide.
- return false;
- }
-
- source_phrase_low = source_back_low;
- source_phrase_high = source_back_high;
- FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
- source_back_low, source_back_high);
- FindProjection(prev_target_high, target_phrase_high, target_low,
- target_high, source_back_low, source_back_high);
- }
-
- return false;
-}
-
-void RuleExtractor::FindProjection(
- int source_phrase_low, int source_phrase_high,
- const vector<int>& source_low, const vector<int>& source_high,
- int& target_phrase_low, int& target_phrase_high) const {
- for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
- if (source_low[i] != -1) {
- if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
- target_phrase_low = source_low[i];
- }
- target_phrase_high = max(target_phrase_high, source_high[i]);
- }
- }
-}
-
-bool RuleExtractor::CheckGaps(
- vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
- const vector<int>& matching, const vector<int>& chunklen,
- const vector<int>& source_low, const vector<int>& source_high,
- const vector<int>& target_low, const vector<int>& target_high,
- int source_phrase_low, int source_phrase_high, int source_back_low,
- int source_back_high, int& num_symbols, bool& met_constraints) const {
- int sentence_id = source_data_array->GetSentenceId(matching[0]);
- int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
-
- if (source_back_low < source_phrase_low) {
- source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
- if (num_symbols >= max_rule_symbols) {
- // Source side contains too many symbols.
- return false;
- }
- ++num_symbols;
- if (require_tight_phrases && (source_low[source_back_low] == -1 ||
- source_low[source_phrase_low - 1] == -1)) {
- // Inside edges of preceding gap are not tight.
- return false;
- }
- } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
- // This is not a hard error. We can't extract this phrase, but we might
- // still be able to extract a superphrase.
- met_constraints = false;
- }
-
- for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
- int gap_start = matching[i] + chunklen[i] - source_sent_start;
- int gap_end = matching[i + 1] - source_sent_start;
- source_gaps.push_back(make_pair(gap_start, gap_end));
- }
-
- if (source_phrase_high < source_back_high) {
- source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
- if (num_symbols >= max_rule_symbols) {
- // Source side contains too many symbols.
- return false;
- }
- ++num_symbols;
- if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
- source_low[source_back_high - 1] == -1)) {
- // Inside edges of following gap are not tight.
- return false;
- }
- } else if (require_tight_phrases &&
- source_low[source_phrase_high - 1] == -1) {
- // This is not a hard error. We can't extract this phrase, but we might
- // still be able to extract a superphrase.
- met_constraints = false;
- }
-
- target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
- for (size_t i = 0; i < source_gaps.size(); ++i) {
- if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
- source_high, target_gaps[i].first, target_gaps[i].second,
- target_low, target_high, source_gaps[i].first,
- source_gaps[i].second, sentence_id, 0, 0, 0, 0, 0,
- false)) {
- // Gap fails integrity check.
- return false;
- }
- }
-
- return true;
-}
-
void RuleExtractor::AddExtracts(
vector<Extract>& extracts, const Phrase& source_phrase,
+ const unordered_map<int, int>& source_indexes,
const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
int target_phrase_low, int target_phrase_high, int sentence_id) const {
- vector<pair<Phrase, PhraseAlignment> > target_phrases = ExtractTargetPhrases(
+ auto target_phrases = target_phrase_extractor->ExtractPhrases(
target_gaps, target_low, target_phrase_low, target_phrase_high,
- sentence_id);
+ source_indexes, sentence_id);
if (target_phrases.size() > 0) {
double pairs_count = 1.0 / target_phrases.size();
@@ -466,147 +209,29 @@ void RuleExtractor::AddExtracts(
}
}
-vector<pair<Phrase, PhraseAlignment> > RuleExtractor::ExtractTargetPhrases(
- const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
- int target_phrase_low, int target_phrase_high, int sentence_id) const {
- int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
- target_data_array->GetSentenceStart(sentence_id) - 1;
-
- vector<int> target_gap_order(target_gaps.size());
- for (size_t i = 0; i < target_gap_order.size(); ++i) {
- for (size_t j = 0; j < i; ++j) {
- if (target_gaps[target_gap_order[j]] < target_gaps[i]) {
- ++target_gap_order[i];
- } else {
- ++target_gap_order[j];
- }
- }
- }
-
- int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
- if (!require_tight_phrases) {
- while (target_x_low > 0 &&
- target_phrase_high - target_x_low < max_rule_span &&
- target_low[target_x_low - 1] == -1) {
- --target_x_low;
- }
- while (target_x_high + 1 < target_sent_len &&
- target_x_high - target_phrase_low < max_rule_span &&
- target_low[target_x_high + 1] == -1) {
- ++target_x_high;
- }
- }
-
- vector<pair<int, int> > gaps(target_gaps.size());
- for (size_t i = 0; i < gaps.size(); ++i) {
- gaps[i] = target_gaps[target_gap_order[i]];
- if (!require_tight_phrases) {
- while (gaps[i].first > target_x_low &&
- target_low[gaps[i].first] == -1) {
- --gaps[i].first;
- }
- while (gaps[i].second < target_x_high &&
- target_low[gaps[i].second] == -1) {
- ++gaps[i].second;
- }
- }
- }
-
- vector<pair<int, int> > ranges(2 * gaps.size() + 2);
- ranges.front() = make_pair(target_x_low, target_phrase_low);
- ranges.back() = make_pair(target_phrase_high, target_x_high);
- for (size_t i = 0; i < gaps.size(); ++i) {
- ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[i].first);
- ranges[i * 2 + 2] = make_pair(target_gaps[i].second, gaps[i].second);
- }
-
- vector<pair<Phrase, PhraseAlignment> > target_phrases;
- vector<int> subpatterns(ranges.size());
- GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order,
- target_phrase_low, target_phrase_high, sentence_id);
- return target_phrases;
-}
-
-void RuleExtractor::GeneratePhrases(
- vector<pair<Phrase, PhraseAlignment> >& target_phrases,
- const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns,
- const vector<int>& target_gap_order, int target_phrase_low,
- int target_phrase_high, int sentence_id) const {
- if (index >= ranges.size()) {
- if (subpatterns.back() - subpatterns.front() > max_rule_span) {
+void RuleExtractor::AddNonterminalExtremities(
+ vector<Extract>& extracts, const vector<int>& matching,
+ const vector<int>& chunklen, const Phrase& source_phrase,
+ int source_back_low, int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high, vector<pair<int, int> > target_gaps,
+ int sentence_id, int starts_with_x, int ends_with_x, int extend_left,
+ int extend_right) const {
+ int source_x_low = source_back_low, source_x_high = source_back_high;
+
+ if (require_tight_phrases) {
+ if (source_low[source_back_low - extend_left] == -1 ||
+ source_low[source_back_high + extend_right - 1] == -1) {
return;
}
-
- vector<int> symbols;
- unordered_set<int> target_indexes;
- int offset = 1;
- if (subpatterns.front() != target_phrase_low) {
- offset = 2;
- symbols.push_back(vocabulary->GetNonterminalIndex(1));
- }
-
- int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
- for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
- for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
- symbols.push_back(target_data_array->AtIndex(target_sent_start + j));
- target_indexes.insert(j);
- }
- if (i < target_gap_order.size()) {
- symbols.push_back(vocabulary->GetNonterminalIndex(
- target_gap_order[i] + offset));
- }
- }
-
- if (subpatterns.back() != target_phrase_high) {
- symbols.push_back(target_gap_order.size() + offset);
- }
-
- const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id);
- vector<pair<int, int> > alignment;
- for (pair<int, int> link: links) {
- if (target_indexes.count(link.second)) {
- alignment.push_back(link);
- }
- }
-
- target_phrases.push_back(make_pair(phrase_builder->Build(symbols),
- alignment));
- return;
- }
-
- subpatterns[index] = ranges[index].first;
- if (index > 0) {
- subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
}
- while (subpatterns[index] <= ranges[index].second) {
- subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
- while (subpatterns[index + 1] <= ranges[index + 1].second) {
- GeneratePhrases(target_phrases, ranges, index + 2, subpatterns,
- target_gap_order, target_phrase_low, target_phrase_high,
- sentence_id);
- ++subpatterns[index + 1];
- }
- ++subpatterns[index];
- }
-}
-void RuleExtractor::AddNonterminalExtremities(
- vector<Extract>& extracts, const Phrase& source_phrase,
- int source_phrase_low, int source_phrase_high, int source_back_low,
- int source_back_high, const vector<int>& source_low,
- const vector<int>& source_high, const vector<int>& target_low,
- const vector<int>& target_high, const vector<pair<int, int> >& target_gaps,
- int sentence_id, int extend_left, int extend_right) const {
- int source_x_low = source_phrase_low, source_x_high = source_phrase_high;
if (extend_left) {
- if (source_back_low != source_phrase_low ||
- source_phrase_low < min_gap_size ||
- (require_tight_phrases && (source_low[source_phrase_low - 1] == -1 ||
- source_low[source_back_high - 1] == -1))) {
+ if (starts_with_x || source_back_low < min_gap_size) {
return;
}
- source_x_low = source_phrase_low - min_gap_size;
+ source_x_low = source_back_low - min_gap_size;
if (require_tight_phrases) {
while (source_x_low >= 0 && source_low[source_x_low] == -1) {
--source_x_low;
@@ -618,15 +243,11 @@ void RuleExtractor::AddNonterminalExtremities(
}
if (extend_right) {
- int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
- source_data_array->GetSentenceStart(sentence_id) - 1;
- if (source_back_high != source_phrase_high ||
- source_phrase_high + min_gap_size > source_sent_len ||
- (require_tight_phrases && (source_low[source_phrase_low] == -1 ||
- source_low[source_phrase_high] == -1))) {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ if (ends_with_x || source_back_high + min_gap_size > source_sent_len) {
return;
}
- source_x_high = source_phrase_high + min_gap_size;
+ source_x_high = source_back_high + min_gap_size;
if (require_tight_phrases) {
while (source_x_high <= source_sent_len &&
source_low[source_x_high - 1] == -1) {
@@ -639,41 +260,56 @@ void RuleExtractor::AddNonterminalExtremities(
}
}
+ int new_nonterminals = extend_left + extend_right;
if (source_x_high - source_x_low > max_rule_span ||
- target_gaps.size() + extend_left + extend_right > max_nonterminals) {
+ target_gaps.size() + new_nonterminals > max_nonterminals ||
+ source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) {
return;
}
int target_x_low = -1, target_x_high = -1;
- if (!FindFixPoint(source_x_low, source_x_high, source_low, source_high,
- target_x_low, target_x_high, target_low, target_high,
- source_x_low, source_x_high, sentence_id, 1, 1,
- extend_left + extend_right, extend_left, extend_right,
- true)) {
+ if (!helper->FindFixPoint(source_x_low, source_x_high, source_low,
+ source_high, target_x_low, target_x_high,
+ target_low, target_high, source_x_low,
+ source_x_high, sentence_id, 1, 1,
+ new_nonterminals, extend_left, extend_right,
+ true)) {
return;
}
- int source_gap_low = -1, source_gap_high = -1, target_gap_low = -1,
- target_gap_high = -1;
- if (extend_left &&
- ((require_tight_phrases && source_low[source_x_low] == -1) ||
- !FindFixPoint(source_x_low, source_phrase_low, source_low, source_high,
- target_gap_low, target_gap_high, target_low, target_high,
- source_gap_low, source_gap_high, sentence_id,
- 0, 0, 0, 0, 0, false))) {
- return;
+ if (extend_left) {
+ int source_gap_low = -1, source_gap_high = -1;
+ int target_gap_low = -1, target_gap_high = -1;
+ if ((require_tight_phrases && source_low[source_x_low] == -1) ||
+ !helper->FindFixPoint(source_x_low, source_back_low, source_low,
+ source_high, target_gap_low, target_gap_high,
+ target_low, target_high, source_gap_low,
+ source_gap_high, sentence_id, 0, 0, 0, false,
+ false, false)) {
+ return;
+ }
+ target_gaps.insert(target_gaps.begin(),
+ make_pair(target_gap_low, target_gap_high));
}
- if (extend_right &&
- ((require_tight_phrases && source_low[source_x_high - 1] == -1) ||
- !FindFixPoint(source_phrase_high, source_x_high, source_low, source_high,
- target_gap_low, target_gap_high, target_low, target_high,
- source_gap_low, source_gap_high, sentence_id,
- 0, 0, 0, 0, 0, false))) {
- return;
+
+ if (extend_right) {
+ int target_gap_low = -1, target_gap_high = -1;
+ int source_gap_low = -1, source_gap_high = -1;
+ if ((require_tight_phrases && source_low[source_x_high - 1] == -1) ||
+ !helper->FindFixPoint(source_back_high, source_x_high, source_low,
+ source_high, target_gap_low, target_gap_high,
+ target_low, target_high, source_gap_low,
+ source_gap_high, sentence_id, 0, 0, 0, false,
+ false, false)) {
+ return;
+ }
+ target_gaps.push_back(make_pair(target_gap_low, target_gap_high));
}
Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left,
extend_right);
- AddExtracts(extracts, new_source_phrase, target_gaps, target_low,
- target_x_low, target_x_high, sentence_id);
+ unordered_map<int, int> source_indexes = helper->GetSourceIndexes(
+ matching, chunklen, extend_left || starts_with_x);
+ AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps,
+ target_low, target_x_low, target_x_high, sentence_id);
}
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h
index f668de24..a087dc6d 100644
--- a/extractor/rule_extractor.h
+++ b/extractor/rule_extractor.h
@@ -2,6 +2,7 @@
#define _RULE_EXTRACTOR_H_
#include <memory>
+#include <unordered_map>
#include <vector>
#include "phrase.h"
@@ -13,8 +14,9 @@ class DataArray;
class PhraseBuilder;
class PhraseLocation;
class Rule;
+class RuleExtractorHelper;
class Scorer;
-class Vocabulary;
+class TargetPhraseExtractor;
typedef vector<pair<int, int> > PhraseAlignment;
@@ -46,84 +48,56 @@ class RuleExtractor {
bool require_aligned_chunks,
bool require_tight_phrases);
- vector<Rule> ExtractRules(const Phrase& phrase,
- const PhraseLocation& location) const;
+ // For testing only.
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor,
+ shared_ptr<RuleExtractorHelper> helper,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_tight_phrases);
+
+ virtual ~RuleExtractor();
+
+ virtual vector<Rule> ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const;
+
+ protected:
+ RuleExtractor();
private:
vector<Extract> ExtractAlignments(const Phrase& phrase,
const vector<int>& matching) const;
- void GetLinksSpans(vector<int>& source_low, vector<int>& source_high,
- vector<int>& target_low, vector<int>& target_high,
- int sentence_id) const;
-
- bool CheckAlignedTerminals(const vector<int>& matching,
- const vector<int>& chunklen,
- const vector<int>& source_low) const;
-
- bool CheckTightPhrases(const vector<int>& matching,
- const vector<int>& chunklen,
- const vector<int>& source_low) const;
-
- bool FindFixPoint(
- int source_phrase_start, int source_phrase_end,
- const vector<int>& source_low, const vector<int>& source_high,
- int& target_phrase_start, int& target_phrase_end,
- const vector<int>& target_low, const vector<int>& target_high,
- int& source_back_low, int& source_back_high, int sentence_id,
- int min_source_gap_size, int min_target_gap_size,
- int max_new_x, int max_low_x, int max_high_x,
- bool allow_arbitrary_expansion) const;
-
- void FindProjection(
- int source_phrase_start, int source_phrase_end,
- const vector<int>& source_low, const vector<int>& source_high,
- int& target_phrase_low, int& target_phrase_end) const;
-
- bool CheckGaps(
- vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
- const vector<int>& matching, const vector<int>& chunklen,
- const vector<int>& source_low, const vector<int>& source_high,
- const vector<int>& target_low, const vector<int>& target_high,
- int source_phrase_low, int source_phrase_high, int source_back_low,
- int source_back_high, int& num_symbols, bool& met_constraints) const;
-
void AddExtracts(
vector<Extract>& extracts, const Phrase& source_phrase,
+ const unordered_map<int, int>& source_indexes,
const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
int target_phrase_low, int target_phrase_high, int sentence_id) const;
- vector<pair<Phrase, PhraseAlignment> > ExtractTargetPhrases(
- const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
- int target_phrase_low, int target_phrase_high, int sentence_id) const;
-
- void GeneratePhrases(
- vector<pair<Phrase, PhraseAlignment> >& target_phrases,
- const vector<pair<int, int> >& ranges, int index,
- vector<int>& subpatterns, const vector<int>& target_gap_order,
- int target_phrase_low, int target_phrase_high, int sentence_id) const;
-
void AddNonterminalExtremities(
- vector<Extract>& extracts, const Phrase& source_phrase,
- int source_phrase_low, int source_phrase_high, int source_back_low,
- int source_back_high, const vector<int>& source_low,
+ vector<Extract>& extracts, const vector<int>& matching,
+ const vector<int>& chunklen, const Phrase& source_phrase,
+ int source_back_low, int source_back_high, const vector<int>& source_low,
const vector<int>& source_high, const vector<int>& target_low,
- const vector<int>& target_high,
- const vector<pair<int, int> >& target_gaps, int sentence_id,
- int extend_left, int extend_right) const;
+ const vector<int>& target_high, vector<pair<int, int> > target_gaps,
+ int sentence_id, int starts_with_x, int ends_with_x, int extend_left,
+ int extend_right) const;
- shared_ptr<DataArray> source_data_array;
+ private:
shared_ptr<DataArray> target_data_array;
- shared_ptr<Alignment> alignment;
+ shared_ptr<DataArray> source_data_array;
shared_ptr<PhraseBuilder> phrase_builder;
shared_ptr<Scorer> scorer;
- shared_ptr<Vocabulary> vocabulary;
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractorHelper> helper;
int max_rule_span;
int min_gap_size;
int max_nonterminals;
int max_rule_symbols;
- bool require_aligned_terminal;
- bool require_aligned_chunks;
bool require_tight_phrases;
};
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
new file mode 100644
index 00000000..ed6ae3a1
--- /dev/null
+++ b/extractor/rule_extractor_helper.cc
@@ -0,0 +1,356 @@
+#include "rule_extractor_helper.h"
+
+#include "data_array.h"
+#include "alignment.h"
+
+RuleExtractorHelper::RuleExtractorHelper(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ target_data_array(target_data_array),
+ alignment(alignment),
+ max_rule_span(max_rule_span),
+ max_rule_symbols(max_rule_symbols),
+ require_aligned_terminal(require_aligned_terminal),
+ require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {}
+
+RuleExtractorHelper::RuleExtractorHelper() {}
+
+RuleExtractorHelper::~RuleExtractorHelper() {}
+
+void RuleExtractorHelper::GetLinksSpans(
+ vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ source_low = vector<int>(source_sent_len, -1);
+ source_high = vector<int>(source_sent_len, -1);
+
+ // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we
+ // can speed it up.
+ target_low = vector<int>(target_sent_len, -1);
+ target_high = vector<int>(target_sent_len, -1);
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ for (auto link: links) {
+ if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
+ source_low[link.first] = link.second;
+ }
+ source_high[link.first] = max(source_high[link.first], link.second + 1);
+
+ if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
+ target_low[link.second] = link.first;
+ }
+ target_high[link.second] = max(target_high[link.second], link.first + 1);
+ }
+}
+
+bool RuleExtractorHelper::CheckAlignedTerminals(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_aligned_terminal) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ int num_aligned_chunks = 0;
+ for (size_t i = 0; i < chunklen.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ int sent_index = matching[i] - source_sent_start + j;
+ if (source_low[sent_index] != -1) {
+ ++num_aligned_chunks;
+ break;
+ }
+ }
+ }
+
+ if (num_aligned_chunks == 0) {
+ return false;
+ }
+
+ return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
+}
+
+bool RuleExtractorHelper::CheckTightPhrases(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_tight_phrases) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - 1 - source_sent_start;
+ if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RuleExtractorHelper::FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const {
+ int prev_target_low = target_phrase_low;
+ int prev_target_high = target_phrase_high;
+
+ FindProjection(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+
+ if (target_phrase_low == -1) {
+ // TODO(pauldb): Low priority corner case inherited from Adam's code:
+ // If w is unaligned, but we don't require aligned terminals, returning an
+ // error here prevents the extraction of the allowed rule
+ // X -> X_1 w X_2 / X_1 X_2
+ return false;
+ }
+
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
+ if (prev_target_low - target_phrase_low < min_target_gap_size) {
+ target_phrase_low = prev_target_low - min_target_gap_size;
+ if (target_phrase_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
+ if (target_phrase_high - prev_target_high < min_target_gap_size) {
+ target_phrase_high = prev_target_high + min_target_gap_size;
+ if (target_phrase_high > target_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ return false;
+ }
+
+ source_back_low = source_back_high = -1;
+ FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high);
+ int new_x = 0;
+ bool new_low_x = false, new_high_x = false;
+ while (true) {
+ source_back_low = min(source_back_low, source_phrase_low);
+ source_back_high = max(source_back_high, source_phrase_high);
+
+ if (source_back_low == source_phrase_low &&
+ source_back_high == source_phrase_high) {
+ return true;
+ }
+
+ if (!allow_low_x && source_back_low < source_phrase_low) {
+ // Extension on the left side not allowed.
+ return false;
+ }
+ if (!allow_high_x && source_back_high > source_phrase_high) {
+ // Extension on the right side not allowed.
+ return false;
+ }
+
+ // Extend left side.
+ if (source_back_low < source_phrase_low) {
+ if (new_low_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_low_x = true;
+ ++new_x;
+ }
+ if (source_phrase_low - source_back_low < min_source_gap_size) {
+ source_back_low = source_phrase_low - min_source_gap_size;
+ if (source_back_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend right side.
+ if (source_back_high > source_phrase_high) {
+ if (new_high_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_high_x = true;
+ ++new_x;
+ }
+ if (source_back_high - source_phrase_high < min_source_gap_size) {
+ source_back_high = source_phrase_high + min_source_gap_size;
+ if (source_back_high > source_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (source_back_high - source_back_low > max_rule_span) {
+ // Rule span too wide.
+ return false;
+ }
+
+ prev_target_low = target_phrase_low;
+ prev_target_high = target_phrase_high;
+ FindProjection(source_back_low, source_phrase_low, source_low, source_high,
+ target_phrase_low, target_phrase_high);
+ FindProjection(source_phrase_high, source_back_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+ if (prev_target_low == target_phrase_low &&
+ prev_target_high == target_phrase_high) {
+ return true;
+ }
+
+ if (!allow_arbitrary_expansion) {
+ // Arbitrary expansion not allowed.
+ return false;
+ }
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ // Target side too wide.
+ return false;
+ }
+
+ source_phrase_low = source_back_low;
+ source_phrase_high = source_back_high;
+ FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
+ source_back_low, source_back_high);
+ FindProjection(prev_target_high, target_phrase_high, target_low,
+ target_high, source_back_low, source_back_high);
+ }
+
+ return false;
+}
+
+void RuleExtractorHelper::FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const {
+ for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
+ if (source_low[i] != -1) {
+ if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
+ target_phrase_low = source_low[i];
+ }
+ target_phrase_high = max(target_phrase_high, source_high[i]);
+ }
+ }
+}
+
+bool RuleExtractorHelper::GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int& num_symbols, bool& met_constraints) const {
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ if (source_back_low < source_phrase_low) {
+ source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_back_low] == -1 ||
+ source_low[source_phrase_low - 1] == -1)) {
+ // Inside edges of preceding gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - source_sent_start;
+ source_gaps.push_back(make_pair(gap_start, gap_end));
+ }
+
+ if (source_phrase_high < source_back_high) {
+ source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
+ source_low[source_back_high - 1] == -1)) {
+ // Inside edges of following gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases &&
+ source_low[source_phrase_high - 1] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
+ for (size_t i = 0; i < source_gaps.size(); ++i) {
+ if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
+ source_high, target_gaps[i].first, target_gaps[i].second,
+ target_low, target_high, source_gaps[i].first,
+ source_gaps[i].second, sentence_id, 0, 0, 0, false, false,
+ false)) {
+ // Gap fails integrity check.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+vector<int> RuleExtractorHelper::GetGapOrder(
+ const vector<pair<int, int> >& gaps) const {
+ vector<int> gap_order(gaps.size());
+ for (size_t i = 0; i < gap_order.size(); ++i) {
+ for (size_t j = 0; j < i; ++j) {
+ if (gaps[gap_order[j]] < gaps[i]) {
+ ++gap_order[i];
+ } else {
+ ++gap_order[j];
+ }
+ }
+ }
+ return gap_order;
+}
+
+unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x) const {
+ unordered_map<int, int> source_indexes;
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ int num_symbols = starts_with_x;
+ for (size_t i = 0; i < matching.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ source_indexes[matching[i] + j - source_sent_start] = num_symbols;
+ ++num_symbols;
+ }
+ ++num_symbols;
+ }
+ return source_indexes;
+}
diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h
new file mode 100644
index 00000000..3478bfc8
--- /dev/null
+++ b/extractor/rule_extractor_helper.h
@@ -0,0 +1,82 @@
+#ifndef _RULE_EXTRACTOR_HELPER_H_
+#define _RULE_EXTRACTOR_HELPER_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+using namespace std;
+
+class Alignment;
+class DataArray;
+
+class RuleExtractorHelper {
+ public:
+ RuleExtractorHelper(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases);
+
+ virtual ~RuleExtractorHelper();
+
+ virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high,
+ int sentence_id) const;
+
+ virtual bool CheckAlignedTerminals(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const;
+
+ virtual bool CheckTightPhrases(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const;
+
+ virtual bool FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const;
+
+ virtual bool GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int& num_symbols, bool& met_constraints) const;
+
+ virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const;
+
+ // TODO(pauldb): Add unit tests.
+ virtual unordered_map<int, int> GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x) const;
+
+ protected:
+ RuleExtractorHelper();
+
+ private:
+ void FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const;
+
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<Alignment> alignment;
+ int max_rule_span;
+ int max_rule_symbols;
+ bool require_aligned_terminal;
+ bool require_aligned_chunks;
+ bool require_tight_phrases;
+};
+
+#endif
diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc
new file mode 100644
index 00000000..29213312
--- /dev/null
+++ b/extractor/rule_extractor_helper_test.cc
@@ -0,0 +1,622 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "rule_extractor_helper.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class RuleExtractorHelperTest : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(12));
+ EXPECT_CALL(*source_data_array, GetSentenceId(_))
+ .WillRepeatedly(Return(5));
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_))
+ .WillRepeatedly(Return(10));
+
+ target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(12));
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(0, 1), make_pair(2, 2), make_pair(3, 1)
+ };
+ alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(_)).WillRepeatedly(Return(links));
+ }
+
+ shared_ptr<MockDataArray> source_data_array;
+ shared_ptr<MockDataArray> target_data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<RuleExtractorHelper> helper;
+};
+
+TEST_F(RuleExtractorHelperTest, TestGetLinksSpans) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low, source_high, target_low, target_high;
+ helper->GetLinksSpans(source_low, source_high, target_low, target_high, 0);
+
+ vector<int> expected_source_low = {0, -1, 2, 1};
+ EXPECT_EQ(expected_source_low, source_low);
+ vector<int> expected_source_high = {2, -1, 3, 2};
+ EXPECT_EQ(expected_source_high, source_high);
+ vector<int> expected_target_low = {0, 0, 2};
+ EXPECT_EQ(expected_target_low, target_low);
+ vector<int> expected_target_high = {1, 4, 3};
+ EXPECT_EQ(expected_target_high, target_high);
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedFalse) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, false, false, true);
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0);
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0);
+
+ vector<int> matching, chunklen, source_low;
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedTerminal) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, false, true);
+
+ vector<int> matching = {10, 12};
+ vector<int> chunklen = {1, 3};
+ vector<int> source_low = {-1, 1, -1, 3, -1};
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+ source_low = {-1, 1, -1, -1, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckAlignedChunks) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> matching = {10, 12};
+ vector<int> chunklen = {1, 3};
+ vector<int> source_low = {2, 1, -1, 3, -1};
+ EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+ source_low = {-1, 1, -1, 3, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+ source_low = {2, 1, -1, -1, -1};
+ EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, source_low));
+}
+
+
+TEST_F(RuleExtractorHelperTest, TestCheckTightPhrasesFalse) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, false);
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0);
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0);
+
+ vector<int> matching, chunklen, source_low;
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low));
+}
+
+TEST_F(RuleExtractorHelperTest, TestCheckTightPhrases) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> matching = {10, 14, 18};
+ vector<int> chunklen = {2, 3, 1};
+ // No missing links.
+ vector<int> source_low = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low));
+
+ // Missing link at the beginning or ending of a gap.
+ source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low));
+ source_low = {0, 1, 2, -1, 4, 5, 6, 7, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low));
+ source_low = {0, 1, 2, 3, 4, 5, 6, -1, 8};
+ EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low));
+
+ // Missing link inside the gap.
+ chunklen = {1, 3, 1};
+ source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8};
+ EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointBadEdgeCase) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<int> source_low = {0, -1, 2};
+ vector<int> source_high = {1, -1, 3};
+ vector<int> target_low = {0, -1, 2};
+ vector<int> target_high = {1, -1, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = 1;
+
+ // This should be in fact true. See comment about the inherited bug.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 0, 0,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSentenceOutOfBounds) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 1, target_phrase_high = 2;
+
+ // Extend out of sentence to left.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 2,
+ 0, false, false, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 2,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetTooWide) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 5, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {0, 0, 0, 0, 0, 0, 0};
+ vector<int> source_high = {7, 7, 7, 7, 7, 7, 7};
+ vector<int> target_low = {0, -1, -1, -1, -1, -1, 0};
+ vector<int> target_high = {7, -1, -1, -1, -1, -1, 7};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Projection is too wide.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, false, false, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPoint) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {1, 1, 1, 3, 4, 5, 5};
+ vector<int> source_high = {2, 2, 3, 4, 6, 6, 6};
+ vector<int> target_low = {-1, 0, 2, 3, 4, 4, -1};
+ vector<int> target_high = {-1, 3, 3, 4, 5, 7, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 2, target_phrase_high = 5;
+
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 1, 1, 1,
+ 2, true, true, false));
+ EXPECT_EQ(1, target_phrase_low);
+ EXPECT_EQ(6, target_phrase_high);
+ EXPECT_EQ(0, source_back_low);
+ EXPECT_EQ(7, source_back_high);
+
+ source_low = {0, -1, 1, 3, 4, -1, 6};
+ source_high = {1, -1, 3, 4, 6, -1, 7};
+ target_low = {0, 2, 2, 3, 4, 4, 6};
+ target_high = {1, 3, 3, 4, 5, 5, 7};
+ source_phrase_low = 2, source_phrase_high = 5;
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 1, 1, 1,
+ 2, true, true, false));
+ EXPECT_EQ(1, target_phrase_low);
+ EXPECT_EQ(6, target_phrase_high);
+ EXPECT_EQ(2, source_back_low);
+ EXPECT_EQ(5, source_back_high);
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointExtensionsNotAllowed) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Extension on the left side not allowed.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 1, false, true, false));
+ // Extension on the left side is allowed, but we can't add anymore X.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, true, true, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ // Extension on the right side not allowed.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 1, true, false, false));
+ // Extension on the right side is allowed, but we can't add anymore X.
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 0, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointSourceSentenceOutOfBounds) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(3));
+
+ vector<int> source_low = {0, 0, 2};
+ vector<int> source_high = {1, 2, 3};
+ vector<int> target_low = {0, 1, 2};
+ vector<int> target_high = {2, 2, 3};
+ int source_phrase_low = 1, source_phrase_high = 2;
+ int source_back_low, source_back_high;
+ int target_phrase_low = 1, target_phrase_high = 2;
+ // Extend out of sentence to left.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 1,
+ 1, true, true, false));
+ source_low = {0, 1, 2};
+ source_high = {1, 3, 3};
+ target_low = {0, 1, 1};
+ target_high = {1, 2, 3};
+ target_phrase_low = 1, target_phrase_high = 2;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 2, 1,
+ 1, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSourceWide) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 5, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ vector<int> source_low = {2, -1, 2, 3, 4, -1, 4};
+ vector<int> source_high = {3, -1, 3, 4, 5, -1, 5};
+ vector<int> target_low = {-1, -1, 0, 3, 4, -1, -1};
+ vector<int> target_high = {-1, -1, 3, 4, 7, -1, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+
+ // Second projection (on source side) is too wide.
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 2, true, true, false));
+}
+
+TEST_F(RuleExtractorHelperTest, TestFindFixPointArbitraryExpansion) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 20, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(11));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(11));
+
+ vector<int> source_low = {1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9};
+ vector<int> source_high = {2, 3, 4, 5, 5, 6, 7, 8, 9, 10, 10};
+ vector<int> target_low = {-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, -1};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, 8, 9, 10, 11, -1};
+ int source_phrase_low = 4, source_phrase_high = 7;
+ int source_back_low, source_back_high;
+ int target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 10, true, true, false));
+
+ source_phrase_low = 4, source_phrase_high = 7;
+ target_phrase_low = -1, target_phrase_high = -1;
+ EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high,
+ source_low, source_high, target_phrase_low,
+ target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high, 0, 1, 1,
+ 10, true, true, true));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapOrder) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+
+ vector<pair<int, int> > gaps =
+ {make_pair(0, 3), make_pair(5, 8), make_pair(11, 12), make_pair(15, 17)};
+ vector<int> expected_gap_order = {0, 1, 2, 3};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+
+ gaps = {make_pair(15, 17), make_pair(8, 9), make_pair(5, 6), make_pair(0, 3)};
+ expected_gap_order = {3, 2, 1, 0};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+
+ gaps = {make_pair(8, 9), make_pair(5, 6), make_pair(0, 3), make_pair(15, 17)};
+ expected_gap_order = {2, 1, 0, 3};
+ EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsExceedNumSymbols) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {1, 1, 2, 3, 4, 5, 6};
+ vector<int> source_high = {2, 2, 3, 4, 5, 6, 7};
+ vector<int> target_low = {-1, 0, 2, 3, 4, 5, 6};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, 7};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 0, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+
+ source_low = {0, 1, 2, 3, 4, 5, 5};
+ source_high = {1, 2, 3, 4, 5, 6, 6};
+ target_low = {0, 1, 2, 3, 4, 5, -1};
+ target_high = {1, 2, 3, 4, 5, 7, -1};
+ source_phrase_low = 1, source_phrase_high = 6;
+ source_back_low = 1, source_back_high = 7;
+ num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsExtensionsNotTight) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 7, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 1, 2, 3, 4, 5, -1};
+ vector<int> source_high = {-1, 2, 3, 4, 5, 6, -1};
+ vector<int> target_low = {-1, 1, 2, 3, 4, 5, -1};
+ vector<int> target_high = {-1, 2, 3, 4, 5, 6, -1};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 0, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+
+ source_phrase_low = 1, source_phrase_high = 6;
+ source_back_low = 1, source_back_high = 7;
+ num_symbols = 5;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsNotTightExtremities) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 7, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, -1, 2, 3, 4, 5, 6};
+ vector<int> source_high = {-1, -1, 3, 4, 5, 6, 7};
+ vector<int> target_low = {-1, -1, 2, 3, 4, 5, 6};
+ vector<int> target_high = {-1, -1, 3, 4, 5, 6, 7};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+ EXPECT_FALSE(met_constraints);
+ vector<pair<int, int> > expected_gaps = {make_pair(2, 3), make_pair(4, 5)};
+ EXPECT_EQ(expected_gaps, source_gaps);
+ EXPECT_EQ(expected_gaps, target_gaps);
+
+ source_low = {-1, 1, 2, 3, 4, -1, 6};
+ source_high = {-1, 2, 3, 4, 5, -1, 7};
+ target_low = {-1, 1, 2, 3, 4, -1, 6};
+ target_high = {-1, 2, 3, 4, 5, -1, 7};
+ met_constraints = true;
+ source_gaps.clear();
+ target_gaps.clear();
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+ EXPECT_FALSE(met_constraints);
+ EXPECT_EQ(expected_gaps, source_gaps);
+ EXPECT_EQ(expected_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapsWithExtensions) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 5, 2, 3, 4, 1, -1};
+ vector<int> source_high = {-1, 6, 3, 4, 5, 2, -1};
+ vector<int> target_low = {-1, 5, 2, 3, 4, 1, -1};
+ vector<int> target_high = {-1, 6, 3, 4, 5, 2, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {12, 14};
+ vector<int> chunklen = {1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 3;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+ vector<pair<int, int> > expected_source_gaps = {
+ make_pair(1, 2), make_pair(3, 4), make_pair(5, 6)
+ };
+ EXPECT_EQ(expected_source_gaps, source_gaps);
+ vector<pair<int, int> > expected_target_gaps = {
+ make_pair(5, 6), make_pair(3, 4), make_pair(1, 2)
+ };
+ EXPECT_EQ(expected_target_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGaps) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 1, 4, 3, 2, 5, -1};
+ vector<int> source_high = {-1, 2, 5, 4, 3, 6, -1};
+ vector<int> target_low = {-1, 1, 4, 3, 2, 5, -1};
+ vector<int> target_high = {-1, 2, 5, 4, 3, 6, -1};
+ int source_phrase_low = 1, source_phrase_high = 6;
+ int source_back_low = 1, source_back_high = 6;
+ vector<int> matching = {11, 13, 15};
+ vector<int> chunklen = {1, 1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 5;
+ EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+ vector<pair<int, int> > expected_source_gaps = {
+ make_pair(2, 3), make_pair(4, 5)
+ };
+ EXPECT_EQ(expected_source_gaps, source_gaps);
+ vector<pair<int, int> > expected_target_gaps = {
+ make_pair(4, 5), make_pair(2, 3)
+ };
+ EXPECT_EQ(expected_target_gaps, target_gaps);
+}
+
+TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) {
+ helper = make_shared<RuleExtractorHelper>(source_data_array,
+ target_data_array, alignment, 10, 5, true, true, true);
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+ EXPECT_CALL(*target_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(7));
+
+ bool met_constraints = true;
+ vector<int> source_low = {-1, 3, 2, 3, 4, 3, -1};
+ vector<int> source_high = {-1, 4, 3, 4, 5, 4, -1};
+ vector<int> target_low = {-1, -1, 2, 1, 4, -1, -1};
+ vector<int> target_high = {-1, -1, 3, 6, 5, -1, -1};
+ int source_phrase_low = 2, source_phrase_high = 5;
+ int source_back_low = 2, source_back_high = 5;
+ vector<int> matching = {12, 14};
+ vector<int> chunklen = {1, 1};
+ vector<pair<int, int> > source_gaps, target_gaps;
+ int num_symbols = 3;
+ EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen,
+ source_low, source_high, target_low, target_high,
+ source_phrase_low, source_phrase_high,
+ source_back_low, source_back_high, num_symbols,
+ met_constraints));
+}
+
+} // namespace
diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc
new file mode 100644
index 00000000..0be44d4d
--- /dev/null
+++ b/extractor/rule_extractor_test.cc
@@ -0,0 +1,166 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_scorer.h"
+#include "mocks/mock_target_phrase_extractor.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule_extractor.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class RuleExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceStart(_))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*source_data_array, GetSentenceLength(_))
+ .WillRepeatedly(Return(10));
+
+ helper = make_shared<MockRuleExtractorHelper>();
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _))
+ .WillRepeatedly(Return(true));
+ unordered_map<int, int> source_indexes;
+ EXPECT_CALL(*helper, GetSourceIndexes(_, _, _))
+ .WillRepeatedly(Return(source_indexes));
+
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalValue(87))
+ .WillRepeatedly(Return("a"));
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ vector<int> symbols = {87};
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ PhraseAlignment phrase_alignment = {make_pair(0, 0)};
+
+ target_phrase_extractor = make_shared<MockTargetPhraseExtractor>();
+ vector<pair<Phrase, PhraseAlignment> > target_phrases = {
+ make_pair(target_phrase, phrase_alignment)
+ };
+ EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _))
+ .WillRepeatedly(Return(target_phrases));
+
+ scorer = make_shared<MockScorer>();
+ vector<double> scores = {0.3, 7.2};
+ EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores));
+
+ extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder,
+ scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false);
+ }
+
+ shared_ptr<MockDataArray> source_data_array;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<MockScorer> scorer;
+ shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractor> extractor;
+};
+
+TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ EXPECT_CALL(*helper, CheckTightPhrases(_, _, _))
+ .WillRepeatedly(Return(false));
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set FindFixPoint to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ // Set CheckGaps to return false.
+ vector<pair<int, int> > gaps;
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(0, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1);
+ vector<pair<int, int> > gaps(3);
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(1, rules.size());
+}
+
+TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) {
+ vector<int> symbols = {87};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> matching = {2};
+ PhraseLocation phrase_location(matching, 1);
+
+ vector<int> links(10, -1);
+ EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll(
+ SetArgReferee<0>(links),
+ SetArgReferee<1>(links),
+ SetArgReferee<2>(links),
+ SetArgReferee<3>(links)));
+
+ vector<pair<int, int> > gaps;
+ // Set FindFixPoint to return true. The number of gaps equals the number of
+ // nonterminals, so we won't add any extremities.
+ helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true);
+
+ vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location);
+ EXPECT_EQ(4, rules.size());
+}
+
+} // namespace
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index c22f9b48..374a0db1 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -1,6 +1,6 @@
#include "rule_factory.h"
-#include <cassert>
+#include <chrono>
#include <memory>
#include <queue>
#include <vector>
@@ -18,7 +18,9 @@
#include "vocabulary.h"
using namespace std;
-using namespace tr1;
+using namespace std::chrono;
+
+typedef high_resolution_clock Clock;
struct State {
State(int start, int end, const vector<int>& phrase,
@@ -68,8 +70,44 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
sampler = make_shared<Sampler>(source_suffix_array, max_samples);
}
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<Intersector> intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols) :
+ matchings_finder(finder),
+ intersector(intersector),
+ phrase_builder(phrase_builder),
+ rule_extractor(rule_extractor),
+ vocabulary(vocabulary),
+ sampler(sampler),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_chunks),
+ max_rule_symbols(max_rule_symbols) {}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory() {}
+
+HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+ intersector->binary_merge_time = 0;
+ intersector->linear_merge_time = 0;
+ intersector->sort_time = 0;
+ Clock::time_point start_time = Clock::now();
+ double total_extract_time = 0;
+ double total_intersect_time = 0;
+ double total_lookup_time = 0;
// Clear cache for every new sentence.
trie.Reset();
shared_ptr<TrieNode> root = trie.GetRoot();
@@ -107,34 +145,42 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
if (RequiresLookup(node, word_id)) {
- shared_ptr<TrieNode> next_suffix_link =
- node->suffix_link->GetChild(word_id);
+ shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ?
+ trie.GetRoot() : node->suffix_link->GetChild(word_id);
if (state.starts_with_x) {
// If the phrase starts with a non terminal, we simply use the matchings
// from the suffix link.
- next_node = shared_ptr<TrieNode>(new TrieNode(
- next_suffix_link, next_phrase, next_suffix_link->matchings));
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, next_suffix_link->matchings);
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
+ Clock::time_point intersect_start_time = Clock::now();
phrase_location = intersector->Intersect(
node->phrase,
node->matchings,
next_suffix_link->phrase,
next_suffix_link->matchings,
next_phrase);
+ Clock::time_point intersect_stop_time = Clock::now();
+ total_intersect_time += duration_cast<milliseconds>(
+ intersect_stop_time - intersect_start_time).count();
} else {
+ Clock::time_point lookup_start_time = Clock::now();
phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
+ Clock::time_point lookup_stop_time = Clock::now();
+ total_lookup_time += duration_cast<milliseconds>(
+ lookup_stop_time - lookup_start_time).count();
}
if (phrase_location.IsEmpty()) {
continue;
}
- next_node = shared_ptr<TrieNode>(new TrieNode(
- next_suffix_link, next_phrase, phrase_location));
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, phrase_location);
}
node->AddChild(word_id, next_node);
@@ -143,12 +189,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
AddTrailingNonterminal(phrase, next_phrase, next_node,
state.starts_with_x);
+ Clock::time_point extract_start_time = Clock::now();
if (!state.starts_with_x) {
PhraseLocation sample = sampler->Sample(next_node->matchings);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
+ Clock::time_point extract_stop_time = Clock::now();
+ total_extract_time += duration_cast<milliseconds>(
+ extract_stop_time - extract_start_time).count();
} else {
next_node = node->GetChild(word_id);
}
@@ -160,6 +210,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
}
+ Clock::time_point stop_time = Clock::now();
+ milliseconds ms = duration_cast<milliseconds>(stop_time - start_time);
+ cerr << "Total time for rule lookup, extraction, and scoring = "
+ << ms.count() / 1000.0 << endl;
+ cerr << "Extract time = " << total_extract_time / 1000.0 << endl;
+ cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl;
+ cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl;
+ cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl;
+ cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl;
+ // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl;
return Grammar(rules, scorer->GetFeatureNames());
}
@@ -192,12 +252,12 @@ void HieroCachingRuleFactory::AddTrailingNonterminal(
Phrase var_phrase = phrase_builder->Build(symbols);
int suffix_var_id = vocabulary->GetNonterminalIndex(
- prefix.Arity() + starts_with_x == 0);
+ prefix.Arity() + (starts_with_x == 0));
shared_ptr<TrieNode> var_suffix_link =
prefix_node->suffix_link->GetChild(suffix_var_id);
- prefix_node->AddChild(var_id, shared_ptr<TrieNode>(new TrieNode(
- var_suffix_link, var_phrase, prefix_node->matchings)));
+ prefix_node->AddChild(var_id, make_shared<TrieNode>(
+ var_suffix_link, var_phrase, prefix_node->matchings));
}
vector<State> HieroCachingRuleFactory::ExtendState(
@@ -216,7 +276,7 @@ vector<State> HieroCachingRuleFactory::ExtendState(
new_states.push_back(State(state.start, state.end + 1, symbols,
state.subpatterns_start, node, state.starts_with_x));
- int num_subpatterns = phrase.Arity() + state.starts_with_x == 0;
+ int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);
if (symbols.size() + 1 >= max_rule_symbols ||
phrase.Arity() >= max_nonterminals ||
num_subpatterns >= max_chunks) {
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index a47b6d16..cf344667 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -40,7 +40,27 @@ class HieroCachingRuleFactory {
bool use_beaza_yates,
bool require_tight_phrases);
- Grammar GetGrammar(const vector<int>& word_ids);
+ // For testing only.
+ HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<Intersector> intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols);
+
+ virtual ~HieroCachingRuleFactory();
+
+ virtual Grammar GetGrammar(const vector<int>& word_ids);
+
+ protected:
+ HieroCachingRuleFactory();
private:
bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id);
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
new file mode 100644
index 00000000..d6fbab74
--- /dev/null
+++ b/extractor/rule_factory_test.cc
@@ -0,0 +1,98 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "grammar.h"
+#include "mocks/mock_intersector.h"
+#include "mocks/mock_matchings_finder.h"
+#include "mocks/mock_rule_extractor.h"
+#include "mocks/mock_sampler.h"
+#include "mocks/mock_scorer.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule_factory.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class RuleFactoryTest : public Test {
+ protected:
+ virtual void SetUp() {
+ finder = make_shared<MockMatchingsFinder>();
+ intersector = make_shared<MockIntersector>();
+
+ vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a"));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(3)).WillRepeatedly(Return("b"));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(4)).WillRepeatedly(Return("c"));
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+
+ scorer = make_shared<MockScorer>();
+ feature_names = {"f1"};
+ EXPECT_CALL(*scorer, GetFeatureNames())
+ .WillRepeatedly(Return(feature_names));
+
+ sampler = make_shared<MockSampler>();
+ EXPECT_CALL(*sampler, Sample(_))
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ Phrase phrase;
+ vector<double> scores = {0.5};
+ vector<pair<int, int> > phrase_alignment = {make_pair(0, 0)};
+ vector<Rule> rules = {Rule(phrase, phrase, scores, phrase_alignment)};
+ extractor = make_shared<MockRuleExtractor>();
+ EXPECT_CALL(*extractor, ExtractRules(_, _))
+ .WillRepeatedly(Return(rules));
+
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5);
+ }
+
+ vector<string> feature_names;
+ shared_ptr<MockMatchingsFinder> finder;
+ shared_ptr<MockIntersector> intersector;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockScorer> scorer;
+ shared_ptr<MockSampler> sampler;
+ shared_ptr<MockRuleExtractor> extractor;
+ shared_ptr<HieroCachingRuleFactory> factory;
+};
+
+TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(6)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _))
+ .Times(1)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ vector<int> word_ids = {2, 3, 4};
+ Grammar grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(7, grammar.GetRules().size());
+}
+
+TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(12)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _))
+ .Times(16)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ vector<int> word_ids = {2, 3, 4, 2, 3};
+ Grammar grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(28, grammar.GetRules().size());
+}
+
+} // namespace
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 37a9cba0..ed30e6fe 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -114,8 +114,8 @@ int main(int argc, char** argv) {
make_shared<TargetGivenSourceCoherent>(),
make_shared<SampleSourceCount>(),
make_shared<CountSourceTarget>(),
- make_shared<MaxLexTargetGivenSource>(table),
make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
make_shared<IsSourceSingleton>(),
make_shared<IsSourceTargetSingleton>()
};
@@ -138,6 +138,10 @@ int main(int argc, char** argv) {
int grammar_id = 0;
fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
string sentence, delimiter = "|||";
while (getline(cin, sentence)) {
string suffix = "";
@@ -148,7 +152,8 @@ int main(int argc, char** argv) {
}
Grammar grammar = extractor.GetGrammar(sentence);
- fs::path grammar_file = grammar_path / to_string(grammar_id);
+ string file_name = "grammar." + to_string(grammar_id);
+ fs::path grammar_file = grammar_path / file_name;
ofstream output(grammar_file.c_str());
output << grammar;
diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt
new file mode 100644
index 00000000..80b446a4
--- /dev/null
+++ b/extractor/sample_alignment.txt
@@ -0,0 +1,2 @@
+0-0 1-1 2-2
+1-0 2-1
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index d8e0f49e..5067ca8a 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -6,6 +6,10 @@
Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) :
suffix_array(suffix_array), max_samples(max_samples) {}
+Sampler::Sampler() {}
+
+Sampler::~Sampler() {}
+
PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
vector<int> sample;
int num_subpatterns;
@@ -32,5 +36,6 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
}
int Sampler::Round(double x) const {
- return x + 0.5;
+ // TODO(pauldb): Remove EPS.
+ return x + 0.5 + 1e-8;
}
diff --git a/extractor/sampler.h b/extractor/sampler.h
index 3b3e3a4d..9cf321fb 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -12,7 +12,12 @@ class Sampler {
public:
Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
- PhraseLocation Sample(const PhraseLocation& location) const;
+ virtual ~Sampler();
+
+ virtual PhraseLocation Sample(const PhraseLocation& location) const;
+
+ protected:
+ Sampler();
private:
int Round(double x) const;
diff --git a/extractor/scorer.cc b/extractor/scorer.cc
index c87e179d..f28b3181 100644
--- a/extractor/scorer.cc
+++ b/extractor/scorer.cc
@@ -5,6 +5,10 @@
Scorer::Scorer(const vector<shared_ptr<Feature> >& features) :
features(features) {}
+Scorer::Scorer() {}
+
+Scorer::~Scorer() {}
+
vector<double> Scorer::Score(const FeatureContext& context) const {
vector<double> scores;
for (auto feature: features) {
diff --git a/extractor/scorer.h b/extractor/scorer.h
index 5b328fb4..ba71a6ee 100644
--- a/extractor/scorer.h
+++ b/extractor/scorer.h
@@ -14,9 +14,14 @@ class Scorer {
public:
Scorer(const vector<shared_ptr<Feature> >& features);
- vector<double> Score(const FeatureContext& context) const;
+ virtual ~Scorer();
- vector<string> GetFeatureNames() const;
+ virtual vector<double> Score(const FeatureContext& context) const;
+
+ virtual vector<string> GetFeatureNames() const;
+
+ protected:
+ Scorer();
private:
vector<shared_ptr<Feature> > features;
diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc
new file mode 100644
index 00000000..56a85762
--- /dev/null
+++ b/extractor/scorer_test.cc
@@ -0,0 +1,47 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "mocks/mock_feature.h"
+#include "scorer.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class ScorerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ feature1 = make_shared<MockFeature>();
+ EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5));
+ EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1"));
+
+ feature2 = make_shared<MockFeature>();
+ EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3));
+ EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2"));
+
+ vector<shared_ptr<Feature> > features = {feature1, feature2};
+ scorer = make_shared<Scorer>(features);
+ }
+
+ shared_ptr<MockFeature> feature1;
+ shared_ptr<MockFeature> feature2;
+ shared_ptr<Scorer> scorer;
+};
+
+TEST_F(ScorerTest, TestScore) {
+ vector<double> expected_scores = {0.5, -1.3};
+ Phrase phrase;
+ FeatureContext context(phrase, phrase, 0.3, 2, 11);
+ EXPECT_EQ(expected_scores, scorer->Score(context));
+}
+
+TEST_F(ScorerTest, TestGetNames) {
+ vector<string> expected_names = {"f1", "f2"};
+ EXPECT_EQ(expected_names, scorer->GetFeatureNames());
+}
+
+} // namespace
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index d13eacd5..9815996f 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -22,9 +22,9 @@ SuffixArray::~SuffixArray() {}
void SuffixArray::BuildSuffixArray() {
vector<int> groups = data_array->GetData();
groups.reserve(groups.size() + 1);
- groups.push_back(data_array->GetVocabularySize());
+ groups.push_back(DataArray::NULL_WORD);
suffix_array.resize(groups.size());
- word_start.resize(data_array->GetVocabularySize() + 2);
+ word_start.resize(data_array->GetVocabularySize() + 1);
InitialBucketSort(groups);
@@ -112,6 +112,8 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step,
}
}
+ TernaryQuicksort(left, mid_left - 1, step, groups);
+
if (mid_left == mid_right) {
groups[suffix_array[mid_left]] = mid_left;
suffix_array[mid_left] = -1;
@@ -121,7 +123,6 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step,
}
}
- TernaryQuicksort(left, mid_left - 1, step, groups);
TernaryQuicksort(mid_right + 1, right, step, groups);
}
@@ -201,7 +202,7 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id,
int result = high;
while (low < high) {
int middle = low + (high - low) / 2;
- if (suffix_array[middle] + offset < data_array->GetSize() &&
+ if (suffix_array[middle] + offset >= data_array->GetSize() ||
data_array->AtIndex(suffix_array[middle] + offset) < word_id) {
low = middle + 1;
} else {
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index d891933c..60295567 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -14,10 +14,10 @@ namespace {
class SuffixArrayTest : public Test {
protected:
virtual void SetUp() {
- data = vector<int>{5, 3, 0, 1, 3, 4, 2, 3, 5, 5, 3, 0, 1};
+ data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};
data_array = make_shared<MockDataArray>();
EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
- EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(6));
+ EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
suffix_array = make_shared<SuffixArray>(data_array);
}
@@ -33,14 +33,15 @@ TEST_F(SuffixArrayTest, TestData) {
}
TEST_F(SuffixArrayTest, TestBuildSuffixArray) {
- vector<int> expected_suffix_array{2, 11, 3, 12, 6, 1, 10, 4, 7, 5, 0, 9, 8};
+ vector<int> expected_suffix_array =
+ {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8};
for (size_t i = 0; i < expected_suffix_array.size(); ++i) {
EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i));
}
}
TEST_F(SuffixArrayTest, TestBuildLCP) {
- vector<int> expected_lcp{-1, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1, 0};
+ vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1};
EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray());
}
@@ -50,26 +51,26 @@ TEST_F(SuffixArrayTest, TestLookup) {
}
EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
- EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(5));
- EXPECT_EQ(PhraseLocation(10, 13), suffix_array->Lookup(0, 14, "word1", 0));
+ EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
+ EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0));
EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0));
EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
- EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(3));
- EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 13, "word3", 1));
+ EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1));
EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
- EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(0));
- EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word4", 2));
+ EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2));
EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
- EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(1));
- EXPECT_EQ(PhraseLocation(10, 12), suffix_array->Lookup(10, 12, "word5", 3));
+ EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3));
- EXPECT_EQ(PhraseLocation(10, 11), suffix_array->Lookup(10, 12, "word3", 4));
- EXPECT_EQ(PhraseLocation(10, 10), suffix_array->Lookup(10, 12, "word5", 1));
+ EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4));
+ EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1));
}
} // namespace
diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc
new file mode 100644
index 00000000..ac583953
--- /dev/null
+++ b/extractor/target_phrase_extractor.cc
@@ -0,0 +1,144 @@
+#include "target_phrase_extractor.h"
+
+#include <unordered_set>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "rule_extractor_helper.h"
+#include "vocabulary.h"
+
+using namespace std;
+
+TargetPhraseExtractor::TargetPhraseExtractor(
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractorHelper> helper,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ bool require_tight_phrases) :
+ target_data_array(target_data_array),
+ alignment(alignment),
+ phrase_builder(phrase_builder),
+ helper(helper),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ require_tight_phrases(require_tight_phrases) {}
+
+TargetPhraseExtractor::TargetPhraseExtractor() {}
+
+TargetPhraseExtractor::~TargetPhraseExtractor() {}
+
+vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const {
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+
+ vector<int> target_gap_order = helper->GetGapOrder(target_gaps);
+
+ int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
+ if (!require_tight_phrases) {
+ while (target_x_low > 0 &&
+ target_phrase_high - target_x_low < max_rule_span &&
+ target_low[target_x_low - 1] == -1) {
+ --target_x_low;
+ }
+ while (target_x_high < target_sent_len &&
+ target_x_high - target_phrase_low < max_rule_span &&
+ target_low[target_x_high] == -1) {
+ ++target_x_high;
+ }
+ }
+
+ vector<pair<int, int> > gaps(target_gaps.size());
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ gaps[i] = target_gaps[target_gap_order[i]];
+ if (!require_tight_phrases) {
+ while (gaps[i].first > target_x_low &&
+ target_low[gaps[i].first - 1] == -1) {
+ --gaps[i].first;
+ }
+ while (gaps[i].second < target_x_high &&
+ target_low[gaps[i].second] == -1) {
+ ++gaps[i].second;
+ }
+ }
+ }
+
+ vector<pair<int, int> > ranges(2 * gaps.size() + 2);
+ ranges.front() = make_pair(target_x_low, target_phrase_low);
+ ranges.back() = make_pair(target_phrase_high, target_x_high);
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ int j = target_gap_order[i];
+ ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first);
+ ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second);
+ }
+
+ vector<pair<Phrase, PhraseAlignment> > target_phrases;
+ vector<int> subpatterns(ranges.size());
+ GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order,
+ target_phrase_low, target_phrase_high, source_indexes,
+ sentence_id);
+ return target_phrases;
+}
+
+void TargetPhraseExtractor::GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns,
+ const vector<int>& target_gap_order, int target_phrase_low,
+ int target_phrase_high, const unordered_map<int, int>& source_indexes,
+ int sentence_id) const {
+ if (index >= ranges.size()) {
+ if (subpatterns.back() - subpatterns.front() > max_rule_span) {
+ return;
+ }
+
+ vector<int> symbols;
+ unordered_map<int, int> target_indexes;
+
+ int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
+ for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
+ target_indexes[j] = symbols.size();
+ string target_word = target_data_array->GetWordAtIndex(
+ target_sent_start + j);
+ symbols.push_back(vocabulary->GetTerminalIndex(target_word));
+ }
+ if (i < target_gap_order.size()) {
+ symbols.push_back(vocabulary->GetNonterminalIndex(
+ target_gap_order[i] + 1));
+ }
+ }
+
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ vector<pair<int, int> > alignment;
+ for (pair<int, int> link: links) {
+ if (target_indexes.count(link.second)) {
+ alignment.push_back(make_pair(source_indexes.find(link.first)->second,
+ target_indexes[link.second]));
+ }
+ }
+
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ target_phrases.push_back(make_pair(target_phrase, alignment));
+ return;
+ }
+
+ subpatterns[index] = ranges[index].first;
+ if (index > 0) {
+ subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
+ }
+ while (subpatterns[index] <= ranges[index].second) {
+ subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
+ while (subpatterns[index + 1] <= ranges[index + 1].second) {
+ GeneratePhrases(target_phrases, ranges, index + 2, subpatterns,
+ target_gap_order, target_phrase_low, target_phrase_high,
+ source_indexes, sentence_id);
+ ++subpatterns[index + 1];
+ }
+ ++subpatterns[index];
+ }
+}
diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h
new file mode 100644
index 00000000..134f24cc
--- /dev/null
+++ b/extractor/target_phrase_extractor.h
@@ -0,0 +1,56 @@
+#ifndef _TARGET_PHRASE_EXTRACTOR_H_
+#define _TARGET_PHRASE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+using namespace std;
+
+class Alignment;
+class DataArray;
+class Phrase;
+class PhraseBuilder;
+class RuleExtractorHelper;
+class Vocabulary;
+
+typedef vector<pair<int, int> > PhraseAlignment;
+
+class TargetPhraseExtractor {
+ public:
+ TargetPhraseExtractor(shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractorHelper> helper,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ bool require_tight_phrases);
+
+ virtual ~TargetPhraseExtractor();
+
+ virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const;
+
+ protected:
+ TargetPhraseExtractor();
+
+ private:
+ void GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index,
+ vector<int>& subpatterns, const vector<int>& target_gap_order,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const;
+
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<Alignment> alignment;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<RuleExtractorHelper> helper;
+ shared_ptr<Vocabulary> vocabulary;
+ int max_rule_span;
+ bool require_tight_phrases;
+};
+
+#endif
diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc
new file mode 100644
index 00000000..7394f4d9
--- /dev/null
+++ b/extractor/target_phrase_extractor_test.cc
@@ -0,0 +1,116 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_rule_extractor_helper.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "target_phrase_extractor.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class TargetPhraseExtractorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data_array = make_shared<MockDataArray>();
+ alignment = make_shared<MockAlignment>();
+ vocabulary = make_shared<MockVocabulary>();
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ helper = make_shared<MockRuleExtractorHelper>();
+ }
+
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockAlignment> alignment;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<MockRuleExtractorHelper> helper;
+ shared_ptr<TargetPhraseExtractor> extractor;
+};
+
+TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) {
+ EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5));
+ EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3));
+
+ vector<string> target_words = {"a", "b", "c", "d", "e"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24};
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i + 3))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {
+ make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1),
+ make_pair(4, 4)
+ };
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {1, 0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, true);
+
+ vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)};
+ vector<int> target_low = {0, 3, 2, 1, 4};
+ unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 0, 5, source_indexes, 1);
+ EXPECT_EQ(1, results.size());
+ vector<int> expected_symbols = {20, -2, 22, -1, 24};
+ EXPECT_EQ(expected_symbols, results[0].first.Get());
+ vector<string> expected_words = {"a", "c", "e"};
+ EXPECT_EQ(expected_words, results[0].first.GetWords());
+ vector<pair<int, int> > expected_alignment = {
+ make_pair(0, 0), make_pair(2, 2), make_pair(4, 4)
+ };
+ EXPECT_EQ(expected_alignment, results[0].second);
+}
+
+TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) {
+ vector<string> target_words = {"a", "b", "c", "d", "e", "f"};
+ vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 26};
+ EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6));
+ EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0));
+
+ for (size_t i = 0; i < target_words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordAtIndex(i))
+ .WillRepeatedly(Return(target_words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i]))
+ .WillRepeatedly(Return(target_symbols[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i]))
+ .WillRepeatedly(Return(target_words[i]));
+ }
+
+ vector<pair<int, int> > links = {make_pair(1, 1)};
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links));
+
+ vector<int> gap_order = {0};
+ EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order));
+
+ extractor = make_shared<TargetPhraseExtractor>(
+ data_array, alignment, phrase_builder, helper, vocabulary, 10, false);
+
+ vector<pair<int, int> > target_gaps = {make_pair(2, 4)};
+ vector<int> target_low = {-1, 1, -1, -1, -1, -1};
+ unordered_map<int, int> source_indexes = {{1, 1}};
+
+ vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases(
+ target_gaps, target_low, 1, 5, source_indexes, 0);
+ EXPECT_EQ(10, results.size());
+ // TODO(pauldb): Finish unit test once it's clear how these alignments should
+ // look like.
+}
+
+} // namespace
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 10f1b9ed..a48c0657 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -9,7 +9,6 @@
#include "data_array.h"
using namespace std;
-using namespace tr1;
TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
@@ -20,14 +19,15 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
unordered_map<int, int> source_links_count;
unordered_map<int, int> target_links_count;
- unordered_map<pair<int, int>, int, PairHash > links_count;
+ unordered_map<pair<int, int>, int, PairHash> links_count;
for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) {
- const vector<pair<int, int> >& links = alignment->GetLinks(i);
+ vector<pair<int, int> > links = alignment->GetLinks(i);
int source_start = source_data_array->GetSentenceStart(i);
- int next_source_start = source_data_array->GetSentenceStart(i + 1);
int target_start = target_data_array->GetSentenceStart(i);
- int next_target_start = target_data_array->GetSentenceStart(i + 1);
+ // Ignore END_OF_LINE markers.
+ int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1;
+ int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1;
vector<int> source_sentence(source_data.begin() + source_start,
source_data.begin() + next_source_start);
vector<int> target_sentence(target_data.begin() + target_start,
@@ -38,15 +38,23 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
for (pair<int, int> link: links) {
source_linked_words[link.first] = 1;
target_linked_words[link.second] = 1;
- int source_word = source_sentence[link.first];
- int target_word = target_sentence[link.second];
+ IncreaseLinksCount(source_links_count, target_links_count, links_count,
+ source_sentence[link.first], target_sentence[link.second]);
+ }
- ++source_links_count[source_word];
- ++target_links_count[target_word];
- ++links_count[make_pair(source_word, target_word)];
+ for (size_t i = 0; i < source_sentence.size(); ++i) {
+ if (!source_linked_words[i]) {
+ IncreaseLinksCount(source_links_count, target_links_count, links_count,
+ source_sentence[i], DataArray::NULL_WORD);
+ }
}
- // TODO(pauldb): Something seems wrong here. No NULL word?
+ for (size_t i = 0; i < target_sentence.size(); ++i) {
+ if (!target_linked_words[i]) {
+ IncreaseLinksCount(source_links_count, target_links_count, links_count,
+ DataArray::NULL_WORD, target_sentence[i]);
+ }
+ }
}
for (pair<pair<int, int>, int> link_count: links_count) {
@@ -58,6 +66,21 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
}
}
+TranslationTable::TranslationTable() {}
+
+TranslationTable::~TranslationTable() {}
+
+void TranslationTable::IncreaseLinksCount(
+ unordered_map<int, int>& source_links_count,
+ unordered_map<int, int>& target_links_count,
+ unordered_map<pair<int, int>, int, PairHash>& links_count,
+ int source_word_id,
+ int target_word_id) const {
+ ++source_links_count[source_word_id];
+ ++target_links_count[target_word_id];
+ ++links_count[make_pair(source_word_id, target_word_id)];
+}
+
double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
if (!source_data_array->HasWord(source_word) ||
@@ -73,7 +96,7 @@ double TranslationTable::GetTargetGivenSourceScore(
double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
if (!source_data_array->HasWord(source_word) ||
- !target_data_array->HasWord(target_word) == 0) {
+ !target_data_array->HasWord(target_word)) {
return -1;
}
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index acf94af7..157ad3af 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -3,13 +3,12 @@
#include <memory>
#include <string>
-#include <tr1/unordered_map>
+#include <unordered_map>
#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
using namespace std;
-using namespace tr1;
namespace fs = boost::filesystem;
class Alignment;
@@ -24,15 +23,27 @@ class TranslationTable {
shared_ptr<DataArray> target_data_array,
shared_ptr<Alignment> alignment);
- double GetTargetGivenSourceScore(const string& source_word,
- const string& target_word);
+ virtual ~TranslationTable();
- double GetSourceGivenTargetScore(const string& source_word,
- const string& target_word);
+ virtual double GetTargetGivenSourceScore(const string& source_word,
+ const string& target_word);
+
+ virtual double GetSourceGivenTargetScore(const string& source_word,
+ const string& target_word);
void WriteBinary(const fs::path& filepath) const;
+ protected:
+ TranslationTable();
+
private:
+ void IncreaseLinksCount(
+ unordered_map<int, int>& source_links_count,
+ unordered_map<int, int>& target_links_count,
+ unordered_map<pair<int, int>, int, PairHash>& links_count,
+ int source_word_id,
+ int target_word_id) const;
+
shared_ptr<DataArray> source_data_array;
shared_ptr<DataArray> target_data_array;
unordered_map<pair<int, int>, pair<double, double>, PairHash>
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
new file mode 100644
index 00000000..c99f3f93
--- /dev/null
+++ b/extractor/translation_table_test.cc
@@ -0,0 +1,82 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "mocks/mock_alignment.h"
+#include "mocks/mock_data_array.h"
+#include "translation_table.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+TEST(TranslationTableTest, TestScores) {
+ vector<string> words = {"a", "b", "c"};
+
+ vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0};
+ vector<int> source_sentence_start = {0, 6, 10, 14};
+ shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetData())
+ .WillRepeatedly(ReturnRef(source_data));
+ EXPECT_CALL(*source_data_array, GetNumSentences())
+ .WillRepeatedly(Return(3));
+ for (size_t i = 0; i < source_sentence_start.size(); ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(source_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*source_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*source_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*source_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
+ vector<int> target_sentence_start = {0, 7, 10, 13};
+ shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetData())
+ .WillRepeatedly(ReturnRef(target_data));
+ for (size_t i = 0; i < target_sentence_start.size(); ++i) {
+ EXPECT_CALL(*target_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(target_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*target_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*target_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*target_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<pair<int, int> > links1 = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
+ make_pair(4, 4), make_pair(4, 5)
+ };
+ vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)};
+ vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)};
+ shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1));
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2));
+ EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3));
+
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+
+ EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a"));
+ EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b"));
+ EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c"));
+ EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d"));
+
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a"));
+ EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b"));
+ EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c"));
+ EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d"));
+}
+
+} // namespace
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index ed55e5e4..c6a8b3e8 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -2,11 +2,10 @@
#define _VOCABULARY_H_
#include <string>
-#include <tr1/unordered_map>
+#include <unordered_map>
#include <vector>
using namespace std;
-using namespace tr1;
class Vocabulary {
public: