summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
Diffstat (limited to 'extractor')
-rw-r--r--extractor/alignment.cc2
-rw-r--r--extractor/alignment.h6
-rw-r--r--extractor/compile.cc2
-rw-r--r--extractor/data_array.h36
-rw-r--r--extractor/fast_intersector.cc3
-rw-r--r--extractor/fast_intersector.h27
-rw-r--r--extractor/features/count_source_target.h3
-rw-r--r--extractor/features/feature.h6
-rw-r--r--extractor/features/is_source_singleton.h3
-rw-r--r--extractor/features/is_source_target_singleton.h3
-rw-r--r--extractor/features/max_lex_source_given_target.h3
-rw-r--r--extractor/features/max_lex_target_given_source.h3
-rw-r--r--extractor/features/sample_source_count.h4
-rw-r--r--extractor/features/target_given_source_coherent.h4
-rw-r--r--extractor/grammar.h3
-rw-r--r--extractor/grammar_extractor.h8
-rw-r--r--extractor/matchings_finder.h5
-rw-r--r--extractor/matchings_trie.h12
-rw-r--r--extractor/phrase.h10
-rw-r--r--extractor/phrase_builder.h5
-rw-r--r--extractor/phrase_location.h12
-rw-r--r--extractor/precomputation.cc11
-rw-r--r--extractor/precomputation.h21
-rw-r--r--extractor/rule.h3
-rw-r--r--extractor/rule_extractor.cc21
-rw-r--r--extractor/rule_extractor.h16
-rw-r--r--extractor/rule_extractor_helper.cc11
-rw-r--r--extractor/rule_extractor_helper.h13
-rw-r--r--extractor/rule_factory.cc13
-rw-r--r--extractor/rule_factory.h20
-rw-r--r--extractor/run_extractor.cc27
-rw-r--r--extractor/sampler.cc2
-rw-r--r--extractor/sampler.h5
-rw-r--r--extractor/scorer.h5
-rw-r--r--extractor/suffix_array.h17
-rw-r--r--extractor/target_phrase_extractor.cc10
-rw-r--r--extractor/target_phrase_extractor.h4
-rw-r--r--extractor/time_util.h1
-rw-r--r--extractor/translation_table.cc17
-rw-r--r--extractor/translation_table.h8
-rw-r--r--extractor/vocabulary.h17
41 files changed, 383 insertions, 19 deletions
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index f9bbcf6a..1aea34b3 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -28,8 +28,6 @@ Alignment::Alignment(const string& filename) {
}
alignments.push_back(alignment);
}
- // Note: shrink_to_fit does nothing for vector<vector<string> > on g++ 4.6.3,
- // but let's hope that the bug will be fixed in a newer version.
alignments.shrink_to_fit();
}
diff --git a/extractor/alignment.h b/extractor/alignment.h
index ef89dc0c..e9292121 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -11,12 +11,18 @@ using namespace std;
namespace extractor {
+/**
+ * Data structure storing the word alignments for a parallel corpus.
+ */
class Alignment {
public:
+ // Reads alignment from text file.
Alignment(const string& filename);
+ // Returns the alignment for a given sentence.
virtual vector<pair<int, int> > GetLinks(int sentence_index) const;
+ // Writes alignment to file in binary format.
void WriteBinary(const fs::path& filepath);
virtual ~Alignment();
diff --git a/extractor/compile.cc b/extractor/compile.cc
index 7062ef03..a9ae2cef 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -37,7 +37,7 @@ int main(int argc, char** argv) {
("max_phrase_len,p", po::value<int>()->default_value(4),
"Maximum frequent phrase length")
("min_frequency", po::value<int>()->default_value(1000),
- "Minimum number of occurences for a pharse to be considered frequent");
+ "Minimum number of occurrences for a pharse to be considered frequent");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
diff --git a/extractor/data_array.h b/extractor/data_array.h
index a26bbecf..978a6931 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -17,9 +17,19 @@ enum Side {
TARGET
};
-// Note: This class has features for both the source and target data arrays.
-// Maybe we can save some memory by having more specific implementations (e.g.
-// sentence_id is only needed for the source data array).
+/**
+ * Data structure storing information about a single side of a parallel corpus.
+ *
+ * Each word is mapped to a unique integer (word_id). The data structure holds
+ * the corpus in the numberized format, together with the hash table mapping
+ * words to word_ids. It also holds additional information such as the starting
+ * index for each sentence and, for each token, the index of the sentence it
+ * belongs to.
+ *
+ * Note: This class has features for both the source and target data arrays.
+ * Maybe we can save some memory by having more specific implementations (not
+ * likely to save a lot of memory tough).
+ */
class DataArray {
public:
static int NULL_WORD;
@@ -27,45 +37,65 @@ class DataArray {
static string NULL_WORD_STR;
static string END_OF_LINE_STR;
+ // Reads data array from text file.
DataArray(const string& filename);
+ // Reads data array from bitext file where the sentences are separated by |||.
DataArray(const string& filename, const Side& side);
virtual ~DataArray();
+ // Returns a vector containing the word ids.
virtual const vector<int>& GetData() const;
+ // Returns the word id at the specified position.
virtual int AtIndex(int index) const;
+ // Returns the original word at the specified position.
virtual string GetWordAtIndex(int index) const;
+ // Returns the size of the data array.
virtual int GetSize() const;
+ // Returns the number of distinct words in the data array.
virtual int GetVocabularySize() const;
+ // Returns whether a word has ever been observed in the data array.
virtual bool HasWord(const string& word) const;
+ // Returns the word id for a given word or -1 if it the word has never been
+ // observed.
virtual int GetWordId(const string& word) const;
+ // Returns the word corresponding to a particular word id.
virtual string GetWord(int word_id) const;
+ // Returns the number of sentences in the data.
virtual int GetNumSentences() const;
+ // Returns the index where the sentence containing the given position starts.
virtual int GetSentenceStart(int position) const;
+ // Returns the length of the sentence.
virtual int GetSentenceLength(int sentence_id) const;
+ // Returns the number of the sentence containing the given position.
virtual int GetSentenceId(int position) const;
+ // Writes data array to file in binary format.
void WriteBinary(const fs::path& filepath) const;
+ // Writes data array to file in binary format.
void WriteBinary(FILE* file) const;
protected:
DataArray();
private:
+ // Sets up specific constants.
void InitializeDataArray();
+
+ // Constructs the data array.
void CreateDataArray(const vector<string>& lines);
unordered_map<string, int> word2id;
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
index 1b8c32b1..2a7693b2 100644
--- a/extractor/fast_intersector.cc
+++ b/extractor/fast_intersector.cc
@@ -107,6 +107,7 @@ PhraseLocation FastIntersector::ExtendPrefixPhraseLocation(
} else {
pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2;
}
+ // Searches for the last symbol in the phrase after each prefix occurrence.
for (int j = range.first; j < range.second; ++j) {
if (pattern_end >= sent_end ||
pattern_end - positions[i] >= max_rule_span) {
@@ -149,6 +150,8 @@ PhraseLocation FastIntersector::ExtendSuffixPhraseLocation(
int pattern_start = positions[i] - range.first;
int pattern_end = positions[i + num_subpatterns - 1] +
phrase.GetChunkLen(phrase.Arity()) - 1;
+ // Searches for the first symbol in the phrase before each suffix
+ // occurrence.
for (int j = range.first; j < range.second; ++j) {
if (pattern_start < sent_start ||
pattern_end - pattern_start >= max_rule_span) {
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
index 32c88a30..f950a2a9 100644
--- a/extractor/fast_intersector.h
+++ b/extractor/fast_intersector.h
@@ -20,6 +20,18 @@ class Precomputation;
class SuffixArray;
class Vocabulary;
+/**
+ * Component for searching the training data for occurrences of source phrases
+ * containing nonterminals
+ *
+ * Given a source phrase containing a nonterminal, we first query the
+ * precomputed index containing frequent collocations. If the phrase is not
+ * frequent enough, we extend the matchings of either its prefix or its suffix,
+ * depending on which operation seems to require less computations.
+ *
+ * Note: This method for intersecting phrase locations is faster than both
+ * mergers (linear or Baeza Yates) described in Adam Lopez' dissertation.
+ */
class FastIntersector {
public:
FastIntersector(shared_ptr<SuffixArray> suffix_array,
@@ -30,6 +42,8 @@ class FastIntersector {
virtual ~FastIntersector();
+ // Finds the locations of a phrase given the locations of its prefix and
+ // suffix.
virtual PhraseLocation Intersect(PhraseLocation& prefix_location,
PhraseLocation& suffix_location,
const Phrase& phrase);
@@ -38,23 +52,36 @@ class FastIntersector {
FastIntersector();
private:
+ // Uses the vocabulary to convert the phrase from the numberized format
+ // specified by the source data array to the numberized format given by the
+ // vocabulary.
vector<int> ConvertPhrase(const vector<int>& old_phrase);
+ // Estimates the number of computations needed if the prefix/suffix is
+ // extended. If the last/first symbol is separated from the rest of the phrase
+ // by a nonterminal, then for each occurrence of the prefix/suffix we need to
+ // check max_rule_span positions. Otherwise, we only need to check a single
+ // position for each occurrence.
int EstimateNumOperations(const PhraseLocation& phrase_location,
bool has_margin_x) const;
+ // Uses the occurrences of the prefix to find the occurrences of the phrase.
PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location,
const Phrase& phrase,
bool prefix_ends_with_x,
int next_symbol) const;
+ // Uses the occurrences of the suffix to find the occurrences of the phrase.
PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location,
const Phrase& phrase,
bool suffix_starts_with_x,
int prev_symbol) const;
+ // Extends the prefix/suffix location to a list of subpatterns positions if it
+ // represents a suffix array range.
void ExtendPhraseLocation(PhraseLocation& location) const;
+ // Returns the range in which the search should be performed.
pair<int, int> GetSearchRange(bool has_marginal_x) const;
shared_ptr<SuffixArray> suffix_array;
diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h
index dec78883..8747fa60 100644
--- a/extractor/features/count_source_target.h
+++ b/extractor/features/count_source_target.h
@@ -6,6 +6,9 @@
namespace extractor {
namespace features {
+/**
+ * Feature for the number of times a word pair was found in the bitext.
+ */
class CountSourceTarget : public Feature {
public:
double Score(const FeatureContext& context) const;
diff --git a/extractor/features/feature.h b/extractor/features/feature.h
index 6693ccbf..36ea504a 100644
--- a/extractor/features/feature.h
+++ b/extractor/features/feature.h
@@ -10,6 +10,9 @@ using namespace std;
namespace extractor {
namespace features {
+/**
+ * Structure providing context for computing feature scores.
+ */
struct FeatureContext {
FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase,
double source_phrase_count, int pair_count, int num_samples) :
@@ -24,6 +27,9 @@ struct FeatureContext {
int num_samples;
};
+/**
+ * Base class for features.
+ */
class Feature {
public:
virtual double Score(const FeatureContext& context) const = 0;
diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h
index 30f76c6d..b8352d0e 100644
--- a/extractor/features/is_source_singleton.h
+++ b/extractor/features/is_source_singleton.h
@@ -6,6 +6,9 @@
namespace extractor {
namespace features {
+/**
+ * Boolean feature checking if the source phrase occurs only once in the data.
+ */
class IsSourceSingleton : public Feature {
public:
double Score(const FeatureContext& context) const;
diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h
index 12fb6ee6..dacfebba 100644
--- a/extractor/features/is_source_target_singleton.h
+++ b/extractor/features/is_source_target_singleton.h
@@ -6,6 +6,9 @@
namespace extractor {
namespace features {
+/**
+ * Boolean feature checking if the phrase pair occurs only once in the data.
+ */
class IsSourceTargetSingleton : public Feature {
public:
double Score(const FeatureContext& context) const;
diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h
index bfa7ef1b..461b0ebf 100644
--- a/extractor/features/max_lex_source_given_target.h
+++ b/extractor/features/max_lex_source_given_target.h
@@ -13,6 +13,9 @@ class TranslationTable;
namespace features {
+/**
+ * Feature computing max(p(f | e)) across all pairs of words in the phrase pair.
+ */
class MaxLexSourceGivenTarget : public Feature {
public:
MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table);
diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h
index 66cf0914..c3c87327 100644
--- a/extractor/features/max_lex_target_given_source.h
+++ b/extractor/features/max_lex_target_given_source.h
@@ -13,6 +13,9 @@ class TranslationTable;
namespace features {
+/**
+ * Feature computing max(p(e | f)) across all pairs of words in the phrase pair.
+ */
class MaxLexTargetGivenSource : public Feature {
public:
MaxLexTargetGivenSource(shared_ptr<TranslationTable> table);
diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h
index 53c7f954..ee6e59a0 100644
--- a/extractor/features/sample_source_count.h
+++ b/extractor/features/sample_source_count.h
@@ -6,6 +6,10 @@
namespace extractor {
namespace features {
+/**
+ * Feature scoring the number of times the source phrase occurs in the sampled
+ * set.
+ */
class SampleSourceCount : public Feature {
public:
double Score(const FeatureContext& context) const;
diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h
index 80d9f617..e66d70a5 100644
--- a/extractor/features/target_given_source_coherent.h
+++ b/extractor/features/target_given_source_coherent.h
@@ -6,6 +6,10 @@
namespace extractor {
namespace features {
+/**
+ * Feature computing the ratio of the phrase pair count over all source phrase
+ * occurrences (sampled).
+ */
class TargetGivenSourceCoherent : public Feature {
public:
double Score(const FeatureContext& context) const;
diff --git a/extractor/grammar.h b/extractor/grammar.h
index a424d65a..fed41b16 100644
--- a/extractor/grammar.h
+++ b/extractor/grammar.h
@@ -11,6 +11,9 @@ namespace extractor {
class Rule;
+/**
+ * Grammar class wrapping the set of rules to be extracted.
+ */
class Grammar {
public:
Grammar(const vector<Rule>& rules, const vector<string>& feature_names);
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 6b1dcf98..b36ceeb9 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -19,6 +19,10 @@ class Scorer;
class SuffixArray;
class Vocabulary;
+/**
+ * Class wrapping all the logic for extracting the synchronous context free
+ * grammars.
+ */
class GrammarExtractor {
public:
GrammarExtractor(
@@ -38,11 +42,15 @@ class GrammarExtractor {
GrammarExtractor(shared_ptr<Vocabulary> vocabulary,
shared_ptr<HieroCachingRuleFactory> rule_factory);
+ // Converts the sentence to a vector of word ids and uses the RuleFactory to
+ // extract the SCFG rules which may be used to decode the sentence.
Grammar GetGrammar(const string& sentence);
private:
+ // Splits the sentence in a vector of words.
vector<string> TokenizeSentence(const string& sentence);
+ // Maps the words to word ids.
vector<int> AnnotateWords(const vector<string>& words);
shared_ptr<Vocabulary> vocabulary;
diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h
index fbb504ef..451f4a4c 100644
--- a/extractor/matchings_finder.h
+++ b/extractor/matchings_finder.h
@@ -11,12 +11,17 @@ namespace extractor {
class PhraseLocation;
class SuffixArray;
+/**
+ * Class wrapping the suffix array lookup for a contiguous phrase.
+ */
class MatchingsFinder {
public:
MatchingsFinder(shared_ptr<SuffixArray> suffix_array);
virtual ~MatchingsFinder();
+ // Uses the suffix array to search only for the last word of the phrase
+ // starting from the range in which the prefix of the phrase occurs.
virtual PhraseLocation Find(PhraseLocation& location, const string& word,
int offset);
diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h
index f3dcc075..1fb29693 100644
--- a/extractor/matchings_trie.h
+++ b/extractor/matchings_trie.h
@@ -11,20 +11,27 @@ using namespace std;
namespace extractor {
+/**
+ * Trie node containing all the occurrences of the corresponding phrase in the
+ * source data.
+ */
struct TrieNode {
TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(),
Phrase phrase = Phrase(),
PhraseLocation matchings = PhraseLocation()) :
suffix_link(suffix_link), phrase(phrase), matchings(matchings) {}
+ // Adds a trie node as a child of the current node.
void AddChild(int key, shared_ptr<TrieNode> child_node) {
children[key] = child_node;
}
+ // Checks if a child exists for a given key.
bool HasChild(int key) {
return children.count(key);
}
+ // Gets the child corresponding to the given key.
shared_ptr<TrieNode> GetChild(int key) {
return children[key];
}
@@ -35,15 +42,20 @@ struct TrieNode {
unordered_map<int, shared_ptr<TrieNode> > children;
};
+/**
+ * Trie containing all the phrases that can be obtained from a sentence.
+ */
class MatchingsTrie {
public:
MatchingsTrie();
virtual ~MatchingsTrie();
+ // Returns the root of the trie.
shared_ptr<TrieNode> GetRoot() const;
private:
+ // Recursively deletes a subtree of the trie.
void DeleteTree(shared_ptr<TrieNode> root);
shared_ptr<TrieNode> root;
diff --git a/extractor/phrase.h b/extractor/phrase.h
index 6521c438..a8e91e3c 100644
--- a/extractor/phrase.h
+++ b/extractor/phrase.h
@@ -11,20 +11,30 @@ using namespace std;
namespace extractor {
+/**
+ * Structure containing the data for a phrase.
+ */
class Phrase {
public:
friend Phrase PhraseBuilder::Build(const vector<int>& phrase);
+ // Returns the number of nonterminals in the phrase.
int Arity() const;
+ // Returns the number of terminals (length) for the given chunk. (A chunk is a
+ // contiguous sequence of terminals in the phrase).
int GetChunkLen(int index) const;
+ // Returns the symbols (word ids) marking up the phrase.
vector<int> Get() const;
+ // Returns the symbol located at the given position in the phrase.
int GetSymbol(int position) const;
+ // Returns the number of symbols in the phrase.
int GetNumSymbols() const;
+ // Returns the words making up the phrase. (Nonterminals are stripped out.)
vector<string> GetWords() const;
bool operator<(const Phrase& other) const;
diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h
index 2956fd35..de86dbae 100644
--- a/extractor/phrase_builder.h
+++ b/extractor/phrase_builder.h
@@ -11,12 +11,17 @@ namespace extractor {
class Phrase;
class Vocabulary;
+/**
+ * Component for constructing phrases.
+ */
class PhraseBuilder {
public:
PhraseBuilder(shared_ptr<Vocabulary> vocabulary);
+ // Constructs a phrase starting from an array of symbols.
Phrase Build(const vector<int>& symbols);
+ // Extends a phrase with a leading and/or trailing nonterminal.
Phrase Extend(const Phrase& phrase, bool start_x, bool end_x);
private:
diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h
index e5f3cf08..91950e03 100644
--- a/extractor/phrase_location.h
+++ b/extractor/phrase_location.h
@@ -8,13 +8,25 @@ using namespace std;
namespace extractor {
+/**
+ * Structure containing information about the occurrences of a phrase in the
+ * source data.
+ *
+ * Every consecutive (disjoint) group of num_subpatterns entries in matchings
+ * vector encodes an occurrence of the phrase. The i-th entry of a group
+ * represents the start of the i-th subpattern of the phrase. If the phrase
+ * doesn't contain any nonterminals, then it may also be represented as the
+ * range in the suffix array which matches the phrase.
+ */
struct PhraseLocation {
PhraseLocation(int sa_low = -1, int sa_high = -1);
PhraseLocation(const vector<int>& matchings, int num_subpatterns);
+ // Checks if a phrase has any occurrences in the source data.
bool IsEmpty() const;
+ // Returns the number of occurrences of a phrase in the source data.
int GetSize() const;
friend bool operator==(const PhraseLocation& a, const PhraseLocation& b);
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index 0fadc95c..b3906943 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -23,6 +23,8 @@ Precomputation::Precomputation(
suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,
min_frequency);
+ // Construct sets containing the frequent and superfrequent contiguous
+ // collocations.
unordered_set<vector<int>, VectorHash> frequent_patterns_set;
unordered_set<vector<int>, VectorHash> super_frequent_patterns_set;
for (size_t i = 0; i < frequent_patterns.size(); ++i) {
@@ -34,6 +36,8 @@ Precomputation::Precomputation(
vector<tuple<int, int, int> > matchings;
for (size_t i = 0; i < data.size(); ++i) {
+ // If the sentence is over, add all the discontiguous frequent patterns to
+ // the index.
if (data[i] == DataArray::END_OF_LINE) {
AddCollocations(matchings, data, max_rule_span, min_gap_size,
max_rule_symbols);
@@ -41,6 +45,7 @@ Precomputation::Precomputation(
continue;
}
vector<int> pattern;
+ // Find all the contiguous frequent patterns starting at position i.
for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {
pattern.push_back(data[i + j - 1]);
if (frequent_patterns_set.count(pattern)) {
@@ -65,6 +70,7 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns(
vector<int> lcp = suffix_array->BuildLCPArray();
vector<int> run_start(max_frequent_phrase_len);
+ // Find all the patterns occurring at least min_frequency times.
priority_queue<pair<int, pair<int, int> > > heap;
for (size_t i = 1; i < lcp.size(); ++i) {
for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {
@@ -77,6 +83,7 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns(
}
}
+ // Extract the most frequent patterns.
vector<vector<int> > frequent_patterns;
while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) {
int start = heap.top().second.first;
@@ -95,10 +102,12 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns(
void Precomputation::AddCollocations(
const vector<tuple<int, int, int> >& matchings, const vector<int>& data,
int max_rule_span, int min_gap_size, int max_rule_symbols) {
+ // Select the leftmost subpattern.
for (size_t i = 0; i < matchings.size(); ++i) {
int start1, size1, is_super1;
tie(start1, size1, is_super1) = matchings[i];
+ // Select the second (middle) subpattern
for (size_t j = i + 1; j < matchings.size(); ++j) {
int start2, size2, is_super2;
tie(start2, size2, is_super2) = matchings[j];
@@ -116,8 +125,10 @@ void Precomputation::AddCollocations(
data.begin() + start2 + size2);
AddStartPositions(collocations[pattern], start1, start2);
+ // Try extending the binary collocation to a ternary collocation.
if (is_super2) {
pattern.push_back(Precomputation::SECOND_NONTERMINAL);
+ // Select the rightmost subpattern.
for (size_t k = j + 1; k < matchings.size(); ++k) {
int start3, size3, is_super3;
tie(start3, size3, is_super3) = matchings[k];
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 2c1eccf8..e3c4d26a 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -20,8 +20,19 @@ typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
class SuffixArray;
+/**
+ * Data structure wrapping an index with all the occurrences of the most
+ * frequent discontiguous collocations in the source data.
+ *
+ * Let a, b, c be contiguous collocations. The index will contain an entry for
+ * every collocation of the form:
+ * - aXb, where a and b are frequent
+ * - aXbXc, where a and b are super-frequent and c is frequent or
+ * b and c are super-frequent and a is frequent.
+ */
class Precomputation {
public:
+ // Constructs the index using the suffix array.
Precomputation(
shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns,
int num_super_frequent_patterns, int max_rule_span,
@@ -32,6 +43,7 @@ class Precomputation {
void WriteBinary(const fs::path& filepath) const;
+ // Returns a reference to the index.
virtual const Index& GetCollocations() const;
static int FIRST_NONTERMINAL;
@@ -41,14 +53,23 @@ class Precomputation {
Precomputation();
private:
+ // Finds the most frequent contiguous collocations.
vector<vector<int> > FindMostFrequentPatterns(
shared_ptr<SuffixArray> suffix_array, const vector<int>& data,
int num_frequent_patterns, int max_frequent_phrase_len,
int min_frequency);
+
+ // Given the locations of the frequent contiguous collocations in a sentence,
+ // it adds new entries to the index for each discontiguous collocation
+ // matching the criteria specified in the class description.
void AddCollocations(
const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data,
int max_rule_span, int min_gap_size, int max_rule_symbols);
+
+ // Adds an occurrence of a binary collocation.
void AddStartPositions(vector<int>& positions, int pos1, int pos2);
+
+ // Adds an occurrence of a ternary collocation.
void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
Index collocations;
diff --git a/extractor/rule.h b/extractor/rule.h
index b4d45fc1..bc95709e 100644
--- a/extractor/rule.h
+++ b/extractor/rule.h
@@ -9,6 +9,9 @@ using namespace std;
namespace extractor {
+/**
+ * Structure containing the data for a SCFG rule.
+ */
struct Rule {
Rule(const Phrase& source_phrase, const Phrase& target_phrase,
const vector<double>& scores, const vector<pair<int, int> >& alignment);
diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc
index b9286472..9f5e8e00 100644
--- a/extractor/rule_extractor.cc
+++ b/extractor/rule_extractor.cc
@@ -79,6 +79,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
int num_subpatterns = location.num_subpatterns;
vector<int> matchings = *location.matchings;
+ // Calculate statistics for the (sampled) occurrences of the source phrase.
map<Phrase, double> source_phrase_counter;
map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter;
for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) {
@@ -91,6 +92,8 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase,
}
}
+ // Compute the feature scores and find the most likely (frequent) alignment
+ // for each pair of source-target phrases.
int num_samples = matchings.size() / num_subpatterns;
vector<Rule> rules;
for (auto source_phrase_entry: alignments_counter) {
@@ -124,6 +127,8 @@ vector<Extract> RuleExtractor::ExtractAlignments(
int sentence_id = source_data_array->GetSentenceId(matching[0]);
int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ // Get the span in the opposite sentence for each word in the source-target
+ // sentece pair.
vector<int> source_low, source_high, target_low, target_high;
helper->GetLinksSpans(source_low, source_high, target_low, target_high,
sentence_id);
@@ -134,6 +139,7 @@ vector<Extract> RuleExtractor::ExtractAlignments(
chunklen[i] = phrase.GetChunkLen(i);
}
+ // Basic checks to see if we can extract phrase pairs for this occurrence.
if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) ||
!helper->CheckTightPhrases(matching, chunklen, source_low)) {
return extracts;
@@ -144,6 +150,7 @@ vector<Extract> RuleExtractor::ExtractAlignments(
int source_phrase_high = matching.back() + chunklen.back() -
source_sent_start;
int target_phrase_low = -1, target_phrase_high = -1;
+ // Find target span and reflected source span for the source phrase.
if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low,
source_high, target_phrase_low, target_phrase_high,
target_low, target_high, source_back_low,
@@ -153,6 +160,7 @@ vector<Extract> RuleExtractor::ExtractAlignments(
return extracts;
}
+ // Get spans for nonterminal gaps.
bool met_constraints = true;
int num_symbols = phrase.GetNumSymbols();
vector<pair<int, int> > source_gaps, target_gaps;
@@ -163,6 +171,7 @@ vector<Extract> RuleExtractor::ExtractAlignments(
return extracts;
}
+ // Find target phrases aligned with the initial source phrase.
bool starts_with_x = source_back_low != source_phrase_low;
bool ends_with_x = source_back_high != source_phrase_high;
Phrase source_phrase = phrase_builder->Extend(
@@ -181,6 +190,8 @@ vector<Extract> RuleExtractor::ExtractAlignments(
return extracts;
}
+ // Extend the source phrase by adding a leading and/or trailing nonterminal
+ // and find target phrases aligned with the extended source phrase.
for (int i = 0; i < 2; ++i) {
for (int j = 1 - i; j < 2; ++j) {
AddNonterminalExtremities(extracts, matching, chunklen, source_phrase,
@@ -203,6 +214,8 @@ void RuleExtractor::AddExtracts(
source_indexes, sentence_id);
if (target_phrases.size() > 0) {
+ // Split the probability equally across all target phrases that can be
+ // aligned with a single occurrence of the source phrase.
double pairs_count = 1.0 / target_phrases.size();
for (auto target_phrase: target_phrases) {
extracts.push_back(Extract(source_phrase, target_phrase.first,
@@ -221,6 +234,7 @@ void RuleExtractor::AddNonterminalExtremities(
int extend_right) const {
int source_x_low = source_back_low, source_x_high = source_back_high;
+ // Check if the extended source phrase will remain tight.
if (require_tight_phrases) {
if (source_low[source_back_low - extend_left] == -1 ||
source_low[source_back_high + extend_right - 1] == -1) {
@@ -228,6 +242,7 @@ void RuleExtractor::AddNonterminalExtremities(
}
}
+ // Check if we can add a nonterminal to the left.
if (extend_left) {
if (starts_with_x || source_back_low < min_gap_size) {
return;
@@ -244,6 +259,7 @@ void RuleExtractor::AddNonterminalExtremities(
}
}
+ // Check if we can add a nonterminal to the right.
if (extend_right) {
int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
if (ends_with_x || source_back_high + min_gap_size > source_sent_len) {
@@ -262,6 +278,7 @@ void RuleExtractor::AddNonterminalExtremities(
}
}
+ // More length checks.
int new_nonterminals = extend_left + extend_right;
if (source_x_high - source_x_low > max_rule_span ||
target_gaps.size() + new_nonterminals > max_nonterminals ||
@@ -269,6 +286,7 @@ void RuleExtractor::AddNonterminalExtremities(
return;
}
+ // Find the target span for the extended phrase and the reflected source span.
int target_x_low = -1, target_x_high = -1;
if (!helper->FindFixPoint(source_x_low, source_x_high, source_low,
source_high, target_x_low, target_x_high,
@@ -279,6 +297,7 @@ void RuleExtractor::AddNonterminalExtremities(
return;
}
+ // Check gap integrity for the leading nonterminal.
if (extend_left) {
int source_gap_low = -1, source_gap_high = -1;
int target_gap_low = -1, target_gap_high = -1;
@@ -294,6 +313,7 @@ void RuleExtractor::AddNonterminalExtremities(
make_pair(target_gap_low, target_gap_high));
}
+ // Check gap integrity for the trailing nonterminal.
if (extend_right) {
int target_gap_low = -1, target_gap_high = -1;
int source_gap_low = -1, source_gap_high = -1;
@@ -308,6 +328,7 @@ void RuleExtractor::AddNonterminalExtremities(
target_gaps.push_back(make_pair(target_gap_low, target_gap_high));
}
+ // Find target phrases aligned with the extended source phrase.
Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left,
extend_right);
unordered_map<int, int> source_indexes = helper->GetSourceIndexes(
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h
index 8b6daeea..bfec0225 100644
--- a/extractor/rule_extractor.h
+++ b/extractor/rule_extractor.h
@@ -22,6 +22,10 @@ class RuleExtractorHelper;
class Scorer;
class TargetPhraseExtractor;
+/**
+ * Structure containing data about the occurrences of a source-target phrase pair
+ * in the parallel corpus.
+ */
struct Extract {
Extract(const Phrase& source_phrase, const Phrase& target_phrase,
double pairs_count, const PhraseAlignment& alignment) :
@@ -34,6 +38,9 @@ struct Extract {
PhraseAlignment alignment;
};
+/**
+ * Component for extracting SCFG rules.
+ */
class RuleExtractor {
public:
RuleExtractor(shared_ptr<DataArray> source_data_array,
@@ -64,6 +71,8 @@ class RuleExtractor {
virtual ~RuleExtractor();
+ // Extracts SCFG rules given a source phrase and a set of its occurrences
+ // in the source data.
virtual vector<Rule> ExtractRules(const Phrase& phrase,
const PhraseLocation& location) const;
@@ -71,15 +80,22 @@ class RuleExtractor {
RuleExtractor();
private:
+ // Finds all target phrases that can be aligned with the source phrase for a
+ // particular occurrence in the data.
vector<Extract> ExtractAlignments(const Phrase& phrase,
const vector<int>& matching) const;
+ // Extracts all target phrases for a given occurrence of the source phrase in
+ // the data. Constructs a vector of Extracts using these target phrases.
void AddExtracts(
vector<Extract>& extracts, const Phrase& source_phrase,
const unordered_map<int, int>& source_indexes,
const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
int target_phrase_low, int target_phrase_high, int sentence_id) const;
+ // Adds a leading and/or trailing nonterminal to the source phrase and
+ // extracts target phrases that can be aligned with the extended source
+ // phrase.
void AddNonterminalExtremities(
vector<Extract>& extracts, const vector<int>& matching,
const vector<int>& chunklen, const Phrase& source_phrase,
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
index 81b522f0..6410d147 100644
--- a/extractor/rule_extractor_helper.cc
+++ b/extractor/rule_extractor_helper.cc
@@ -88,6 +88,7 @@ bool RuleExtractorHelper::CheckTightPhrases(
return true;
}
+ // Check if the chunk extremities are aligned.
int sentence_id = source_data_array->GetSentenceId(matching[0]);
int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
@@ -126,6 +127,7 @@ bool RuleExtractorHelper::FindFixPoint(
int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ // Extend the target span to the left.
if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
if (prev_target_low - target_phrase_low < min_target_gap_size) {
target_phrase_low = prev_target_low - min_target_gap_size;
@@ -135,6 +137,7 @@ bool RuleExtractorHelper::FindFixPoint(
}
}
+ // Extend the target span to the right.
if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
if (target_phrase_high - prev_target_high < min_target_gap_size) {
target_phrase_high = prev_target_high + min_target_gap_size;
@@ -144,10 +147,12 @@ bool RuleExtractorHelper::FindFixPoint(
}
}
+ // Check target span length.
if (target_phrase_high - target_phrase_low > max_rule_span) {
return false;
}
+ // Find the initial reflected source span.
source_back_low = source_back_high = -1;
FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
source_back_low, source_back_high);
@@ -157,6 +162,7 @@ bool RuleExtractorHelper::FindFixPoint(
source_back_low = min(source_back_low, source_phrase_low);
source_back_high = max(source_back_high, source_phrase_high);
+ // Stop if the reflected source span matches the previous source span.
if (source_back_low == source_phrase_low &&
source_back_high == source_phrase_high) {
return true;
@@ -212,10 +218,14 @@ bool RuleExtractorHelper::FindFixPoint(
prev_target_low = target_phrase_low;
prev_target_high = target_phrase_high;
+ // Find the reflection including the left gap (if one was added).
FindProjection(source_back_low, source_phrase_low, source_low, source_high,
target_phrase_low, target_phrase_high);
+ // Find the reflection including the right gap (if one was added).
FindProjection(source_phrase_high, source_back_high, source_low,
source_high, target_phrase_low, target_phrase_high);
+ // Stop if the new re-reflected target span matches the previous target
+ // span.
if (prev_target_low == target_phrase_low &&
prev_target_high == target_phrase_high) {
return true;
@@ -232,6 +242,7 @@ bool RuleExtractorHelper::FindFixPoint(
source_phrase_low = source_back_low;
source_phrase_high = source_back_high;
+ // Re-reflect the target span.
FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
source_back_low, source_back_high);
FindProjection(prev_target_high, target_phrase_high, target_low,
diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h
index 7bf80c4b..bea75bc3 100644
--- a/extractor/rule_extractor_helper.h
+++ b/extractor/rule_extractor_helper.h
@@ -12,6 +12,9 @@ namespace extractor {
class Alignment;
class DataArray;
+/**
+ * Helper class for extracting SCFG rules.
+ */
class RuleExtractorHelper {
public:
RuleExtractorHelper(shared_ptr<DataArray> source_data_array,
@@ -25,18 +28,23 @@ class RuleExtractorHelper {
virtual ~RuleExtractorHelper();
+ // Find the alignment span for each word in the source target sentence pair.
virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high,
vector<int>& target_low, vector<int>& target_high,
int sentence_id) const;
+ // Check if one chunk (all chunks) is aligned at least in one point.
virtual bool CheckAlignedTerminals(const vector<int>& matching,
const vector<int>& chunklen,
const vector<int>& source_low) const;
+ // Check if the chunks are tight.
virtual bool CheckTightPhrases(const vector<int>& matching,
const vector<int>& chunklen,
const vector<int>& source_low) const;
+ // Find the target span and the reflected source span for a source phrase
+ // occurrence.
virtual bool FindFixPoint(
int source_phrase_low, int source_phrase_high,
const vector<int>& source_low, const vector<int>& source_high,
@@ -47,6 +55,7 @@ class RuleExtractorHelper {
int max_new_x, bool allow_low_x, bool allow_high_x,
bool allow_arbitrary_expansion) const;
+ // Find the gap spans for each nonterminal in the source phrase.
virtual bool GetGaps(
vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
const vector<int>& matching, const vector<int>& chunklen,
@@ -55,8 +64,10 @@ class RuleExtractorHelper {
int source_phrase_low, int source_phrase_high, int source_back_low,
int source_back_high, int& num_symbols, bool& met_constraints) const;
+ // Get the order of the nonterminals in the target phrase.
virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const;
+ // Map each terminal symbol with its position in the source phrase.
virtual unordered_map<int, int> GetSourceIndexes(
const vector<int>& matching, const vector<int>& chunklen,
int starts_with_x) const;
@@ -65,6 +76,8 @@ class RuleExtractorHelper {
RuleExtractorHelper();
private:
+ // Find the projection of a source phrase in the target sentence. May also be
+ // used to find the projection of a target phrase in the source sentence.
void FindProjection(
int source_phrase_low, int source_phrase_high,
const vector<int>& source_low, const vector<int>& source_high,
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index fbc62e50..8c30fb9e 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -152,12 +152,18 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
+ // For phrases containing a nonterminal, we use either the occurrences
+ // of the prefix or the suffix to determine the occurrences of the
+ // phrase.
Clock::time_point intersect_start = Clock::now();
phrase_location = fast_intersector->Intersect(
node->matchings, next_suffix_link->matchings, next_phrase);
Clock::time_point intersect_stop = Clock::now();
total_intersect_time += GetDuration(intersect_start, intersect_stop);
} else {
+ // For phrases not containing any nonterminals, we simply query the
+ // suffix array using the suffix array range of the prefix as a
+ // starting point.
Clock::time_point lookup_start = Clock::now();
phrase_location = matchings_finder->Find(
node->matchings,
@@ -170,9 +176,12 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
if (phrase_location.IsEmpty()) {
continue;
}
+
+ // Create new trie node to store data about the current phrase.
next_node = make_shared<TrieNode>(
next_suffix_link, next_phrase, phrase_location);
}
+ // Add the new trie node to the trie cache.
node->AddChild(word_id, next_node);
// Automatically adds a trailing non terminal if allowed. Simply copy the
@@ -182,6 +191,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
+ // Extract rules for the sampled set of occurrences.
PhraseLocation sample = sampler->Sample(next_node->matchings);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
@@ -193,6 +203,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
next_node = node->GetChild(word_id);
}
+ // Create more states (phrases) to be analyzed.
vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase,
next_node);
for (State new_state: new_states) {
@@ -262,6 +273,7 @@ vector<State> HieroCachingRuleFactory::ExtendState(
return new_states;
}
+ // New state for adding the next symbol.
new_states.push_back(State(state.start, state.end + 1, symbols,
state.subpatterns_start, node, state.starts_with_x));
@@ -272,6 +284,7 @@ vector<State> HieroCachingRuleFactory::ExtendState(
return new_states;
}
+ // New states for adding a nonterminal followed by a new symbol.
int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1);
symbols.push_back(var_id);
vector<int> subpatterns_start = state.subpatterns_start;
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index d8dc2ccc..52e8712a 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -25,6 +25,17 @@ class State;
class SuffixArray;
class Vocabulary;
+/**
+ * Component containing most of the logic for extracting SCFG rules for a given
+ * sentence.
+ *
+ * Given a sentence (as a vector of word ids), this class constructs all the
+ * possible source phrases starting from this sentence. For each source phrase,
+ * it finds all its occurrences in the source data and samples some of these
+ * occurrences to extract aligned source-target phrase pairs. A trie cache is
+ * used to avoid unnecessary computations if a source phrase can be constructed
+ * more than once (e.g. some words occur more than once in the sentence).
+ */
class HieroCachingRuleFactory {
public:
HieroCachingRuleFactory(
@@ -58,21 +69,30 @@ class HieroCachingRuleFactory {
virtual ~HieroCachingRuleFactory();
+ // Constructs SCFG rules for a given sentence.
+ // (See class description for more details.)
virtual Grammar GetGrammar(const vector<int>& word_ids);
protected:
HieroCachingRuleFactory();
private:
+ // Checks if the phrase (if previously encountered) or its prefix have any
+ // occurrences in the source data.
bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id);
+ // Checks if the phrase has previously been analyzed.
bool RequiresLookup(shared_ptr<TrieNode> node, int word_id);
+ // Creates a new state in the trie that corresponds to adding a trailing
+ // nonterminal to the current phrase.
void AddTrailingNonterminal(vector<int> symbols,
const Phrase& prefix,
const shared_ptr<TrieNode>& prefix_node,
bool starts_with_x);
+ // Extends the current state by possibly adding a nonterminal followed by a
+ // terminal.
vector<State> ExtendState(const vector<int>& word_ids,
const State& state,
vector<int> symbols,
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index dba4578c..d5ff23b2 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -35,6 +35,7 @@ using namespace std;
using namespace extractor;
using namespace features;
+// Returns the file path in which a given grammar should be written.
fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
string file_name = "grammar." + to_string(file_number);
return grammar_path / file_name;
@@ -45,6 +46,7 @@ int main(int argc, char** argv) {
#pragma omp parallel
num_threads_default = omp_get_num_threads();
+ // Sets up the command line arguments map.
po::options_description desc("Command line options");
desc.add_options()
("help,h", "Show available options")
@@ -69,7 +71,7 @@ int main(int argc, char** argv) {
("max_nonterminals", po::value<int>()->default_value(2),
"Maximum number of nonterminals in a rule")
("min_frequency", po::value<int>()->default_value(1000),
- "Minimum number of occurences for a pharse to be considered frequent")
+ "Minimum number of occurrences for a pharse to be considered frequent")
("max_samples", po::value<int>()->default_value(300),
"Maximum number of samples")
("tight_phrases", po::value<bool>()->default_value(true),
@@ -78,8 +80,8 @@ int main(int argc, char** argv) {
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
- // Check for help argument before notify, so we don't need to pass in the
- // required parameters.
+ // Checks for the help option before calling notify, so the we don't get an
+ // exception for missing required arguments.
if (vm.count("help")) {
cout << desc << endl;
return 0;
@@ -94,6 +96,7 @@ int main(int argc, char** argv) {
return 1;
}
+ // Reads the parallel corpus.
Clock::time_point preprocess_start_time = Clock::now();
cerr << "Reading source and target data..." << endl;
Clock::time_point start_time = Clock::now();
@@ -111,6 +114,7 @@ int main(int argc, char** argv) {
cerr << "Reading data took " << GetDuration(start_time, stop_time)
<< " seconds" << endl;
+ // Constructs the suffix array for the source data.
cerr << "Creating source suffix array..." << endl;
start_time = Clock::now();
shared_ptr<SuffixArray> source_suffix_array =
@@ -119,6 +123,7 @@ int main(int argc, char** argv) {
cerr << "Creating suffix array took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ // Reads the alignment.
cerr << "Reading alignment..." << endl;
start_time = Clock::now();
shared_ptr<Alignment> alignment =
@@ -127,6 +132,8 @@ int main(int argc, char** argv) {
cerr << "Reading alignment took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ // Constructs an index storing the occurrences in the source data for each
+ // frequent collocation.
cerr << "Precomputing collocations..." << endl;
start_time = Clock::now();
shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
@@ -142,6 +149,8 @@ int main(int argc, char** argv) {
cerr << "Precomputing collocations took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
+ // Constructs a table storing p(e | f) and p(f | e) for every pair of source
+ // and target words.
cerr << "Precomputing conditional probabilities..." << endl;
start_time = Clock::now();
shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
@@ -155,6 +164,7 @@ int main(int argc, char** argv) {
<< GetDuration(preprocess_start_time, preprocess_stop_time)
<< " seconds" << endl;
+ // Features used to score each grammar rule.
Clock::time_point extraction_start_time = Clock::now();
vector<shared_ptr<Feature> > features = {
make_shared<TargetGivenSourceCoherent>(),
@@ -167,6 +177,7 @@ int main(int argc, char** argv) {
};
shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+ // Sets up the grammar extractor.
GrammarExtractor extractor(
source_suffix_array,
target_data_array,
@@ -180,26 +191,30 @@ int main(int argc, char** argv) {
vm["max_samples"].as<int>(),
vm["tight_phrases"].as<bool>());
- // Release extra memory used by the initial precomputation.
+ // Releases extra memory used by the initial precomputation.
precomputation.reset();
+ // Creates the grammars directory if it doesn't exist.
fs::path grammar_path = vm["grammars"].as<string>();
if (!fs::is_directory(grammar_path)) {
fs::create_directory(grammar_path);
}
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
string sentence;
vector<string> sentences;
while (getline(cin, sentence)) {
sentences.push_back(sentence);
}
+ // Extracts the grammar for each sentence and saves it to a file.
vector<string> suffixes(sentences.size());
#pragma omp parallel for schedule(dynamic) \
num_threads(vm["threads"].as<int>())
for (size_t i = 0; i < sentences.size(); ++i) {
- string delimiter = "|||", suffix;
- int position = sentences[i].find(delimiter);
+ string suffix;
+ int position = sentences[i].find("|||");
if (position != sentences[i].npos) {
suffix = sentences[i].substr(position);
sentences[i] = sentences[i].substr(0, position);
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index f64a408c..d81956b5 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -16,6 +16,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
+ // Sample suffix array range.
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
double step = max(1.0, (double) (high - low) / max_samples);
@@ -23,6 +24,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
sample.push_back(suffix_array->GetSuffix(Round(i)));
}
} else {
+ // Sample vector of occurrences.
num_subpatterns = location.num_subpatterns;
int num_matchings = location.matchings->size() / num_subpatterns;
double step = max(1.0, (double) num_matchings / max_samples);
diff --git a/extractor/sampler.h b/extractor/sampler.h
index cda28b10..be4aa1bb 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -10,18 +10,23 @@ namespace extractor {
class PhraseLocation;
class SuffixArray;
+/**
+ * Provides uniform sampling for a PhraseLocation.
+ */
class Sampler {
public:
Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples);
virtual ~Sampler();
+ // Samples uniformly at most max_samples phrase occurrences.
virtual PhraseLocation Sample(const PhraseLocation& location) const;
protected:
Sampler();
private:
+ // Round floating point number to the nearest integer.
int Round(double x) const;
shared_ptr<SuffixArray> suffix_array;
diff --git a/extractor/scorer.h b/extractor/scorer.h
index c31db0ca..af8a3b10 100644
--- a/extractor/scorer.h
+++ b/extractor/scorer.h
@@ -14,14 +14,19 @@ namespace features {
class FeatureContext;
} // namespace features
+/**
+ * Computes the feature scores for a source-target phrase pair.
+ */
class Scorer {
public:
Scorer(const vector<shared_ptr<features::Feature> >& features);
virtual ~Scorer();
+ // Computes the feature score for the given context.
virtual vector<double> Score(const features::FeatureContext& context) const;
+ // Returns the set of feature names used to score any context.
virtual vector<string> GetFeatureNames() const;
protected:
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
index 7a4f1110..bf731d79 100644
--- a/extractor/suffix_array.h
+++ b/extractor/suffix_array.h
@@ -17,18 +17,26 @@ class PhraseLocation;
class SuffixArray {
public:
+ // Creates a suffix array from a data array.
SuffixArray(shared_ptr<DataArray> data_array);
virtual ~SuffixArray();
+ // Returns the size of the suffix array.
virtual int GetSize() const;
+ // Returns the data array on top of which the suffix array is constructed.
virtual shared_ptr<DataArray> GetData() const;
+ // Constructs the longest-common-prefix array using the algorithm of Kasai et
+ // al. (2001).
virtual vector<int> BuildLCPArray() const;
+ // Returns the i-th suffix.
virtual int GetSuffix(int rank) const;
+ // Given the range in which a phrase is located and the next word, returns the
+ // range corresponding to the phrase extended with the next word.
virtual PhraseLocation Lookup(int low, int high, const string& word,
int offset) const;
@@ -38,14 +46,23 @@ class SuffixArray {
SuffixArray();
private:
+ // Constructs the suffix array using the algorithm of Larsson and Sadakane
+ // (1999).
void BuildSuffixArray();
+ // Bucket sort on the data array (used for initializing the construction of
+ // the suffix array.)
void InitialBucketSort(vector<int>& groups);
void TernaryQuicksort(int left, int right, int step, vector<int>& groups);
+ // Constructs the suffix array in log(n) steps by doubling the length of the
+ // suffixes at each step.
void PrefixDoublingSort(vector<int>& groups);
+ // Given a [low, high) range in the suffix array in which all elements have
+ // the first offset-1 values the same, it returns the first position where the
+ // offset value is greater or equal to word_id.
int LookupRangeStart(int low, int high, int word_id, int offset) const;
shared_ptr<DataArray> data_array;
diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc
index 9f8bc6e2..2b8a2e4a 100644
--- a/extractor/target_phrase_extractor.cc
+++ b/extractor/target_phrase_extractor.cc
@@ -43,11 +43,13 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
if (!require_tight_phrases) {
+ // Extend loose target phrase to the left.
while (target_x_low > 0 &&
target_phrase_high - target_x_low < max_rule_span &&
target_low[target_x_low - 1] == -1) {
--target_x_low;
}
+ // Extend loose target phrase to the right.
while (target_x_high < target_sent_len &&
target_x_high - target_phrase_low < max_rule_span &&
target_low[target_x_high] == -1) {
@@ -59,10 +61,12 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
for (size_t i = 0; i < gaps.size(); ++i) {
gaps[i] = target_gaps[target_gap_order[i]];
if (!require_tight_phrases) {
+ // Extend gap to the left.
while (gaps[i].first > target_x_low &&
target_low[gaps[i].first - 1] == -1) {
--gaps[i].first;
}
+ // Extend gap to the right.
while (gaps[i].second < target_x_high &&
target_low[gaps[i].second] == -1) {
++gaps[i].second;
@@ -70,6 +74,9 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
}
}
+ // Compute the range in which each chunk may start or end. (Even indexes
+ // represent the range in which the chunk may start, odd indexes represent the
+ // range in which the chunk may end.)
vector<pair<int, int> > ranges(2 * gaps.size() + 2);
ranges.front() = make_pair(target_x_low, target_phrase_low);
ranges.back() = make_pair(target_phrase_high, target_x_high);
@@ -101,6 +108,7 @@ void TargetPhraseExtractor::GeneratePhrases(
vector<int> symbols;
unordered_map<int, int> target_indexes;
+ // Construct target phrase chunk by chunk.
int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
@@ -115,6 +123,7 @@ void TargetPhraseExtractor::GeneratePhrases(
}
}
+ // Construct the alignment between the source and the target phrase.
vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
vector<pair<int, int> > alignment;
for (pair<int, int> link: links) {
@@ -133,6 +142,7 @@ void TargetPhraseExtractor::GeneratePhrases(
if (index > 0) {
subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
}
+ // Choose every possible combination of [start, end) for the current chunk.
while (subpatterns[index] <= ranges[index].second) {
subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
while (subpatterns[index + 1] <= ranges[index + 1].second) {
diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h
index a4b54145..289bae2f 100644
--- a/extractor/target_phrase_extractor.h
+++ b/extractor/target_phrase_extractor.h
@@ -30,6 +30,8 @@ class TargetPhraseExtractor {
virtual ~TargetPhraseExtractor();
+ // Finds all the target phrases that can extracted from a span in the
+ // target sentence (matching the given set of target phrase gaps).
virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases(
const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
int target_phrase_low, int target_phrase_high,
@@ -39,6 +41,8 @@ class TargetPhraseExtractor {
TargetPhraseExtractor();
private:
+ // Computes the cartesian product over the sets of possible target phrase
+ // chunks.
void GeneratePhrases(
vector<pair<Phrase, PhraseAlignment> >& target_phrases,
const vector<pair<int, int> >& ranges, int index,
diff --git a/extractor/time_util.h b/extractor/time_util.h
index 45f79199..f7fd51d3 100644
--- a/extractor/time_util.h
+++ b/extractor/time_util.h
@@ -10,6 +10,7 @@ namespace extractor {
typedef high_resolution_clock Clock;
+// Computes the duration in seconds of the specified time interval.
double GetDuration(const Clock::time_point& start_time,
const Clock::time_point& stop_time);
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index 1852a357..45da707a 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -23,6 +23,8 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
unordered_map<int, int> target_links_count;
unordered_map<pair<int, int>, int, PairHash> links_count;
+ // For each pair of aligned source target words increment their link count by
+ // 1. Unaligned words are paired with the NULL token.
for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) {
vector<pair<int, int> > links = alignment->GetLinks(i);
int source_start = source_data_array->GetSentenceStart(i);
@@ -40,25 +42,28 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
for (pair<int, int> link: links) {
source_linked_words[link.first] = 1;
target_linked_words[link.second] = 1;
- IncreaseLinksCount(source_links_count, target_links_count, links_count,
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
source_sentence[link.first], target_sentence[link.second]);
}
for (size_t i = 0; i < source_sentence.size(); ++i) {
if (!source_linked_words[i]) {
- IncreaseLinksCount(source_links_count, target_links_count, links_count,
- source_sentence[i], DataArray::NULL_WORD);
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
+ source_sentence[i], DataArray::NULL_WORD);
}
}
for (size_t i = 0; i < target_sentence.size(); ++i) {
if (!target_linked_words[i]) {
- IncreaseLinksCount(source_links_count, target_links_count, links_count,
- DataArray::NULL_WORD, target_sentence[i]);
+ IncrementLinksCount(source_links_count, target_links_count, links_count,
+ DataArray::NULL_WORD, target_sentence[i]);
}
}
}
+ // Calculating:
+ // p(e | f) = count(e, f) / count(f)
+ // p(f | e) = count(e, f) / count(e)
for (pair<pair<int, int>, int> link_count: links_count) {
int source_word = link_count.first.first;
int target_word = link_count.first.second;
@@ -72,7 +77,7 @@ TranslationTable::TranslationTable() {}
TranslationTable::~TranslationTable() {}
-void TranslationTable::IncreaseLinksCount(
+void TranslationTable::IncrementLinksCount(
unordered_map<int, int>& source_links_count,
unordered_map<int, int>& target_links_count,
unordered_map<pair<int, int>, int, PairHash>& links_count,
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index a7be26f5..10504d3b 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -18,6 +18,9 @@ typedef boost::hash<pair<int, int> > PairHash;
class Alignment;
class DataArray;
+/**
+ * Bilexical table with conditional probabilities.
+ */
class TranslationTable {
public:
TranslationTable(
@@ -27,9 +30,11 @@ class TranslationTable {
virtual ~TranslationTable();
+ // Returns p(e | f).
virtual double GetTargetGivenSourceScore(const string& source_word,
const string& target_word);
+ // Returns p(f | e).
virtual double GetSourceGivenTargetScore(const string& source_word,
const string& target_word);
@@ -39,7 +44,8 @@ class TranslationTable {
TranslationTable();
private:
- void IncreaseLinksCount(
+ // Increment links count for the given (f, e) word pair.
+ void IncrementLinksCount(
unordered_map<int, int>& source_links_count,
unordered_map<int, int>& target_links_count,
unordered_map<pair<int, int>, int, PairHash>& links_count,
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index 03c7dc66..c8fd9411 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -9,16 +9,33 @@ using namespace std;
namespace extractor {
+/**
+ * Data structure for mapping words to word ids.
+ *
+ * This strucure contains words located in the frequent collocations and words
+ * encountered during the grammar extraction time. This dictionary is
+ * considerably smaller than the dictionaries in the data arrays (and so is the
+ * query time). Note that this is the single data structure that changes state
+ * and needs to have thread safe read/write operations.
+ *
+ * Note: For an experiment using different vocabulary instances for each thread,
+ * the running time did not improve implying that the critical regions do not
+ * cause bottlenecks.
+ */
class Vocabulary {
public:
virtual ~Vocabulary();
+ // Returns the word id for the given word.
virtual int GetTerminalIndex(const string& word);
+ // Returns the id for a nonterminal located at the given position in a phrase.
int GetNonterminalIndex(int position);
+ // Checks if a symbol is a nonterminal.
bool IsTerminal(int symbol);
+ // Returns the word corresponding to the given word id.
virtual string GetTerminalValue(int symbol);
private: