summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am23
-rw-r--r--extractor/alignment.cc2
-rw-r--r--extractor/alignment.h2
-rw-r--r--extractor/binary_search_merger.cc4
-rw-r--r--extractor/binary_search_merger.h7
-rw-r--r--extractor/binary_search_merger_test.cc4
-rw-r--r--extractor/compile.cc5
-rw-r--r--extractor/data_array.h7
-rw-r--r--extractor/features/count_source_target.cc11
-rw-r--r--extractor/features/count_source_target.h13
-rw-r--r--extractor/features/feature.cc3
-rw-r--r--extractor/features/feature.h32
-rw-r--r--extractor/features/is_source_singleton.cc11
-rw-r--r--extractor/features/is_source_singleton.h13
-rw-r--r--extractor/features/is_source_target_singleton.cc11
-rw-r--r--extractor/features/is_source_target_singleton.h13
-rw-r--r--extractor/features/max_lex_source_given_target.cc30
-rw-r--r--extractor/features/max_lex_source_given_target.h24
-rw-r--r--extractor/features/max_lex_target_given_source.cc30
-rw-r--r--extractor/features/max_lex_target_given_source.h24
-rw-r--r--extractor/features/sample_source_count.cc11
-rw-r--r--extractor/features/sample_source_count.h13
-rw-r--r--extractor/features/target_given_source_coherent.cc12
-rw-r--r--extractor/features/target_given_source_coherent.h13
-rw-r--r--extractor/grammar.cc24
-rw-r--r--extractor/grammar.h23
-rw-r--r--extractor/grammar_extractor.cc20
-rw-r--r--extractor/grammar_extractor.h15
-rw-r--r--extractor/intersector.cc50
-rw-r--r--extractor/intersector.h22
-rw-r--r--extractor/intersector_test.cc193
-rw-r--r--extractor/linear_merger.cc2
-rw-r--r--extractor/linear_merger.h3
-rw-r--r--extractor/mocks/mock_binary_search_merger.h15
-rw-r--r--extractor/mocks/mock_data_array.h1
-rw-r--r--extractor/mocks/mock_linear_merger.h10
-rw-r--r--extractor/mocks/mock_precomputation.h9
-rw-r--r--extractor/mocks/mock_suffix_array.h6
-rw-r--r--extractor/mocks/mock_vocabulary.h1
-rw-r--r--extractor/phrase.cc29
-rw-r--r--extractor/phrase.h12
-rw-r--r--extractor/phrase_builder.cc24
-rw-r--r--extractor/phrase_builder.h2
-rw-r--r--extractor/phrase_location.cc10
-rw-r--r--extractor/phrase_location.h2
-rw-r--r--extractor/precomputation.cc18
-rw-r--r--extractor/precomputation.h15
-rw-r--r--extractor/precomputation_test.cc138
-rw-r--r--extractor/rule.cc10
-rw-r--r--extractor/rule.h20
-rw-r--r--extractor/rule_extractor.cc675
-rw-r--r--extractor/rule_extractor.h120
-rw-r--r--extractor/rule_factory.cc56
-rw-r--r--extractor/rule_factory.h31
-rw-r--r--extractor/run_extractor.cc70
-rw-r--r--extractor/sampler.cc36
-rw-r--r--extractor/sampler.h24
-rw-r--r--extractor/sampler_test.cc72
-rw-r--r--extractor/scorer.cc21
-rw-r--r--extractor/scorer.h16
-rw-r--r--extractor/suffix_array.cc2
-rw-r--r--extractor/suffix_array.h9
-rw-r--r--extractor/translation_table.cc10
-rw-r--r--extractor/translation_table.h18
-rw-r--r--extractor/vocabulary.h2
65 files changed, 2005 insertions, 149 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 844c0ef3..ded06239 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -3,22 +3,27 @@ bin_PROGRAMS = compile run_extractor
noinst_PROGRAMS = \
binary_search_merger_test \
data_array_test \
+ intersector_test \
linear_merger_test \
matching_comparator_test \
matching_test \
matchings_finder_test \
phrase_test \
precomputation_test \
+ sampler_test \
suffix_array_test \
veb_test
-TESTS = precomputation_test
+TESTS = sampler_test
#TESTS = binary_search_merger_test \
# data_array_test \
+# intersector_test \
# linear_merger_test \
# matching_comparator_test \
# matching_test \
+# matchings_finder_test \
# phrase_test \
+# precomputation_test \
# suffix_array_test \
# veb_test
@@ -26,6 +31,8 @@ binary_search_merger_test_SOURCES = binary_search_merger_test.cc
binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
data_array_test_SOURCES = data_array_test.cc
data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+intersector_test_SOURCES = intersector_test.cc
+intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
linear_merger_test_SOURCES = linear_merger_test.cc
linear_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
matching_comparator_test_SOURCES = matching_comparator_test.cc
@@ -40,6 +47,8 @@ precomputation_test_SOURCES = precomputation_test.cc
precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
suffix_array_test_SOURCES = suffix_array_test.cc
suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+sampler_test_SOURCES = sampler_test.cc
+sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
veb_test_SOURCES = veb_test.cc
veb_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
@@ -62,6 +71,15 @@ libextractor_a_SOURCES = \
alignment.cc \
binary_search_merger.cc \
data_array.cc \
+ features/count_source_target.cc \
+ features/feature.cc \
+ features/is_source_singleton.cc \
+ features/is_source_target_singleton.cc \
+ features/max_lex_source_given_target.cc \
+ features/max_lex_target_given_source.cc \
+ features/sample_source_count.cc \
+ features/target_given_source_coherent.cc \
+ grammar.cc \
grammar_extractor.cc \
matching.cc \
matching_comparator.cc \
@@ -73,8 +91,11 @@ libextractor_a_SOURCES = \
phrase_builder.cc \
phrase_location.cc \
precomputation.cc \
+ rule.cc \
rule_extractor.cc \
rule_factory.cc \
+ sampler.cc \
+ scorer.cc \
suffix_array.cc \
translation_table.cc \
veb.cc \
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index cad28a72..2fa0abac 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -31,7 +31,7 @@ Alignment::Alignment(const string& filename) {
alignments.shrink_to_fit();
}
-vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const {
+const vector<pair<int, int> >& Alignment::GetLinks(int sentence_index) const {
return alignments[sentence_index];
}
diff --git a/extractor/alignment.h b/extractor/alignment.h
index e357e468..290d6015 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -13,7 +13,7 @@ class Alignment {
public:
Alignment(const string& filename);
- vector<pair<int, int> > GetLinks(int sentence_index) const;
+ const vector<pair<int, int> >& GetLinks(int sentence_index) const;
void WriteBinary(const fs::path& filepath);
diff --git a/extractor/binary_search_merger.cc b/extractor/binary_search_merger.cc
index 7b018876..43d2f734 100644
--- a/extractor/binary_search_merger.cc
+++ b/extractor/binary_search_merger.cc
@@ -19,6 +19,10 @@ BinarySearchMerger::BinarySearchMerger(
data_array(data_array), comparator(comparator),
force_binary_search_merge(force_binary_search_merge) {}
+BinarySearchMerger::BinarySearchMerger() {}
+
+BinarySearchMerger::~BinarySearchMerger() {}
+
void BinarySearchMerger::Merge(
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
diff --git a/extractor/binary_search_merger.h b/extractor/binary_search_merger.h
index 0e229b3b..ffa47c8e 100644
--- a/extractor/binary_search_merger.h
+++ b/extractor/binary_search_merger.h
@@ -20,7 +20,9 @@ class BinarySearchMerger {
shared_ptr<MatchingComparator> comparator,
bool force_binary_search_merge = false);
- void Merge(
+ virtual ~BinarySearchMerger();
+
+ virtual void Merge(
vector<int>& locations, const Phrase& phrase, const Phrase& suffix,
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
@@ -28,6 +30,9 @@ class BinarySearchMerger {
static double BAEZA_YATES_FACTOR;
+ protected:
+ BinarySearchMerger();
+
private:
bool IsIntersectionVoid(
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
diff --git a/extractor/binary_search_merger_test.cc b/extractor/binary_search_merger_test.cc
index 20350b1e..b1baa62f 100644
--- a/extractor/binary_search_merger_test.cc
+++ b/extractor/binary_search_merger_test.cc
@@ -34,8 +34,8 @@ class BinarySearchMergerTest : public Test {
// We are going to force the binary_search_merger to do all the work, so we
// need to check that the linear_merger never gets called.
- shared_ptr<MockLinearMerger> linear_merger = make_shared<MockLinearMerger>(
- vocabulary, data_array, comparator);
+ shared_ptr<MockLinearMerger> linear_merger =
+ make_shared<MockLinearMerger>();
EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0);
binary_search_merger = make_shared<BinarySearchMerger>(
diff --git a/extractor/compile.cc b/extractor/compile.cc
index c3ea3c8d..f5cd41f4 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -77,8 +77,9 @@ int main(int argc, char** argv) {
source_suffix_array->WriteBinary(output_dir / fs::path("f.bin"));
target_data_array->WriteBinary(output_dir / fs::path("e.bin"));
- Alignment alignment(vm["alignment"].as<string>());
- alignment.WriteBinary(output_dir / fs::path("a.bin"));
+ shared_ptr<Alignment> alignment =
+ make_shared<Alignment>(vm["alignment"].as<string>());
+ alignment->WriteBinary(output_dir / fs::path("a.bin"));
Precomputation precomputation(
source_suffix_array,
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 6d3e99d5..19fbff88 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -23,8 +23,6 @@ class DataArray {
static string END_OF_FILE_STR;
static string END_OF_LINE_STR;
- DataArray();
-
DataArray(const string& filename);
DataArray(const string& filename, const Side& side);
@@ -43,7 +41,7 @@ class DataArray {
virtual int GetWordId(const string& word) const;
- string GetWord(int word_id) const;
+ virtual string GetWord(int word_id) const;
int GetNumSentences() const;
@@ -55,6 +53,9 @@ class DataArray {
void WriteBinary(FILE* file) const;
+ protected:
+ DataArray();
+
private:
void InitializeDataArray();
void CreateDataArray(const vector<string>& lines);
diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc
new file mode 100644
index 00000000..9441b451
--- /dev/null
+++ b/extractor/features/count_source_target.cc
@@ -0,0 +1,11 @@
+#include "count_source_target.h"
+
+#include <cmath>
+
+double CountSourceTarget::Score(const FeatureContext& context) const {
+ return log10(1 + context.pair_count);
+}
+
+string CountSourceTarget::GetName() const {
+ return "CountEF";
+}
diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h
new file mode 100644
index 00000000..a2481944
--- /dev/null
+++ b/extractor/features/count_source_target.h
@@ -0,0 +1,13 @@
+#ifndef _COUNT_SOURCE_TARGET_H_
+#define _COUNT_SOURCE_TARGET_H_
+
+#include "feature.h"
+
+class CountSourceTarget : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+#endif
diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc
new file mode 100644
index 00000000..7381c35a
--- /dev/null
+++ b/extractor/features/feature.cc
@@ -0,0 +1,3 @@
+#include "feature.h"
+
+const double Feature::MAX_SCORE = 99.0;
diff --git a/extractor/features/feature.h b/extractor/features/feature.h
new file mode 100644
index 00000000..ad22d3e7
--- /dev/null
+++ b/extractor/features/feature.h
@@ -0,0 +1,32 @@
+#ifndef _FEATURE_H_
+#define _FEATURE_H_
+
+#include <string>
+
+//TODO(pauldb): include headers nicely.
+#include "../phrase.h"
+
+using namespace std;
+
+struct FeatureContext {
+ FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase,
+ double sample_source_count, int pair_count) :
+ source_phrase(source_phrase), target_phrase(target_phrase),
+ sample_source_count(sample_source_count), pair_count(pair_count) {}
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ double sample_source_count;
+ int pair_count;
+};
+
+class Feature {
+ public:
+ virtual double Score(const FeatureContext& context) const = 0;
+
+ virtual string GetName() const = 0;
+
+ static const double MAX_SCORE;
+};
+
+#endif
diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc
new file mode 100644
index 00000000..754df3bf
--- /dev/null
+++ b/extractor/features/is_source_singleton.cc
@@ -0,0 +1,11 @@
+#include "is_source_singleton.h"
+
+#include <cmath>
+
+double IsSourceSingleton::Score(const FeatureContext& context) const {
+ return context.sample_source_count == 1;
+}
+
+string IsSourceSingleton::GetName() const {
+ return "IsSingletonF";
+}
diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h
new file mode 100644
index 00000000..7cc72828
--- /dev/null
+++ b/extractor/features/is_source_singleton.h
@@ -0,0 +1,13 @@
+#ifndef _IS_SOURCE_SINGLETON_H_
+#define _IS_SOURCE_SINGLETON_H_
+
+#include "feature.h"
+
+class IsSourceSingleton : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+#endif
diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc
new file mode 100644
index 00000000..ec816509
--- /dev/null
+++ b/extractor/features/is_source_target_singleton.cc
@@ -0,0 +1,11 @@
+#include "is_source_target_singleton.h"
+
+#include <cmath>
+
+double IsSourceTargetSingleton::Score(const FeatureContext& context) const {
+ return context.pair_count == 1;
+}
+
+string IsSourceTargetSingleton::GetName() const {
+ return "IsSingletonEF";
+}
diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h
new file mode 100644
index 00000000..58913b74
--- /dev/null
+++ b/extractor/features/is_source_target_singleton.h
@@ -0,0 +1,13 @@
+#ifndef _IS_SOURCE_TARGET_SINGLETON_H_
+#define _IS_SOURCE_TARGET_SINGLETON_H_
+
+#include "feature.h"
+
+class IsSourceTargetSingleton : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+#endif
diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc
new file mode 100644
index 00000000..c4792d49
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target.cc
@@ -0,0 +1,30 @@
+#include "max_lex_source_given_target.h"
+
+#include <cmath>
+
+#include "../translation_table.h"
+
+MaxLexSourceGivenTarget::MaxLexSourceGivenTarget(
+ shared_ptr<TranslationTable> table) :
+ table(table) {}
+
+double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const {
+ vector<string> source_words = context.source_phrase.GetWords();
+ // TODO(pauldb): Add NULL to target_words, after fixing translation table.
+ vector<string> target_words = context.target_phrase.GetWords();
+
+ double score = 0;
+ for (string source_word: source_words) {
+ double max_score = 0;
+ for (string target_word: target_words) {
+ max_score = max(max_score,
+ table->GetSourceGivenTargetScore(source_word, target_word));
+ }
+ score += max_score > 0 ? -log10(max_score) : MAX_SCORE;
+ }
+ return score;
+}
+
+string MaxLexSourceGivenTarget::GetName() const {
+ return "MaxLexFGivenE";
+}
diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h
new file mode 100644
index 00000000..e87c1c8e
--- /dev/null
+++ b/extractor/features/max_lex_source_given_target.h
@@ -0,0 +1,24 @@
+#ifndef _MAX_LEX_SOURCE_GIVEN_TARGET_H_
+#define _MAX_LEX_SOURCE_GIVEN_TARGET_H_
+
+#include <memory>
+
+#include "feature.h"
+
+using namespace std;
+
+class TranslationTable;
+
+class MaxLexSourceGivenTarget : public Feature {
+ public:
+ MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table);
+
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+
+ private:
+ shared_ptr<TranslationTable> table;
+};
+
+#endif
diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc
new file mode 100644
index 00000000..d82182fe
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source.cc
@@ -0,0 +1,30 @@
+#include "max_lex_target_given_source.h"
+
+#include <cmath>
+
+#include "../translation_table.h"
+
+MaxLexTargetGivenSource::MaxLexTargetGivenSource(
+ shared_ptr<TranslationTable> table) :
+ table(table) {}
+
+double MaxLexTargetGivenSource::Score(const FeatureContext& context) const {
+ // TODO(pauldb): Add NULL to source_words, after fixing translation table.
+ vector<string> source_words = context.source_phrase.GetWords();
+ vector<string> target_words = context.target_phrase.GetWords();
+
+ double score = 0;
+ for (string target_word: target_words) {
+ double max_score = 0;
+ for (string source_word: source_words) {
+ max_score = max(max_score,
+ table->GetTargetGivenSourceScore(source_word, target_word));
+ }
+ score += max_score > 0 ? -log10(max_score) : MAX_SCORE;
+ }
+ return score;
+}
+
+string MaxLexTargetGivenSource::GetName() const {
+ return "MaxLexEGivenF";
+}
diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h
new file mode 100644
index 00000000..9585ff04
--- /dev/null
+++ b/extractor/features/max_lex_target_given_source.h
@@ -0,0 +1,24 @@
+#ifndef _MAX_LEX_TARGET_GIVEN_SOURCE_H_
+#define _MAX_LEX_TARGET_GIVEN_SOURCE_H_
+
+#include <memory>
+
+#include "feature.h"
+
+using namespace std;
+
+class TranslationTable;
+
+class MaxLexTargetGivenSource : public Feature {
+ public:
+ MaxLexTargetGivenSource(shared_ptr<TranslationTable> table);
+
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+
+ private:
+ shared_ptr<TranslationTable> table;
+};
+
+#endif
diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc
new file mode 100644
index 00000000..c8124cfb
--- /dev/null
+++ b/extractor/features/sample_source_count.cc
@@ -0,0 +1,11 @@
+#include "sample_source_count.h"
+
+#include <cmath>
+
+double SampleSourceCount::Score(const FeatureContext& context) const {
+ return log10(1 + context.sample_source_count);
+}
+
+string SampleSourceCount::GetName() const {
+ return "SampleCountF";
+}
diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h
new file mode 100644
index 00000000..62d236c8
--- /dev/null
+++ b/extractor/features/sample_source_count.h
@@ -0,0 +1,13 @@
+#ifndef _SAMPLE_SOURCE_COUNT_H_
+#define _SAMPLE_SOURCE_COUNT_H_
+
+#include "feature.h"
+
+class SampleSourceCount : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+#endif
diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc
new file mode 100644
index 00000000..748413c3
--- /dev/null
+++ b/extractor/features/target_given_source_coherent.cc
@@ -0,0 +1,12 @@
+#include "target_given_source_coherent.h"
+
+#include <cmath>
+
+double TargetGivenSourceCoherent::Score(const FeatureContext& context) const {
+ double prob = context.pair_count / context.sample_source_count;
+ return prob > 0 ? -log10(prob) : MAX_SCORE;
+}
+
+string TargetGivenSourceCoherent::GetName() const {
+ return "EGivenFCoherent";
+}
diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h
new file mode 100644
index 00000000..09c8edb1
--- /dev/null
+++ b/extractor/features/target_given_source_coherent.h
@@ -0,0 +1,13 @@
+#ifndef _TARGET_GIVEN_SOURCE_COHERENT_H_
+#define _TARGET_GIVEN_SOURCE_COHERENT_H_
+
+#include "feature.h"
+
+class TargetGivenSourceCoherent : public Feature {
+ public:
+ double Score(const FeatureContext& context) const;
+
+ string GetName() const;
+};
+
+#endif
diff --git a/extractor/grammar.cc b/extractor/grammar.cc
new file mode 100644
index 00000000..79a0541d
--- /dev/null
+++ b/extractor/grammar.cc
@@ -0,0 +1,24 @@
+#include "grammar.h"
+
+#include "rule.h"
+
+Grammar::Grammar(const vector<Rule>& rules,
+ const vector<string>& feature_names) :
+ rules(rules), feature_names(feature_names) {}
+
+ostream& operator<<(ostream& os, const Grammar& grammar) {
+ for (Rule rule: grammar.rules) {
+ os << "[X] ||| " << rule.source_phrase << " ||| "
+ << rule.target_phrase << " |||";
+ for (size_t i = 0; i < rule.scores.size(); ++i) {
+ os << " " << grammar.feature_names[i] << "=" << rule.scores[i];
+ }
+ os << " |||";
+ for (auto link: rule.alignment) {
+ os << " " << link.first << "-" << link.second;
+ }
+ os << endl;
+ }
+
+ return os;
+}
diff --git a/extractor/grammar.h b/extractor/grammar.h
new file mode 100644
index 00000000..db15fa7e
--- /dev/null
+++ b/extractor/grammar.h
@@ -0,0 +1,23 @@
+#ifndef _GRAMMAR_H_
+#define _GRAMMAR_H_
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+using namespace std;
+
+class Rule;
+
+class Grammar {
+ public:
+ Grammar(const vector<Rule>& rules, const vector<string>& feature_names);
+
+ friend ostream& operator<<(ostream& os, const Grammar& grammar);
+
+ private:
+ vector<Rule> rules;
+ vector<string> feature_names;
+};
+
+#endif
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 3014c2e9..15268165 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -4,6 +4,10 @@
#include <sstream>
#include <vector>
+#include "grammar.h"
+#include "rule.h"
+#include "vocabulary.h"
+
using namespace std;
vector<string> Tokenize(const string& sentence) {
@@ -22,18 +26,20 @@ vector<string> Tokenize(const string& sentence) {
GrammarExtractor::GrammarExtractor(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment, const Precomputation& precomputation,
- int min_gap_size, int max_rule_span, int max_nonterminals,
- int max_rule_symbols, bool use_baeza_yates) :
+ shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
+ int max_nonterminals, int max_rule_symbols, int max_samples,
+ bool use_baeza_yates, bool require_tight_phrases) :
vocabulary(make_shared<Vocabulary>()),
rule_factory(source_suffix_array, target_data_array, alignment,
- vocabulary, precomputation, min_gap_size, max_rule_span,
- max_nonterminals, max_rule_symbols, use_baeza_yates) {}
+ vocabulary, precomputation, scorer, min_gap_size, max_rule_span,
+ max_nonterminals, max_rule_symbols, max_samples, use_baeza_yates,
+ require_tight_phrases) {}
-void GrammarExtractor::GetGrammar(const string& sentence) {
+Grammar GrammarExtractor::GetGrammar(const string& sentence) {
vector<string> words = Tokenize(sentence);
vector<int> word_ids = AnnotateWords(words);
- rule_factory.GetGrammar(word_ids);
+ return rule_factory.GetGrammar(word_ids);
}
vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 05e153fc..243f33cf 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -5,29 +5,34 @@
#include <vector>
#include "rule_factory.h"
-#include "vocabulary.h"
using namespace std;
class Alignment;
class DataArray;
+class Grammar;
class Precomputation;
+class Rule;
class SuffixArray;
+class Vocabulary;
class GrammarExtractor {
public:
GrammarExtractor(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment,
- const Precomputation& precomputation,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
int max_rule_symbols,
- bool use_baeza_yates);
+ int max_samples,
+ bool use_baeza_yates,
+ bool require_tight_phrases);
- void GetGrammar(const string& sentence);
+ Grammar GetGrammar(const string& sentence);
private:
vector<int> AnnotateWords(const vector<string>& words);
diff --git a/extractor/intersector.cc b/extractor/intersector.cc
index 9d9b54c0..b53479af 100644
--- a/extractor/intersector.cc
+++ b/extractor/intersector.cc
@@ -10,35 +10,51 @@
#include "vocabulary.h"
Intersector::Intersector(shared_ptr<Vocabulary> vocabulary,
- const Precomputation& precomputation,
+ shared_ptr<Precomputation> precomputation,
shared_ptr<SuffixArray> suffix_array,
shared_ptr<MatchingComparator> comparator,
bool use_baeza_yates) :
vocabulary(vocabulary),
suffix_array(suffix_array),
use_baeza_yates(use_baeza_yates) {
- linear_merger = make_shared<LinearMerger>(
- vocabulary, suffix_array->GetData(), comparator);
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ linear_merger = make_shared<LinearMerger>(vocabulary, data_array, comparator);
binary_search_merger = make_shared<BinarySearchMerger>(
- vocabulary, linear_merger, suffix_array->GetData(), comparator);
+ vocabulary, linear_merger, data_array, comparator);
+ ConvertIndexes(precomputation, data_array);
+}
- shared_ptr<DataArray> source_data_array = suffix_array->GetData();
+Intersector::Intersector(shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<LinearMerger> linear_merger,
+ shared_ptr<BinarySearchMerger> binary_search_merger,
+ bool use_baeza_yates) :
+ vocabulary(vocabulary),
+ suffix_array(suffix_array),
+ linear_merger(linear_merger),
+ binary_search_merger(binary_search_merger),
+ use_baeza_yates(use_baeza_yates) {
+ ConvertIndexes(precomputation, suffix_array->GetData());
+}
- const Index& precomputed_index = precomputation.GetInvertedIndex();
+void Intersector::ConvertIndexes(shared_ptr<Precomputation> precomputation,
+ shared_ptr<DataArray> data_array) {
+ const Index& precomputed_index = precomputation->GetInvertedIndex();
for (pair<vector<int>, vector<int> > entry: precomputed_index) {
- vector<int> phrase = Convert(entry.first, source_data_array);
+ vector<int> phrase = ConvertPhrase(entry.first, data_array);
inverted_index[phrase] = entry.second;
}
- const Index& precomputed_collocations = precomputation.GetCollocations();
+ const Index& precomputed_collocations = precomputation->GetCollocations();
for (pair<vector<int>, vector<int> > entry: precomputed_collocations) {
- vector<int> phrase = Convert(entry.first, source_data_array);
+ vector<int> phrase = ConvertPhrase(entry.first, data_array);
collocations[phrase] = entry.second;
}
}
-vector<int> Intersector::Convert(
- const vector<int>& old_phrase, shared_ptr<DataArray> source_data_array) {
+vector<int> Intersector::ConvertPhrase(const vector<int>& old_phrase,
+ shared_ptr<DataArray> data_array) {
vector<int> new_phrase;
new_phrase.reserve(old_phrase.size());
@@ -49,7 +65,7 @@ vector<int> Intersector::Convert(
new_phrase.push_back(vocabulary->GetNonterminalIndex(arity));
} else {
new_phrase.push_back(
- vocabulary->GetTerminalIndex(source_data_array->GetWord(word_id)));
+ vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
}
}
@@ -70,8 +86,7 @@ PhraseLocation Intersector::Intersect(
&& vocabulary->IsTerminal(symbols.back()));
if (collocations.count(symbols)) {
- return PhraseLocation(make_shared<vector<int> >(collocations[symbols]),
- phrase.Arity());
+ return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
}
vector<int> locations;
@@ -91,19 +106,18 @@ PhraseLocation Intersector::Intersect(
prefix_matchings->end(), suffix_matchings->begin(),
suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns);
}
- return PhraseLocation(shared_ptr<vector<int> >(new vector<int>(locations)),
- phrase.Arity() + 1);
+ return PhraseLocation(locations, phrase.Arity() + 1);
}
void Intersector::ExtendPhraseLocation(
const Phrase& phrase, PhraseLocation& phrase_location) {
int low = phrase_location.sa_low, high = phrase_location.sa_high;
- if (phrase.Arity() || phrase_location.num_subpatterns ||
- phrase_location.IsEmpty()) {
+ if (phrase_location.matchings != NULL) {
return;
}
phrase_location.num_subpatterns = 1;
+ phrase_location.sa_low = phrase_location.sa_high = 0;
vector<int> symbols = phrase.Get();
if (inverted_index.count(symbols)) {
diff --git a/extractor/intersector.h b/extractor/intersector.h
index 874ffc1b..f023cc96 100644
--- a/extractor/intersector.h
+++ b/extractor/intersector.h
@@ -13,8 +13,8 @@
using namespace std;
using namespace tr1;
-typedef boost::hash<vector<int> > vector_hash;
-typedef unordered_map<vector<int>, vector<int>, vector_hash> Index;
+typedef boost::hash<vector<int> > VectorHash;
+typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
class DataArray;
class MatchingComparator;
@@ -28,19 +28,31 @@ class Intersector {
public:
Intersector(
shared_ptr<Vocabulary> vocabulary,
- const Precomputation& precomputaiton,
+ shared_ptr<Precomputation> precomputation,
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<MatchingComparator> comparator,
bool use_baeza_yates);
+ // For testing.
+ Intersector(
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<LinearMerger> linear_merger,
+ shared_ptr<BinarySearchMerger> binary_search_merger,
+ bool use_baeza_yates);
+
PhraseLocation Intersect(
const Phrase& prefix, PhraseLocation& prefix_location,
const Phrase& suffix, PhraseLocation& suffix_location,
const Phrase& phrase);
private:
- vector<int> Convert(const vector<int>& old_phrase,
- shared_ptr<DataArray> source_data_array);
+ void ConvertIndexes(shared_ptr<Precomputation> precomputation,
+ shared_ptr<DataArray> data_array);
+
+ vector<int> ConvertPhrase(const vector<int>& old_phrase,
+ shared_ptr<DataArray> data_array);
void ExtendPhraseLocation(const Phrase& phrase,
PhraseLocation& phrase_location);
diff --git a/extractor/intersector_test.cc b/extractor/intersector_test.cc
new file mode 100644
index 00000000..a3756902
--- /dev/null
+++ b/extractor/intersector_test.cc
@@ -0,0 +1,193 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "intersector.h"
+#include "mocks/mock_binary_search_merger.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_linear_merger.h"
+#include "mocks/mock_precomputation.h"
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_vocabulary.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class IntersectorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data = {2, 3, 4, 3, 4, 3};
+ vector<string> words = {"a", "b", "c", "b", "c", "b"};
+ data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWord(data[i]))
+ .WillRepeatedly(Return(words[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i]))
+ .WillRepeatedly(Return(data[i]));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(data[i]))
+ .WillRepeatedly(Return(words[i]));
+ }
+
+ vector<int> suffixes = {0, 1, 3, 5, 2, 4, 6};
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData())
+ .WillRepeatedly(Return(data_array));
+ EXPECT_CALL(*suffix_array, GetSize())
+ .WillRepeatedly(Return(suffixes.size()));
+ for (size_t i = 0; i < suffixes.size(); ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i))
+ .WillRepeatedly(Return(suffixes[i]));
+ }
+
+ vector<int> key = {2, -1, 4};
+ vector<int> values = {0, 2};
+ collocations[key] = values;
+ precomputation = make_shared<MockPrecomputation>();
+ EXPECT_CALL(*precomputation, GetInvertedIndex())
+ .WillRepeatedly(ReturnRef(inverted_index));
+ EXPECT_CALL(*precomputation, GetCollocations())
+ .WillRepeatedly(ReturnRef(collocations));
+
+ linear_merger = make_shared<MockLinearMerger>();
+ binary_search_merger = make_shared<MockBinarySearchMerger>();
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ }
+
+ Index inverted_index;
+ Index collocations;
+ vector<int> data;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<MockPrecomputation> precomputation;
+ shared_ptr<MockLinearMerger> linear_merger;
+ shared_ptr<MockBinarySearchMerger> binary_search_merger;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<Intersector> intersector;
+};
+
+TEST_F(IntersectorTest, TestCachedCollocation) {
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ suffix_array, linear_merger, binary_search_merger, false);
+
+ vector<int> prefix_symbols = {2, -1};
+ Phrase prefix = phrase_builder->Build(prefix_symbols);
+ vector<int> suffix_symbols = {-1, 4};
+ Phrase suffix = phrase_builder->Build(suffix_symbols);
+ vector<int> symbols = {2, -1, 4};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_locs(0, 1), suffix_locs(2, 3);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix, prefix_locs, suffix, suffix_locs, phrase);
+
+ vector<int> expected_locs = {0, 2};
+ PhraseLocation expected_result(expected_locs, 2);
+
+ EXPECT_EQ(expected_result, result);
+ EXPECT_EQ(PhraseLocation(0, 1), prefix_locs);
+ EXPECT_EQ(PhraseLocation(2, 3), suffix_locs);
+}
+
+TEST_F(IntersectorTest, TestLinearMergeaXb) {
+ vector<int> prefix_symbols = {3, -1};
+ Phrase prefix = phrase_builder->Build(prefix_symbols);
+ vector<int> suffix_symbols = {-1, 4};
+ Phrase suffix = phrase_builder->Build(suffix_symbols);
+ vector<int> symbols = {3, -1, 4};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6);
+
+ vector<int> ex_prefix_locs = {1, 3, 5};
+ PhraseLocation extended_prefix_locs(ex_prefix_locs, 1);
+ vector<int> ex_suffix_locs = {2, 4};
+ PhraseLocation extended_suffix_locs(ex_suffix_locs, 1);
+
+ vector<int> expected_locs = {1, 4};
+ EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(SetArgReferee<0>(expected_locs));
+ EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0);
+
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ suffix_array, linear_merger, binary_search_merger, false);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix, prefix_locs, suffix, suffix_locs, phrase);
+ PhraseLocation expected_result(expected_locs, 2);
+
+ EXPECT_EQ(expected_result, result);
+ EXPECT_EQ(extended_prefix_locs, prefix_locs);
+ EXPECT_EQ(extended_suffix_locs, suffix_locs);
+}
+
+TEST_F(IntersectorTest, TestBinarySearchMergeaXb) {
+ vector<int> prefix_symbols = {3, -1};
+ Phrase prefix = phrase_builder->Build(prefix_symbols);
+ vector<int> suffix_symbols = {-1, 4};
+ Phrase suffix = phrase_builder->Build(suffix_symbols);
+ vector<int> symbols = {3, -1, 4};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_locs(1, 4), suffix_locs(4, 6);
+
+ vector<int> ex_prefix_locs = {1, 3, 5};
+ PhraseLocation extended_prefix_locs(ex_prefix_locs, 1);
+ vector<int> ex_suffix_locs = {2, 4};
+ PhraseLocation extended_suffix_locs(ex_suffix_locs, 1);
+
+ vector<int> expected_locs = {1, 4};
+ EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(SetArgReferee<0>(expected_locs));
+ EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0);
+
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ suffix_array, linear_merger, binary_search_merger, true);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix, prefix_locs, suffix, suffix_locs, phrase);
+ PhraseLocation expected_result(expected_locs, 2);
+
+ EXPECT_EQ(expected_result, result);
+ EXPECT_EQ(extended_prefix_locs, prefix_locs);
+ EXPECT_EQ(extended_suffix_locs, suffix_locs);
+}
+
+TEST_F(IntersectorTest, TestMergeaXbXc) {
+ vector<int> prefix_symbols = {2, -1, 4, -1};
+ Phrase prefix = phrase_builder->Build(prefix_symbols);
+ vector<int> suffix_symbols = {-1, 4, -1, 4};
+ Phrase suffix = phrase_builder->Build(suffix_symbols);
+ vector<int> symbols = {2, -1, 4, -1, 4};
+ Phrase phrase = phrase_builder->Build(symbols);
+
+ vector<int> ex_prefix_locs = {0, 2, 0, 4};
+ PhraseLocation extended_prefix_locs(ex_prefix_locs, 2);
+ vector<int> ex_suffix_locs = {2, 4};
+ PhraseLocation extended_suffix_locs(ex_suffix_locs, 2);
+ vector<int> expected_locs = {0, 2, 4};
+ EXPECT_CALL(*linear_merger, Merge(_, _, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(SetArgReferee<0>(expected_locs));
+ EXPECT_CALL(*binary_search_merger, Merge(_, _, _, _, _, _, _, _, _)).Times(0);
+
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ suffix_array, linear_merger, binary_search_merger, false);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix, extended_prefix_locs, suffix, extended_suffix_locs, phrase);
+ PhraseLocation expected_result(expected_locs, 3);
+
+ EXPECT_EQ(expected_result, result);
+ EXPECT_EQ(ex_prefix_locs, *extended_prefix_locs.matchings);
+ EXPECT_EQ(ex_suffix_locs, *extended_suffix_locs.matchings);
+}
+
+} // namespace
diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc
index 59e5f34c..666f8d87 100644
--- a/extractor/linear_merger.cc
+++ b/extractor/linear_merger.cc
@@ -14,6 +14,8 @@ LinearMerger::LinearMerger(shared_ptr<Vocabulary> vocabulary,
shared_ptr<MatchingComparator> comparator) :
vocabulary(vocabulary), data_array(data_array), comparator(comparator) {}
+LinearMerger::LinearMerger() {}
+
LinearMerger::~LinearMerger() {}
void LinearMerger::Merge(
diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h
index 7bfb9246..6a69b804 100644
--- a/extractor/linear_merger.h
+++ b/extractor/linear_merger.h
@@ -26,6 +26,9 @@ class LinearMerger {
vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
int prefix_subpatterns, int suffix_subpatterns) const;
+ protected:
+ LinearMerger();
+
private:
shared_ptr<Vocabulary> vocabulary;
shared_ptr<DataArray> data_array;
diff --git a/extractor/mocks/mock_binary_search_merger.h b/extractor/mocks/mock_binary_search_merger.h
new file mode 100644
index 00000000..e1375ee3
--- /dev/null
+++ b/extractor/mocks/mock_binary_search_merger.h
@@ -0,0 +1,15 @@
+#include <gmock/gmock.h>
+
+#include <vector>
+
+#include "../binary_search_merger.h"
+#include "../phrase.h"
+
+using namespace std;
+
+class MockBinarySearchMerger: public BinarySearchMerger {
+ public:
+ MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&,
+ vector<int>::iterator, vector<int>::iterator, vector<int>::iterator,
+ vector<int>::iterator, int, int));
+};
diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h
index cda8f7a6..54497cf5 100644
--- a/extractor/mocks/mock_data_array.h
+++ b/extractor/mocks/mock_data_array.h
@@ -10,5 +10,6 @@ class MockDataArray : public DataArray {
MOCK_CONST_METHOD0(GetVocabularySize, int());
MOCK_CONST_METHOD1(HasWord, bool(const string& word));
MOCK_CONST_METHOD1(GetWordId, int(const string& word));
+ MOCK_CONST_METHOD1(GetWord, string(int word_id));
MOCK_CONST_METHOD1(GetSentenceId, int(int position));
};
diff --git a/extractor/mocks/mock_linear_merger.h b/extractor/mocks/mock_linear_merger.h
index 0defa88a..82243428 100644
--- a/extractor/mocks/mock_linear_merger.h
+++ b/extractor/mocks/mock_linear_merger.h
@@ -2,19 +2,13 @@
#include <vector>
-#include "linear_merger.h"
-#include "phrase.h"
+#include "../linear_merger.h"
+#include "../phrase.h"
using namespace std;
class MockLinearMerger: public LinearMerger {
public:
- MockLinearMerger(shared_ptr<Vocabulary> vocabulary,
- shared_ptr<DataArray> data_array,
- shared_ptr<MatchingComparator> comparator) :
- LinearMerger(vocabulary, data_array, comparator) {}
-
-
MOCK_CONST_METHOD9(Merge, void(vector<int>&, const Phrase&, const Phrase&,
vector<int>::iterator, vector<int>::iterator, vector<int>::iterator,
vector<int>::iterator, int, int));
diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h
new file mode 100644
index 00000000..987bdb2f
--- /dev/null
+++ b/extractor/mocks/mock_precomputation.h
@@ -0,0 +1,9 @@
+#include <gmock/gmock.h>
+
+#include "../precomputation.h"
+
+class MockPrecomputation : public Precomputation {
+ public:
+ MOCK_CONST_METHOD0(GetInvertedIndex, const Index&());
+ MOCK_CONST_METHOD0(GetCollocations, const Index&());
+};
diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h
index 38d8bad6..11a3a443 100644
--- a/extractor/mocks/mock_suffix_array.h
+++ b/extractor/mocks/mock_suffix_array.h
@@ -1,5 +1,6 @@
#include <gmock/gmock.h>
+#include <memory>
#include <string>
#include "../data_array.h"
@@ -10,8 +11,9 @@ using namespace std;
class MockSuffixArray : public SuffixArray {
public:
- MockSuffixArray() : SuffixArray(make_shared<DataArray>()) {}
-
MOCK_CONST_METHOD0(GetSize, int());
+ MOCK_CONST_METHOD0(GetData, shared_ptr<DataArray>());
+ MOCK_CONST_METHOD0(BuildLCPArray, vector<int>());
+ MOCK_CONST_METHOD1(GetSuffix, int(int));
MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int));
};
diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h
index 06dea10f..e5c191f5 100644
--- a/extractor/mocks/mock_vocabulary.h
+++ b/extractor/mocks/mock_vocabulary.h
@@ -5,4 +5,5 @@
class MockVocabulary : public Vocabulary {
public:
MOCK_METHOD1(GetTerminalValue, string(int word_id));
+ MOCK_METHOD1(GetTerminalIndex, int(const string& word));
};
diff --git a/extractor/phrase.cc b/extractor/phrase.cc
index f9bd9908..6dc242db 100644
--- a/extractor/phrase.cc
+++ b/extractor/phrase.cc
@@ -23,3 +23,32 @@ vector<int> Phrase::Get() const {
int Phrase::GetSymbol(int position) const {
return symbols[position];
}
+
+int Phrase::GetNumSymbols() const {
+ return symbols.size();
+}
+
+vector<string> Phrase::GetWords() const {
+ return words;
+}
+
+int Phrase::operator<(const Phrase& other) const {
+ return symbols < other.symbols;
+}
+
+ostream& operator<<(ostream& os, const Phrase& phrase) {
+ int current_word = 0;
+ for (size_t i = 0; i < phrase.symbols.size(); ++i) {
+ if (phrase.symbols[i] < 0) {
+ os << "[X," << -phrase.symbols[i] << "]";
+ } else {
+ os << phrase.words[current_word];
+ ++current_word;
+ }
+
+ if (i + 1 < phrase.symbols.size()) {
+ os << " ";
+ }
+ }
+ return os;
+}
diff --git a/extractor/phrase.h b/extractor/phrase.h
index 5a5124d9..f40a8169 100644
--- a/extractor/phrase.h
+++ b/extractor/phrase.h
@@ -1,6 +1,7 @@
#ifndef _PHRASE_H_
#define _PHRASE_H_
+#include <iostream>
#include <string>
#include <vector>
@@ -20,6 +21,17 @@ class Phrase {
int GetSymbol(int position) const;
+ //TODO(pauldb): Unit test this method.
+ int GetNumSymbols() const;
+
+ //TODO(pauldb): Add unit tests.
+ vector<string> GetWords() const;
+
+ //TODO(pauldb): Add unit tests.
+ int operator<(const Phrase& other) const;
+
+ friend ostream& operator<<(ostream& os, const Phrase& phrase);
+
private:
vector<int> symbols;
vector<int> var_pos;
diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc
index 7f3447e5..c4e0c2ed 100644
--- a/extractor/phrase_builder.cc
+++ b/extractor/phrase_builder.cc
@@ -19,3 +19,27 @@ Phrase PhraseBuilder::Build(const vector<int>& symbols) {
}
return phrase;
}
+
+Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) {
+ vector<int> symbols = phrase.Get();
+ int num_nonterminals = 0;
+ if (start_x) {
+ num_nonterminals = 1;
+ symbols.insert(symbols.begin(),
+ vocabulary->GetNonterminalIndex(num_nonterminals));
+ }
+
+ for (size_t i = start_x; i < symbols.size(); ++i) {
+ if (vocabulary->IsTerminal(symbols[i])) {
+ ++num_nonterminals;
+ symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals);
+ }
+ }
+
+ if (end_x) {
+ ++num_nonterminals;
+ symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals));
+ }
+
+ return Build(symbols);
+}
diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h
index f01cb23b..a49af457 100644
--- a/extractor/phrase_builder.h
+++ b/extractor/phrase_builder.h
@@ -15,6 +15,8 @@ class PhraseBuilder {
Phrase Build(const vector<int>& symbols);
+ Phrase Extend(const Phrase& phrase, bool start_x, bool end_x);
+
private:
shared_ptr<Vocabulary> vocabulary;
};
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
index b5b68549..984407c5 100644
--- a/extractor/phrase_location.cc
+++ b/extractor/phrase_location.cc
@@ -1,16 +1,12 @@
#include "phrase_location.h"
-#include <cstdio>
-
PhraseLocation::PhraseLocation(int sa_low, int sa_high) :
- sa_low(sa_low), sa_high(sa_high),
- matchings(shared_ptr<vector<int> >()),
- num_subpatterns(0) {}
+ sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {}
-PhraseLocation::PhraseLocation(shared_ptr<vector<int> > matchings,
+PhraseLocation::PhraseLocation(const vector<int>& matchings,
int num_subpatterns) :
sa_high(0), sa_low(0),
- matchings(matchings),
+ matchings(make_shared<vector<int> >(matchings)),
num_subpatterns(num_subpatterns) {}
bool PhraseLocation::IsEmpty() {
diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h
index 96004b33..e04d8628 100644
--- a/extractor/phrase_location.h
+++ b/extractor/phrase_location.h
@@ -9,7 +9,7 @@ using namespace std;
struct PhraseLocation {
PhraseLocation(int sa_low = -1, int sa_high = -1);
- PhraseLocation(shared_ptr<vector<int> > matchings, int num_subpatterns);
+ PhraseLocation(const vector<int>& matchings, int num_subpatterns);
bool IsEmpty();
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 97a70554..9a167976 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -2,11 +2,6 @@
#include <iostream>
#include <queue>
-#include <tr1/unordered_set>
-#include <tuple>
-#include <vector>
-
-#include <boost/functional/hash.hpp>
#include "data_array.h"
#include "suffix_array.h"
@@ -26,9 +21,8 @@ Precomputation::Precomputation(
suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
min_frequency);
- unordered_set<vector<int>, boost::hash<vector<int> > > frequent_patterns_set;
- unordered_set<vector<int>, boost::hash<vector<int> > >
- super_frequent_patterns_set;
+ unordered_set<vector<int>, VectorHash> frequent_patterns_set;
+ unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
for (size_t i = 0; i < frequent_patterns.size(); ++i) {
frequent_patterns_set.insert(frequent_patterns[i]);
if (i < num_super_frequent_patterns) {
@@ -60,6 +54,10 @@ Precomputation::Precomputation(
}
}
+Precomputation::Precomputation() {}
+
+Precomputation::~Precomputation() {}
+
vector<vector<int> > Precomputation::FindMostFrequentPatterns(
shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) {
@@ -107,7 +105,7 @@ void Precomputation::AddCollocations(
}
if (start2 - start1 - size1 >= min_gap_size
- && start2 + size2 - size1 <= max_rule_span
+ && start2 + size2 - start1 <= max_rule_span
&& size1 + size2 + 1 <= max_rule_symbols) {
vector<int> pattern(data.begin() + start1,
data.begin() + start1 + size1);
@@ -126,7 +124,7 @@ void Precomputation::AddCollocations(
}
if (start3 - start2 - size2 >= min_gap_size
- && start3 + size3 - size1 <= max_rule_span
+ && start3 + size3 - start1 <= max_rule_span
&& size1 + size2 + size3 + 2 <= max_rule_symbols
&& (is_super1 || is_super3)) {
pattern.insert(pattern.end(), data.begin() + start3,
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 0d1b269f..428505d8 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -16,8 +16,8 @@ using namespace tr1;
class SuffixArray;
-typedef boost::hash<vector<int> > vector_hash;
-typedef unordered_map<vector<int>, vector<int>, vector_hash> Index;
+typedef boost::hash<vector<int> > VectorHash;
+typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
class Precomputation {
public:
@@ -27,20 +27,25 @@ class Precomputation {
int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency);
+ virtual ~Precomputation();
+
void WriteBinary(const fs::path& filepath) const;
- const Index& GetInvertedIndex() const;
- const Index& GetCollocations() const;
+ virtual const Index& GetInvertedIndex() const;
+ virtual const Index& GetCollocations() const;
static int NON_TERMINAL;
+ protected:
+ Precomputation();
+
private:
vector<vector<int> > FindMostFrequentPatterns(
shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
int num_frequent_patterns, int max_frequent_phrase_len,
int min_frequency);
void AddCollocations(
- const vector<tuple<int, int, int> >& matchings, const vector<int>& data,
+ const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data,
int max_rule_span, int min_gap_size, int max_rule_symbols);
void AddStartPositions(vector<int>& positions, int pos1, int pos2);
void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
new file mode 100644
index 00000000..9edb29db
--- /dev/null
+++ b/extractor/precomputation_test.cc
@@ -0,0 +1,138 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <vector>
+
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_suffix_array.h"
+#include "precomputation.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class PrecomputationTest : public Test {
+ protected:
+ virtual void SetUp() {
+ data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};
+ data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
+
+ vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};
+ vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0};
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (size_t i = 0; i < suffixes.size(); ++i) {
+ EXPECT_CALL(*suffix_array,
+ GetSuffix(i)).WillRepeatedly(Return(suffixes[i]));
+ }
+ EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
+ }
+
+ vector<int> data;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+};
+
+TEST_F(PrecomputationTest, TestInvertedIndex) {
+ Precomputation precomputation(suffix_array, 100, 3, 10, 5, 1, 4, 2);
+ Index inverted_index = precomputation.GetInvertedIndex();
+
+ EXPECT_EQ(8, inverted_index.size());
+ vector<int> key = {2};
+ vector<int> expected_value = {1, 5, 8, 11};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {3};
+ expected_value = {2, 6, 9};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {4};
+ expected_value = {0, 10};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {5};
+ expected_value = {3, 7};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {4, 2};
+ expected_value = {0, 10};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {2, 3};
+ expected_value = {1, 5, 8};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {3, 5};
+ expected_value = {2, 6};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+ key = {2, 3, 5};
+ expected_value = {1, 5};
+ EXPECT_EQ(expected_value, inverted_index[key]);
+
+ key = {2, 4};
+ EXPECT_EQ(0, inverted_index.count(key));
+}
+
+TEST_F(PrecomputationTest, TestCollocations) {
+ Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
+ Index collocations = precomputation.GetCollocations();
+
+ EXPECT_EQ(-1, precomputation.NON_TERMINAL);
+ vector<int> key = {2, 3, -1, 2};
+ vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, 3, -1, 2, 3};
+ expected_value = {1, 5, 1, 8, 5, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, 3, -1, 3};
+ expected_value = {1, 6, 1, 9, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2};
+ expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3};
+ expected_value = {2, 6, 2, 9, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, 3};
+ expected_value = {2, 5, 2, 8, 6, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2};
+ expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2, 3};
+ expected_value = {1, 5, 1, 8, 5, 8};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3};
+ expected_value = {1, 6, 1, 9, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+
+ key = {2, -1, 2, -1, 2};
+ expected_value = {1, 5, 8, 5, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 2, -1, 3};
+ expected_value = {1, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3, -1, 2};
+ expected_value = {1, 6, 8, 5, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {2, -1, 3, -1, 3};
+ expected_value = {1, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, -1, 2};
+ expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 2, -1, 3};
+ expected_value = {2, 5, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3, -1, 2};
+ expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11};
+ EXPECT_EQ(expected_value, collocations[key]);
+ key = {3, -1, 3, -1, 3};
+ expected_value = {2, 6, 9};
+ EXPECT_EQ(expected_value, collocations[key]);
+
+ // Exceeds max_rule_symbols.
+ key = {2, -1, 2, -1, 2, 3};
+ EXPECT_EQ(0, collocations.count(key));
+ // Contains non frequent pattern.
+ key = {2, -1, 5};
+ EXPECT_EQ(0, collocations.count(key));
+}
+
+} // namespace
diff --git a/extractor/rule.cc b/extractor/rule.cc
new file mode 100644
index 00000000..9c7ac9b5
--- /dev/null
+++ b/extractor/rule.cc
@@ -0,0 +1,10 @@
+#include "rule.h"
+
+Rule::Rule(const Phrase& source_phrase,
+ const Phrase& target_phrase,
+ const vector<double>& scores,
+ const vector<pair<int, int> >& alignment) :
+ source_phrase(source_phrase),
+ target_phrase(target_phrase),
+ scores(scores),
+ alignment(alignment) {}
diff --git a/extractor/rule.h b/extractor/rule.h
new file mode 100644
index 00000000..64ff8794
--- /dev/null
+++ b/extractor/rule.h
@@ -0,0 +1,20 @@
+#ifndef _RULE_H_
+#define _RULE_H_
+
+#include <vector>
+
+#include "phrase.h"
+
+using namespace std;
+
+struct Rule {
+ Rule(const Phrase& source_phrase, const Phrase& target_phrase,
+ const vector<double>& scores, const vector<pair<int, int> >& alignment);
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ vector<double> scores;
+ vector<pair<int, int> > alignment;
+};
+
+#endif
diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc
index 48b39b63..9460020f 100644
--- a/extractor/rule_extractor.cc
+++ b/extractor/rule_extractor.cc
@@ -1,10 +1,679 @@
#include "rule_extractor.h"
+#include <map>
+#include <tr1/unordered_set>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/feature.h"
+#include "phrase_builder.h"
+#include "phrase_location.h"
+#include "rule.h"
+#include "scorer.h"
+#include "vocabulary.h"
+
+using namespace std;
+using namespace tr1;
+
RuleExtractor::RuleExtractor(
- shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alingment) {
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ target_data_array(target_data_array),
+ alignment(alignment),
+ phrase_builder(phrase_builder),
+ scorer(scorer),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size),
+ max_nonterminals(max_nonterminals),
+ max_rule_symbols(max_rule_symbols),
+ require_aligned_terminal(require_aligned_terminal),
+ require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {}
+
+vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const {
+ int num_subpatterns = location.num_subpatterns;
+ vector<int> matchings = *location.matchings;
+
+ map<Phrase, double> source_phrase_counter;
+ map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter;
+ for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) {
+ vector<int> matching(i, i + num_subpatterns);
+ vector<Extract> extracts = ExtractAlignments(phrase, matching);
+
+ for (Extract e: extracts) {
+ source_phrase_counter[e.source_phrase] += e.pairs_count;
+ alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1;
+ }
+ }
+
+ vector<Rule> rules;
+ for (auto source_phrase_entry: alignments_counter) {
+ Phrase source_phrase = source_phrase_entry.first;
+ for (auto target_phrase_entry: source_phrase_entry.second) {
+ Phrase target_phrase = target_phrase_entry.first;
+
+ int max_locations = 0, num_locations = 0;
+ PhraseAlignment most_frequent_alignment;
+ for (auto alignment_entry: target_phrase_entry.second) {
+ num_locations += alignment_entry.second;
+ if (alignment_entry.second > max_locations) {
+ most_frequent_alignment = alignment_entry.first;
+ max_locations = alignment_entry.second;
+ }
+ }
+
+ FeatureContext context(source_phrase, target_phrase,
+ source_phrase_counter[source_phrase], num_locations);
+ vector<double> scores = scorer->Score(context);
+ rules.push_back(Rule(source_phrase, target_phrase, scores,
+ most_frequent_alignment));
+ }
+ }
+ return rules;
+}
+
+vector<Extract> RuleExtractor::ExtractAlignments(
+ const Phrase& phrase, const vector<int>& matching) const {
+ vector<Extract> extracts;
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ vector<int> source_low, source_high, target_low, target_high;
+ GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id);
+
+ int num_subpatterns = matching.size();
+ vector<int> chunklen(num_subpatterns);
+ for (size_t i = 0; i < num_subpatterns; ++i) {
+ chunklen[i] = phrase.GetChunkLen(i);
+ }
+
+ if (!CheckAlignedTerminals(matching, chunklen, source_low) ||
+ !CheckTightPhrases(matching, chunklen, source_low)) {
+ return extracts;
+ }
+
+ int source_back_low = -1, source_back_high = -1;
+ int source_phrase_low = matching[0] - source_sent_start;
+ int source_phrase_high = matching.back() + chunklen.back() - source_sent_start;
+ int target_phrase_low = -1, target_phrase_high = -1;
+ if (!FindFixPoint(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high,
+ target_low, target_high, source_back_low, source_back_high,
+ sentence_id, min_gap_size, 0,
+ max_nonterminals - matching.size() + 1, 1, 1, false)) {
+ return extracts;
+ }
+
+ bool met_constraints = true;
+ int num_symbols = phrase.GetNumSymbols();
+ vector<pair<int, int> > source_gaps, target_gaps;
+ if (!CheckGaps(source_gaps, target_gaps, matching, chunklen, source_low,
+ source_high, target_low, target_high, source_phrase_low,
+ source_phrase_high, source_back_low, source_back_high,
+ num_symbols, met_constraints)) {
+ return extracts;
+ }
+
+ bool start_x = source_back_low != source_phrase_low;
+ bool end_x = source_back_high != source_phrase_high;
+ Phrase source_phrase = phrase_builder->Extend(phrase, start_x, end_x);
+ if (met_constraints) {
+ AddExtracts(extracts, source_phrase, target_gaps, target_low,
+ target_phrase_low, target_phrase_high, sentence_id);
+ }
+
+ if (source_gaps.size() >= max_nonterminals ||
+ source_phrase.GetNumSymbols() >= max_rule_symbols ||
+ source_back_high - source_back_low + min_gap_size > max_rule_span) {
+ // Cannot add any more nonterminals.
+ return extracts;
+ }
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 1 - i; j < 2; ++j) {
+ AddNonterminalExtremities(extracts, source_phrase, source_phrase_low,
+ source_phrase_high, source_back_low, source_back_high, source_low,
+ source_high, target_low, target_high, target_gaps, sentence_id, i, j);
+ }
+ }
+
+ return extracts;
+}
+
+void RuleExtractor::GetLinksSpans(
+ vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
+ // Ignore end of line markers.
+ int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
+ source_data_array->GetSentenceStart(sentence_id) - 1;
+ int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
+ target_data_array->GetSentenceStart(sentence_id) - 1;
+ source_low = vector<int>(source_sent_len, -1);
+ source_high = vector<int>(source_sent_len, -1);
+
+ // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we
+ // can speed it up.
+ target_low = vector<int>(target_sent_len, -1);
+ target_high = vector<int>(target_sent_len, -1);
+ const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id);
+ for (auto link: links) {
+ if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
+ source_low[link.first] = link.second;
+ }
+ source_high[link.first] = max(source_high[link.first], link.second + 1);
+
+ if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
+ target_low[link.second] = link.first;
+ }
+ target_high[link.second] = max(target_high[link.second], link.first + 1);
+ }
+}
+
+bool RuleExtractor::CheckAlignedTerminals(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_aligned_terminal) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ int num_aligned_chunks = 0;
+ for (size_t i = 0; i < chunklen.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ int sent_index = matching[i] - source_sent_start + j;
+ if (source_low[sent_index] != -1) {
+ ++num_aligned_chunks;
+ break;
+ }
+ }
+ }
+
+ if (num_aligned_chunks == 0) {
+ return false;
+ }
+
+ return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
+}
+
+bool RuleExtractor::CheckTightPhrases(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_tight_phrases) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - 1 - source_sent_start;
+ if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RuleExtractor::FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, int max_low_x, int max_high_x,
+ bool allow_arbitrary_expansion) const {
+ int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
+ source_data_array->GetSentenceStart(sentence_id) - 1;
+ int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
+ target_data_array->GetSentenceStart(sentence_id) - 1;
+
+ int prev_target_low = target_phrase_low;
+ int prev_target_high = target_phrase_high;
+ FindProjection(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+
+ if (target_phrase_low == -1) {
+ // TODO(pauldb): Low priority corner case inherited from Adam's code:
+ // If w is unaligned, but we don't require aligned terminals, returning an
+ // error here prevents the extraction of the allowed rule
+ // X -> X_1 w X_2 / X_1 X_2
+ return false;
+ }
+
+ if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
+ if (prev_target_low - target_phrase_low < min_target_gap_size) {
+ target_phrase_low = prev_target_low - min_target_gap_size;
+ if (target_phrase_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
+ if (target_phrase_high - prev_target_high < min_target_gap_size) {
+ target_phrase_high = prev_target_high + min_target_gap_size;
+ if (target_phrase_high > target_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ return false;
+ }
+
+ source_back_low = source_back_high = -1;
+ FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high);
+ int new_x = 0, new_low_x = 0, new_high_x = 0;
+
+ while (true) {
+ source_back_low = min(source_back_low, source_phrase_low);
+ source_back_high = max(source_back_high, source_phrase_high);
+
+ if (source_back_low == source_phrase_low &&
+ source_back_high == source_phrase_high) {
+ return true;
+ }
+
+ if (new_low_x >= max_low_x && source_back_low < source_phrase_low) {
+ // Extension on the left side not allowed.
+ return false;
+ }
+ if (new_high_x >= max_high_x && source_back_high > source_phrase_high) {
+ // Extension on the right side not allowed.
+ return false;
+ }
+
+ // Extend left side.
+ if (source_back_low < source_phrase_low) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ ++new_x; ++new_low_x;
+ if (source_phrase_low - source_back_low < min_source_gap_size) {
+ source_back_low = source_phrase_low - min_source_gap_size;
+ if (source_back_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend right side.
+ if (source_back_high > source_phrase_high) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ ++new_x; ++new_high_x;
+ if (source_back_high - source_phrase_high < min_source_gap_size) {
+ source_back_high = source_phrase_high + min_source_gap_size;
+ if (source_back_high > source_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (source_back_high - source_back_low > max_rule_span) {
+ // Rule span too wide.
+ return false;
+ }
+
+ prev_target_low = target_phrase_low;
+ prev_target_high = target_phrase_high;
+ FindProjection(source_back_low, source_phrase_low, source_low, source_high,
+ target_phrase_low, target_phrase_high);
+ FindProjection(source_phrase_high, source_back_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+ if (prev_target_low == target_phrase_low &&
+ prev_target_high == target_phrase_high) {
+ return true;
+ }
+
+ if (!allow_arbitrary_expansion) {
+ // Arbitrary expansion not allowed.
+ return false;
+ }
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ // Target side too wide.
+ return false;
+ }
+
+ source_phrase_low = source_back_low;
+ source_phrase_high = source_back_high;
+ FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
+ source_back_low, source_back_high);
+ FindProjection(prev_target_high, target_phrase_high, target_low,
+ target_high, source_back_low, source_back_high);
+ }
+
+ return false;
+}
+
+void RuleExtractor::FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const {
+ for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
+ if (source_low[i] != -1) {
+ if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
+ target_phrase_low = source_low[i];
+ }
+ target_phrase_high = max(target_phrase_high, source_high[i]);
+ }
+ }
}
-void RuleExtractor::ExtractRules() {
+bool RuleExtractor::CheckGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int& num_symbols, bool& met_constraints) const {
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ if (source_back_low < source_phrase_low) {
+ source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_back_low] == -1 ||
+ source_low[source_phrase_low - 1] == -1)) {
+ // Inside edges of preceding gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - source_sent_start;
+ source_gaps.push_back(make_pair(gap_start, gap_end));
+ }
+
+ if (source_phrase_high < source_back_high) {
+ source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
+ source_low[source_back_high - 1] == -1)) {
+ // Inside edges of following gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases &&
+ source_low[source_phrase_high - 1] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
+ for (size_t i = 0; i < source_gaps.size(); ++i) {
+ if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
+ source_high, target_gaps[i].first, target_gaps[i].second,
+ target_low, target_high, source_gaps[i].first,
+ source_gaps[i].second, sentence_id, 0, 0, 0, 0, 0,
+ false)) {
+ // Gap fails integrity check.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+void RuleExtractor::AddExtracts(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const {
+ vector<pair<Phrase, PhraseAlignment> > target_phrases = ExtractTargetPhrases(
+ target_gaps, target_low, target_phrase_low, target_phrase_high,
+ sentence_id);
+
+ if (target_phrases.size() > 0) {
+ double pairs_count = 1.0 / target_phrases.size();
+ for (auto target_phrase: target_phrases) {
+ extracts.push_back(Extract(source_phrase, target_phrase.first,
+ pairs_count, target_phrase.second));
+ }
+ }
+}
+
+vector<pair<Phrase, PhraseAlignment> > RuleExtractor::ExtractTargetPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const {
+ int target_sent_len = target_data_array->GetSentenceStart(sentence_id + 1) -
+ target_data_array->GetSentenceStart(sentence_id) - 1;
+
+ vector<int> target_gap_order(target_gaps.size());
+ for (size_t i = 0; i < target_gap_order.size(); ++i) {
+ for (size_t j = 0; j < i; ++j) {
+ if (target_gaps[target_gap_order[j]] < target_gaps[i]) {
+ ++target_gap_order[i];
+ } else {
+ ++target_gap_order[j];
+ }
+ }
+ }
+
+ int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
+ if (!require_tight_phrases) {
+ while (target_x_low > 0 &&
+ target_phrase_high - target_x_low < max_rule_span &&
+ target_low[target_x_low - 1] == -1) {
+ --target_x_low;
+ }
+ while (target_x_high + 1 < target_sent_len &&
+ target_x_high - target_phrase_low < max_rule_span &&
+ target_low[target_x_high + 1] == -1) {
+ ++target_x_high;
+ }
+ }
+
+ vector<pair<int, int> > gaps(target_gaps.size());
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ gaps[i] = target_gaps[target_gap_order[i]];
+ if (!require_tight_phrases) {
+ while (gaps[i].first > target_x_low &&
+ target_low[gaps[i].first] == -1) {
+ --gaps[i].first;
+ }
+ while (gaps[i].second < target_x_high &&
+ target_low[gaps[i].second] == -1) {
+ ++gaps[i].second;
+ }
+ }
+ }
+
+ vector<pair<int, int> > ranges(2 * gaps.size() + 2);
+ ranges.front() = make_pair(target_x_low, target_phrase_low);
+ ranges.back() = make_pair(target_phrase_high, target_x_high);
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[i].first);
+ ranges[i * 2 + 2] = make_pair(target_gaps[i].second, gaps[i].second);
+ }
+
+ vector<pair<Phrase, PhraseAlignment> > target_phrases;
+ vector<int> subpatterns(ranges.size());
+ GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order,
+ target_phrase_low, target_phrase_high, sentence_id);
+ return target_phrases;
+}
+
+void RuleExtractor::GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns,
+ const vector<int>& target_gap_order, int target_phrase_low,
+ int target_phrase_high, int sentence_id) const {
+ if (index >= ranges.size()) {
+ if (subpatterns.back() - subpatterns.front() > max_rule_span) {
+ return;
+ }
+
+ vector<int> symbols;
+ unordered_set<int> target_indexes;
+ int offset = 1;
+ if (subpatterns.front() != target_phrase_low) {
+ offset = 2;
+ symbols.push_back(vocabulary->GetNonterminalIndex(1));
+ }
+
+ int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
+ for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
+ symbols.push_back(target_data_array->AtIndex(target_sent_start + j));
+ target_indexes.insert(j);
+ }
+ if (i < target_gap_order.size()) {
+ symbols.push_back(vocabulary->GetNonterminalIndex(
+ target_gap_order[i] + offset));
+ }
+ }
+
+ if (subpatterns.back() != target_phrase_high) {
+ symbols.push_back(target_gap_order.size() + offset);
+ }
+
+ const vector<pair<int, int> >& links = alignment->GetLinks(sentence_id);
+ vector<pair<int, int> > alignment;
+ for (pair<int, int> link: links) {
+ if (target_indexes.count(link.second)) {
+ alignment.push_back(link);
+ }
+ }
+
+ target_phrases.push_back(make_pair(phrase_builder->Build(symbols),
+ alignment));
+ return;
+ }
+
+ subpatterns[index] = ranges[index].first;
+ if (index > 0) {
+ subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
+ }
+ while (subpatterns[index] <= ranges[index].second) {
+ subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
+ while (subpatterns[index + 1] <= ranges[index + 1].second) {
+ GeneratePhrases(target_phrases, ranges, index + 2, subpatterns,
+ target_gap_order, target_phrase_low, target_phrase_high,
+ sentence_id);
+ ++subpatterns[index + 1];
+ }
+ ++subpatterns[index];
+ }
+}
+
+void RuleExtractor::AddNonterminalExtremities(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high, const vector<pair<int, int> >& target_gaps,
+ int sentence_id, int extend_left, int extend_right) const {
+ int source_x_low = source_phrase_low, source_x_high = source_phrase_high;
+ if (extend_left) {
+ if (source_back_low != source_phrase_low ||
+ source_phrase_low < min_gap_size ||
+ (require_tight_phrases && (source_low[source_phrase_low - 1] == -1 ||
+ source_low[source_back_high - 1] == -1))) {
+ return;
+ }
+
+ source_x_low = source_phrase_low - min_gap_size;
+ if (require_tight_phrases) {
+ while (source_x_low >= 0 && source_low[source_x_low] == -1) {
+ --source_x_low;
+ }
+ }
+ if (source_x_low < 0) {
+ return;
+ }
+ }
+
+ if (extend_right) {
+ int source_sent_len = source_data_array->GetSentenceStart(sentence_id + 1) -
+ source_data_array->GetSentenceStart(sentence_id) - 1;
+ if (source_back_high != source_phrase_high ||
+ source_phrase_high + min_gap_size > source_sent_len ||
+ (require_tight_phrases && (source_low[source_phrase_low] == -1 ||
+ source_low[source_phrase_high] == -1))) {
+ return;
+ }
+ source_x_high = source_phrase_high + min_gap_size;
+ if (require_tight_phrases) {
+ while (source_x_high <= source_sent_len &&
+ source_low[source_x_high - 1] == -1) {
+ ++source_x_high;
+ }
+ }
+
+ if (source_x_high > source_sent_len) {
+ return;
+ }
+ }
+
+ if (source_x_high - source_x_low > max_rule_span ||
+ target_gaps.size() + extend_left + extend_right > max_nonterminals) {
+ return;
+ }
+
+ int target_x_low = -1, target_x_high = -1;
+ if (!FindFixPoint(source_x_low, source_x_high, source_low, source_high,
+ target_x_low, target_x_high, target_low, target_high,
+ source_x_low, source_x_high, sentence_id, 1, 1,
+ extend_left + extend_right, extend_left, extend_right,
+ true)) {
+ return;
+ }
+
+ int source_gap_low = -1, source_gap_high = -1, target_gap_low = -1,
+ target_gap_high = -1;
+ if (extend_left &&
+ ((require_tight_phrases && source_low[source_x_low] == -1) ||
+ !FindFixPoint(source_x_low, source_phrase_low, source_low, source_high,
+ target_gap_low, target_gap_high, target_low, target_high,
+ source_gap_low, source_gap_high, sentence_id,
+ 0, 0, 0, 0, 0, false))) {
+ return;
+ }
+ if (extend_right &&
+ ((require_tight_phrases && source_low[source_x_high - 1] == -1) ||
+ !FindFixPoint(source_phrase_high, source_x_high, source_low, source_high,
+ target_gap_low, target_gap_high, target_low, target_high,
+ source_gap_low, source_gap_high, sentence_id,
+ 0, 0, 0, 0, 0, false))) {
+ return;
+ }
+
+ Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left,
+ extend_right);
+ AddExtracts(extracts, new_source_phrase, target_gaps, target_low,
+ target_x_low, target_x_high, sentence_id);
}
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h
index 13b5447a..f668de24 100644
--- a/extractor/rule_extractor.h
+++ b/extractor/rule_extractor.h
@@ -2,21 +2,129 @@
#define _RULE_EXTRACTOR_H_
#include <memory>
+#include <vector>
+
+#include "phrase.h"
using namespace std;
class Alignment;
class DataArray;
-class SuffixArray;
+class PhraseBuilder;
+class PhraseLocation;
+class Rule;
+class Scorer;
+class Vocabulary;
+
+typedef vector<pair<int, int> > PhraseAlignment;
+
+struct Extract {
+ Extract(const Phrase& source_phrase, const Phrase& target_phrase,
+ double pairs_count, const PhraseAlignment& alignment) :
+ source_phrase(source_phrase), target_phrase(target_phrase),
+ pairs_count(pairs_count), alignment(alignment) {}
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ double pairs_count;
+ PhraseAlignment alignment;
+};
class RuleExtractor {
public:
- RuleExtractor(
- shared_ptr<SuffixArray> source_suffix_array,
- shared_ptr<DataArray> target_data_array,
- const Alignment& alingment);
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alingment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases);
+
+ vector<Rule> ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const;
+
+ private:
+ vector<Extract> ExtractAlignments(const Phrase& phrase,
+ const vector<int>& matching) const;
+
+ void GetLinksSpans(vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high,
+ int sentence_id) const;
+
+ bool CheckAlignedTerminals(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const;
+
+ bool CheckTightPhrases(const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const;
+
+ bool FindFixPoint(
+ int source_phrase_start, int source_phrase_end,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_start, int& target_phrase_end,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, int max_low_x, int max_high_x,
+ bool allow_arbitrary_expansion) const;
+
+ void FindProjection(
+ int source_phrase_start, int source_phrase_end,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_end) const;
+
+ bool CheckGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int& num_symbols, bool& met_constraints) const;
+
+ void AddExtracts(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const;
+
+ vector<pair<Phrase, PhraseAlignment> > ExtractTargetPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const;
+
+ void GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index,
+ vector<int>& subpatterns, const vector<int>& target_gap_order,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const;
+
+ void AddNonterminalExtremities(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high,
+ const vector<pair<int, int> >& target_gaps, int sentence_id,
+ int extend_left, int extend_right) const;
- void ExtractRules();
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<Alignment> alignment;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<Scorer> scorer;
+ shared_ptr<Vocabulary> vocabulary;
+ int max_rule_span;
+ int min_gap_size;
+ int max_nonterminals;
+ int max_rule_symbols;
+ bool require_aligned_terminal;
+ bool require_aligned_chunks;
+ bool require_tight_phrases;
};
#endif
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 7a8356b8..c22f9b48 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -5,8 +5,15 @@
#include <queue>
#include <vector>
+#include "grammar.h"
+#include "intersector.h"
+#include "matchings_finder.h"
#include "matching_comparator.h"
#include "phrase.h"
+#include "rule.h"
+#include "rule_extractor.h"
+#include "sampler.h"
+#include "scorer.h"
#include "suffix_array.h"
#include "vocabulary.h"
@@ -30,28 +37,39 @@ struct State {
HieroCachingRuleFactory::HieroCachingRuleFactory(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment,
+ shared_ptr<Alignment> alignment,
const shared_ptr<Vocabulary>& vocabulary,
- const Precomputation& precomputation,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
int max_rule_symbols,
- bool use_baeza_yates) :
- matchings_finder(source_suffix_array),
- intersector(vocabulary, precomputation, source_suffix_array,
- make_shared<MatchingComparator>(min_gap_size, max_rule_span),
- use_baeza_yates),
- phrase_builder(vocabulary),
- rule_extractor(source_suffix_array, target_data_array, alignment),
+ int max_samples,
+ bool use_baeza_yates,
+ bool require_tight_phrases) :
vocabulary(vocabulary),
+ scorer(scorer),
min_gap_size(min_gap_size),
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_nonterminals + 1),
- max_rule_symbols(max_rule_symbols) {}
+ max_rule_symbols(max_rule_symbols) {
+ matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
+ shared_ptr<MatchingComparator> comparator =
+ make_shared<MatchingComparator>(min_gap_size, max_rule_span);
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ source_suffix_array, comparator, use_baeza_yates);
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
+ target_data_array, alignment, phrase_builder, scorer, vocabulary,
+ max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
+ false, require_tight_phrases);
+ sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+}
+
-void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
// Clear cache for every new sentence.
trie.Reset();
shared_ptr<TrieNode> root = trie.GetRoot();
@@ -69,6 +87,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
vector<int>(1, i), x_root, true));
}
+ vector<Rule> rules;
while (!states.empty()) {
State state = states.front();
states.pop();
@@ -77,7 +96,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
vector<int> phrase = state.phrase;
int word_id = word_ids[state.end];
phrase.push_back(word_id);
- Phrase next_phrase = phrase_builder.Build(phrase);
+ Phrase next_phrase = phrase_builder->Build(phrase);
shared_ptr<TrieNode> next_node;
if (CannotHaveMatchings(node, word_id)) {
@@ -98,14 +117,14 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
- phrase_location = intersector.Intersect(
+ phrase_location = intersector->Intersect(
node->phrase,
node->matchings,
next_suffix_link->phrase,
next_suffix_link->matchings,
next_phrase);
} else {
- phrase_location = matchings_finder.Find(
+ phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
@@ -125,7 +144,10 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
state.starts_with_x);
if (!state.starts_with_x) {
- rule_extractor.ExtractRules();
+ PhraseLocation sample = sampler->Sample(next_node->matchings);
+ vector<Rule> new_rules =
+ rule_extractor->ExtractRules(next_phrase, sample);
+ rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
} else {
next_node = node->GetChild(word_id);
@@ -137,6 +159,8 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
states.push(new_state);
}
}
+
+ return Grammar(rules, scorer->GetFeatureNames());
}
bool HieroCachingRuleFactory::CannotHaveMatchings(
@@ -165,7 +189,7 @@ void HieroCachingRuleFactory::AddTrailingNonterminal(
int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1);
symbols.push_back(var_id);
- Phrase var_phrase = phrase_builder.Build(symbols);
+ Phrase var_phrase = phrase_builder->Build(symbols);
int suffix_var_id = vocabulary->GetNonterminalIndex(
prefix.Arity() + starts_with_x == 0);
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index 8fe8bf30..a47b6d16 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -4,17 +4,21 @@
#include <memory>
#include <vector>
-#include "matchings_finder.h"
-#include "intersector.h"
#include "matchings_trie.h"
#include "phrase_builder.h"
-#include "rule_extractor.h"
using namespace std;
class Alignment;
class DataArray;
+class Grammar;
+class MatchingsFinder;
+class Intersector;
class Precomputation;
+class Rule;
+class RuleExtractor;
+class Sampler;
+class Scorer;
class State;
class SuffixArray;
class Vocabulary;
@@ -24,16 +28,19 @@ class HieroCachingRuleFactory {
HieroCachingRuleFactory(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment,
+ shared_ptr<Alignment> alignment,
const shared_ptr<Vocabulary>& vocabulary,
- const Precomputation& precomputation,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
int max_rule_symbols,
- bool use_beaza_yates);
+ int max_samples,
+ bool use_beaza_yates,
+ bool require_tight_phrases);
- void GetGrammar(const vector<int>& word_ids);
+ Grammar GetGrammar(const vector<int>& word_ids);
private:
bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id);
@@ -51,12 +58,14 @@ class HieroCachingRuleFactory {
const Phrase& phrase,
const shared_ptr<TrieNode>& node);
- MatchingsFinder matchings_finder;
- Intersector intersector;
+ shared_ptr<MatchingsFinder> matchings_finder;
+ shared_ptr<Intersector> intersector;
MatchingsTrie trie;
- PhraseBuilder phrase_builder;
- RuleExtractor rule_extractor;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<RuleExtractor> rule_extractor;
shared_ptr<Vocabulary> vocabulary;
+ shared_ptr<Sampler> sampler;
+ shared_ptr<Scorer> scorer;
int min_gap_size;
int max_rule_span;
int max_nonterminals;
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 4f841864..37a9cba0 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -1,16 +1,31 @@
+#include <fstream>
#include <iostream>
#include <string>
+#include <vector>
+#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include "alignment.h"
#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
#include "grammar_extractor.h"
#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
#include "suffix_array.h"
#include "translation_table.h"
+namespace fs = boost::filesystem;
namespace po = boost::program_options;
using namespace std;
@@ -23,21 +38,26 @@ int main(int argc, char** argv) {
("target,e", po::value<string>(), "Target language corpus")
("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
("alignment,a", po::value<string>()->required(), "Bitext word alignment")
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
("frequent", po::value<int>()->default_value(100),
"Number of precomputed frequent patterns")
("super_frequent", po::value<int>()->default_value(10),
"Number of precomputed super frequent patterns")
- ("max_rule_span,s", po::value<int>()->default_value(15),
+ ("max_rule_span", po::value<int>()->default_value(15),
"Maximum rule span")
("max_rule_symbols,l", po::value<int>()->default_value(5),
"Maximum number of symbols (terminals + nontermals) in a rule")
- ("min_gap_size,g", po::value<int>()->default_value(1), "Minimum gap size")
- ("max_phrase_len,p", po::value<int>()->default_value(4),
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_phrase_len", po::value<int>()->default_value(4),
"Maximum frequent phrase length")
("max_nonterminals", po::value<int>()->default_value(2),
"Maximum number of nonterminals in a rule")
("min_frequency", po::value<int>()->default_value(1000),
"Minimum number of occurences for a pharse to be considered frequent")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)")
("baeza_yates", po::value<bool>()->default_value(true),
"Use double binary search");
@@ -74,9 +94,10 @@ int main(int argc, char** argv) {
make_shared<SuffixArray>(source_data_array);
- Alignment alignment(vm["alignment"].as<string>());
+ shared_ptr<Alignment> alignment =
+ make_shared<Alignment>(vm["alignment"].as<string>());
- Precomputation precomputation(
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
source_suffix_array,
vm["frequent"].as<int>(),
vm["super_frequent"].as<int>(),
@@ -86,7 +107,19 @@ int main(int argc, char** argv) {
vm["max_phrase_len"].as<int>(),
vm["min_frequency"].as<int>());
- TranslationTable table(source_data_array, target_data_array, alignment);
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+
+ vector<shared_ptr<Feature> > features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
// TODO(pauldb): Add parallelization.
GrammarExtractor extractor(
@@ -94,15 +127,34 @@ int main(int argc, char** argv) {
target_data_array,
alignment,
precomputation,
+ scorer,
vm["min_gap_size"].as<int>(),
vm["max_rule_span"].as<int>(),
vm["max_nonterminals"].as<int>(),
vm["max_rule_symbols"].as<int>(),
- vm["baeza_yates"].as<bool>());
+ vm["max_samples"].as<int>(),
+ vm["baeza_yates"].as<bool>(),
+ vm["tight_phrases"].as<bool>());
- string sentence;
+ int grammar_id = 0;
+ fs::path grammar_path = vm["grammars"].as<string>();
+ string sentence, delimiter = "|||";
while (getline(cin, sentence)) {
- extractor.GetGrammar(sentence);
+ string suffix = "";
+ int position = sentence.find(delimiter);
+ if (position != sentence.npos) {
+ suffix = sentence.substr(position);
+ sentence = sentence.substr(0, position);
+ }
+
+ Grammar grammar = extractor.GetGrammar(sentence);
+ fs::path grammar_file = grammar_path / to_string(grammar_id);
+ ofstream output(grammar_file.c_str());
+ output << grammar;
+
+ cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << grammar_id
+ << "\"> " << sentence << " </seg> " << suffix << endl;
+ ++grammar_id;
}
return 0;
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
new file mode 100644
index 00000000..d8e0f49e
--- /dev/null
+++ b/extractor/sampler.cc
@@ -0,0 +1,36 @@
+#include "sampler.h"
+
+#include "phrase_location.h"
+#include "suffix_array.h"
+
+Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) :
+ suffix_array(suffix_array), max_samples(max_samples) {}
+
+PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
+ vector<int> sample;
+ int num_subpatterns;
+ if (location.matchings == NULL) {
+ num_subpatterns = 1;
+ int low = location.sa_low, high = location.sa_high;
+ double step = max(1.0, (double) (high - low) / max_samples);
+ for (double i = low; i < high && sample.size() < max_samples; i += step) {
+ sample.push_back(suffix_array->GetSuffix(Round(i)));
+ }
+ } else {
+ num_subpatterns = location.num_subpatterns;
+ int num_matchings = location.matchings->size() / num_subpatterns;
+ double step = max(1.0, (double) num_matchings / max_samples);
+ for (double i = 0, num_samples = 0;
+ i < num_matchings && num_samples < max_samples;
+ i += step, ++num_samples) {
+ int start = Round(i) * num_subpatterns;
+ sample.insert(sample.end(), location.matchings->begin() + start,
+ location.matchings->begin() + start + num_subpatterns);
+ }
+ }
+ return PhraseLocation(sample, num_subpatterns);
+}
+
+int Sampler::Round(double x) const {
+ return x + 0.5;
+}
diff --git a/extractor/sampler.h b/extractor/sampler.h
new file mode 100644
index 00000000..3b3e3a4d
--- /dev/null
+++ b/extractor/sampler.h
@@ -0,0 +1,24 @@
+#ifndef _SAMPLER_H_
+#define _SAMPLER_H_
+
+#include <memory>
+
+using namespace std;
+
+class PhraseLocation;
+class SuffixArray;
+
+class Sampler {
+ public:
+ Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
+
+ PhraseLocation Sample(const PhraseLocation& location) const;
+
+ private:
+ int Round(double x) const;
+
+ shared_ptr<SuffixArray> suffix_array;
+ int max_samples;
+};
+
+#endif
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
new file mode 100644
index 00000000..4f91965b
--- /dev/null
+++ b/extractor/sampler_test.cc
@@ -0,0 +1,72 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_suffix_array.h"
+#include "phrase_location.h"
+#include "sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class SamplerTest : public Test {
+ protected:
+ virtual void SetUp() {
+ suffix_array = make_shared<MockSuffixArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<Sampler> sampler;
+};
+
+TEST_F(SamplerTest, TestSuffixArrayRange) {
+ PhraseLocation location(0, 10);
+
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ vector<int> expected_locations = {0};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {0, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {0, 3, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 4);
+ expected_locations = {0, 3, 5, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+}
+
+TEST_F(SamplerTest, TestSubstringsSample) {
+ vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ PhraseLocation location(locations, 2);
+
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ vector<int> expected_locations = {0, 1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {0, 1, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {0, 1, 4, 5, 6, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+
+ sampler = make_shared<Sampler>(suffix_array, 7);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+}
+
+} // namespace
diff --git a/extractor/scorer.cc b/extractor/scorer.cc
index 22d5be1a..c87e179d 100644
--- a/extractor/scorer.cc
+++ b/extractor/scorer.cc
@@ -1,9 +1,22 @@
#include "scorer.h"
-Scorer::Scorer(const vector<Feature*>& features) : features(features) {}
+#include "features/feature.h"
-Scorer::~Scorer() {
- for (Feature* feature: features) {
- delete feature;
+Scorer::Scorer(const vector<shared_ptr<Feature> >& features) :
+ features(features) {}
+
+vector<double> Scorer::Score(const FeatureContext& context) const {
+ vector<double> scores;
+ for (auto feature: features) {
+ scores.push_back(feature->Score(context));
+ }
+ return scores;
+}
+
+vector<string> Scorer::GetFeatureNames() const {
+ vector<string> feature_names;
+ for (auto feature: features) {
+ feature_names.push_back(feature->GetName());
}
+ return feature_names;
}
diff --git a/extractor/scorer.h b/extractor/scorer.h
index 57405a6c..5b328fb4 100644
--- a/extractor/scorer.h
+++ b/extractor/scorer.h
@@ -1,19 +1,25 @@
#ifndef _SCORER_H_
#define _SCORER_H_
+#include <memory>
+#include <string>
#include <vector>
-#include "features/feature.h"
-
using namespace std;
+class Feature;
+class FeatureContext;
+
class Scorer {
public:
- Scorer(const vector<Feature*>& features);
- ~Scorer();
+ Scorer(const vector<shared_ptr<Feature> >& features);
+
+ vector<double> Score(const FeatureContext& context) const;
+
+ vector<string> GetFeatureNames() const;
private:
- vector<Feature*> features;
+ vector<shared_ptr<Feature> > features;
};
#endif
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index 76f00ace..d13eacd5 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -15,6 +15,8 @@ SuffixArray::SuffixArray(shared_ptr<DataArray> data_array) :
BuildSuffixArray();
}
+SuffixArray::SuffixArray() {}
+
SuffixArray::~SuffixArray() {}
void SuffixArray::BuildSuffixArray() {
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
index 7708f5a2..79a22694 100644
--- a/extractor/suffix_array.h
+++ b/extractor/suffix_array.h
@@ -21,17 +21,20 @@ class SuffixArray {
virtual int GetSize() const;
- shared_ptr<DataArray> GetData() const;
+ virtual shared_ptr<DataArray> GetData() const;
- vector<int> BuildLCPArray() const;
+ virtual vector<int> BuildLCPArray() const;
- int GetSuffix(int rank) const;
+ virtual int GetSuffix(int rank) const;
virtual PhraseLocation Lookup(int low, int high, const string& word,
int offset) const;
void WriteBinary(const fs::path& filepath) const;
+ protected:
+ SuffixArray();
+
private:
void BuildSuffixArray();
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 5eb4ffdc..10f1b9ed 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -13,17 +13,17 @@ using namespace tr1;
TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment) :
+ shared_ptr<Alignment> alignment) :
source_data_array(source_data_array), target_data_array(target_data_array) {
const vector<int>& source_data = source_data_array->GetData();
const vector<int>& target_data = target_data_array->GetData();
unordered_map<int, int> source_links_count;
unordered_map<int, int> target_links_count;
- unordered_map<pair<int, int>, int, boost::hash<pair<int, int> > > links_count;
+ unordered_map<pair<int, int>, int, PairHash > links_count;
for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) {
- vector<pair<int, int> > links = alignment.GetLinks(i);
+ const vector<pair<int, int> >& links = alignment->GetLinks(i);
int source_start = source_data_array->GetSentenceStart(i);
int next_source_start = source_data_array->GetSentenceStart(i + 1);
int target_start = target_data_array->GetSentenceStart(i);
@@ -58,7 +58,7 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
}
}
-double TranslationTable::GetEgivenFScore(
+double TranslationTable::GetTargetGivenSourceScore(
const string& source_word, const string& target_word) {
if (!source_data_array->HasWord(source_word) ||
!target_data_array->HasWord(target_word)) {
@@ -70,7 +70,7 @@ double TranslationTable::GetEgivenFScore(
return translation_probabilities[make_pair(source_id, target_id)].first;
}
-double TranslationTable::GetFgivenEScore(
+double TranslationTable::GetSourceGivenTargetScore(
const string& source_word, const string& target_word) {
if (!source_data_array->HasWord(source_word) ||
!target_data_array->HasWord(target_word) == 0) {
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index 6004eca0..acf94af7 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -15,24 +15,28 @@ namespace fs = boost::filesystem;
class Alignment;
class DataArray;
+typedef boost::hash<pair<int, int> > PairHash;
+
class TranslationTable {
public:
TranslationTable(
shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment);
+ shared_ptr<Alignment> alignment);
- double GetEgivenFScore(const string& source_word, const string& target_word);
+ double GetTargetGivenSourceScore(const string& source_word,
+ const string& target_word);
- double GetFgivenEScore(const string& source_word, const string& target_word);
+ double GetSourceGivenTargetScore(const string& source_word,
+ const string& target_word);
void WriteBinary(const fs::path& filepath) const;
private:
- shared_ptr<DataArray> source_data_array;
- shared_ptr<DataArray> target_data_array;
- unordered_map<pair<int, int>, pair<double, double>,
- boost::hash<pair<int, int> > > translation_probabilities;
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<DataArray> target_data_array;
+ unordered_map<pair<int, int>, pair<double, double>, PairHash>
+ translation_probabilities;
};
#endif
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index 05744269..ed55e5e4 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -12,7 +12,7 @@ class Vocabulary {
public:
virtual ~Vocabulary();
- int GetTerminalIndex(const string& word);
+ virtual int GetTerminalIndex(const string& word);
int GetNonterminalIndex(int position);