diff options
Diffstat (limited to 'extractor')
41 files changed, 383 insertions, 19 deletions
diff --git a/extractor/alignment.cc b/extractor/alignment.cc index f9bbcf6a..1aea34b3 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -28,8 +28,6 @@ Alignment::Alignment(const string& filename) { } alignments.push_back(alignment); } - // Note: shrink_to_fit does nothing for vector<vector<string> > on g++ 4.6.3, - // but let's hope that the bug will be fixed in a newer version. alignments.shrink_to_fit(); } diff --git a/extractor/alignment.h b/extractor/alignment.h index ef89dc0c..e9292121 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -11,12 +11,18 @@ using namespace std; namespace extractor { +/** + * Data structure storing the word alignments for a parallel corpus. + */ class Alignment { public: + // Reads alignment from text file. Alignment(const string& filename); + // Returns the alignment for a given sentence. virtual vector<pair<int, int> > GetLinks(int sentence_index) const; + // Writes alignment to file in binary format. void WriteBinary(const fs::path& filepath); virtual ~Alignment(); diff --git a/extractor/compile.cc b/extractor/compile.cc index 7062ef03..a9ae2cef 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -37,7 +37,7 @@ int main(int argc, char** argv) { ("max_phrase_len,p", po::value<int>()->default_value(4), "Maximum frequent phrase length") ("min_frequency", po::value<int>()->default_value(1000), - "Minimum number of occurences for a pharse to be considered frequent"); + "Minimum number of occurrences for a pharse to be considered frequent"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); diff --git a/extractor/data_array.h b/extractor/data_array.h index a26bbecf..978a6931 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -17,9 +17,19 @@ enum Side { TARGET }; -// Note: This class has features for both the source and target data arrays. -// Maybe we can save some memory by having more specific implementations (e.g. -// sentence_id is only needed for the source data array). +/** + * Data structure storing information about a single side of a parallel corpus. + * + * Each word is mapped to a unique integer (word_id). The data structure holds + * the corpus in the numberized format, together with the hash table mapping + * words to word_ids. It also holds additional information such as the starting + * index for each sentence and, for each token, the index of the sentence it + * belongs to. + * + * Note: This class has features for both the source and target data arrays. + * Maybe we can save some memory by having more specific implementations (not + * likely to save a lot of memory tough). + */ class DataArray { public: static int NULL_WORD; @@ -27,45 +37,65 @@ class DataArray { static string NULL_WORD_STR; static string END_OF_LINE_STR; + // Reads data array from text file. DataArray(const string& filename); + // Reads data array from bitext file where the sentences are separated by |||. DataArray(const string& filename, const Side& side); virtual ~DataArray(); + // Returns a vector containing the word ids. virtual const vector<int>& GetData() const; + // Returns the word id at the specified position. virtual int AtIndex(int index) const; + // Returns the original word at the specified position. virtual string GetWordAtIndex(int index) const; + // Returns the size of the data array. virtual int GetSize() const; + // Returns the number of distinct words in the data array. virtual int GetVocabularySize() const; + // Returns whether a word has ever been observed in the data array. virtual bool HasWord(const string& word) const; + // Returns the word id for a given word or -1 if it the word has never been + // observed. virtual int GetWordId(const string& word) const; + // Returns the word corresponding to a particular word id. virtual string GetWord(int word_id) const; + // Returns the number of sentences in the data. virtual int GetNumSentences() const; + // Returns the index where the sentence containing the given position starts. virtual int GetSentenceStart(int position) const; + // Returns the length of the sentence. virtual int GetSentenceLength(int sentence_id) const; + // Returns the number of the sentence containing the given position. virtual int GetSentenceId(int position) const; + // Writes data array to file in binary format. void WriteBinary(const fs::path& filepath) const; + // Writes data array to file in binary format. void WriteBinary(FILE* file) const; protected: DataArray(); private: + // Sets up specific constants. void InitializeDataArray(); + + // Constructs the data array. void CreateDataArray(const vector<string>& lines); unordered_map<string, int> word2id; diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index 1b8c32b1..2a7693b2 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -107,6 +107,7 @@ PhraseLocation FastIntersector::ExtendPrefixPhraseLocation( } else { pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2; } + // Searches for the last symbol in the phrase after each prefix occurrence. for (int j = range.first; j < range.second; ++j) { if (pattern_end >= sent_end || pattern_end - positions[i] >= max_rule_span) { @@ -149,6 +150,8 @@ PhraseLocation FastIntersector::ExtendSuffixPhraseLocation( int pattern_start = positions[i] - range.first; int pattern_end = positions[i + num_subpatterns - 1] + phrase.GetChunkLen(phrase.Arity()) - 1; + // Searches for the first symbol in the phrase before each suffix + // occurrence. for (int j = range.first; j < range.second; ++j) { if (pattern_start < sent_start || pattern_end - pattern_start >= max_rule_span) { diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 32c88a30..f950a2a9 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -20,6 +20,18 @@ class Precomputation; class SuffixArray; class Vocabulary; +/** + * Component for searching the training data for occurrences of source phrases + * containing nonterminals + * + * Given a source phrase containing a nonterminal, we first query the + * precomputed index containing frequent collocations. If the phrase is not + * frequent enough, we extend the matchings of either its prefix or its suffix, + * depending on which operation seems to require less computations. + * + * Note: This method for intersecting phrase locations is faster than both + * mergers (linear or Baeza Yates) described in Adam Lopez' dissertation. + */ class FastIntersector { public: FastIntersector(shared_ptr<SuffixArray> suffix_array, @@ -30,6 +42,8 @@ class FastIntersector { virtual ~FastIntersector(); + // Finds the locations of a phrase given the locations of its prefix and + // suffix. virtual PhraseLocation Intersect(PhraseLocation& prefix_location, PhraseLocation& suffix_location, const Phrase& phrase); @@ -38,23 +52,36 @@ class FastIntersector { FastIntersector(); private: + // Uses the vocabulary to convert the phrase from the numberized format + // specified by the source data array to the numberized format given by the + // vocabulary. vector<int> ConvertPhrase(const vector<int>& old_phrase); + // Estimates the number of computations needed if the prefix/suffix is + // extended. If the last/first symbol is separated from the rest of the phrase + // by a nonterminal, then for each occurrence of the prefix/suffix we need to + // check max_rule_span positions. Otherwise, we only need to check a single + // position for each occurrence. int EstimateNumOperations(const PhraseLocation& phrase_location, bool has_margin_x) const; + // Uses the occurrences of the prefix to find the occurrences of the phrase. PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location, const Phrase& phrase, bool prefix_ends_with_x, int next_symbol) const; + // Uses the occurrences of the suffix to find the occurrences of the phrase. PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location, const Phrase& phrase, bool suffix_starts_with_x, int prev_symbol) const; + // Extends the prefix/suffix location to a list of subpatterns positions if it + // represents a suffix array range. void ExtendPhraseLocation(PhraseLocation& location) const; + // Returns the range in which the search should be performed. pair<int, int> GetSearchRange(bool has_marginal_x) const; shared_ptr<SuffixArray> suffix_array; diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h index dec78883..8747fa60 100644 --- a/extractor/features/count_source_target.h +++ b/extractor/features/count_source_target.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Feature for the number of times a word pair was found in the bitext. + */ class CountSourceTarget : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/feature.h b/extractor/features/feature.h index 6693ccbf..36ea504a 100644 --- a/extractor/features/feature.h +++ b/extractor/features/feature.h @@ -10,6 +10,9 @@ using namespace std; namespace extractor { namespace features { +/** + * Structure providing context for computing feature scores. + */ struct FeatureContext { FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, double source_phrase_count, int pair_count, int num_samples) : @@ -24,6 +27,9 @@ struct FeatureContext { int num_samples; }; +/** + * Base class for features. + */ class Feature { public: virtual double Score(const FeatureContext& context) const = 0; diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h index 30f76c6d..b8352d0e 100644 --- a/extractor/features/is_source_singleton.h +++ b/extractor/features/is_source_singleton.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Boolean feature checking if the source phrase occurs only once in the data. + */ class IsSourceSingleton : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h index 12fb6ee6..dacfebba 100644 --- a/extractor/features/is_source_target_singleton.h +++ b/extractor/features/is_source_target_singleton.h @@ -6,6 +6,9 @@ namespace extractor { namespace features { +/** + * Boolean feature checking if the phrase pair occurs only once in the data. + */ class IsSourceTargetSingleton : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h index bfa7ef1b..461b0ebf 100644 --- a/extractor/features/max_lex_source_given_target.h +++ b/extractor/features/max_lex_source_given_target.h @@ -13,6 +13,9 @@ class TranslationTable; namespace features { +/** + * Feature computing max(p(f | e)) across all pairs of words in the phrase pair. + */ class MaxLexSourceGivenTarget : public Feature { public: MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table); diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h index 66cf0914..c3c87327 100644 --- a/extractor/features/max_lex_target_given_source.h +++ b/extractor/features/max_lex_target_given_source.h @@ -13,6 +13,9 @@ class TranslationTable; namespace features { +/** + * Feature computing max(p(e | f)) across all pairs of words in the phrase pair. + */ class MaxLexTargetGivenSource : public Feature { public: MaxLexTargetGivenSource(shared_ptr<TranslationTable> table); diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h index 53c7f954..ee6e59a0 100644 --- a/extractor/features/sample_source_count.h +++ b/extractor/features/sample_source_count.h @@ -6,6 +6,10 @@ namespace extractor { namespace features { +/** + * Feature scoring the number of times the source phrase occurs in the sampled + * set. + */ class SampleSourceCount : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h index 80d9f617..e66d70a5 100644 --- a/extractor/features/target_given_source_coherent.h +++ b/extractor/features/target_given_source_coherent.h @@ -6,6 +6,10 @@ namespace extractor { namespace features { +/** + * Feature computing the ratio of the phrase pair count over all source phrase + * occurrences (sampled). + */ class TargetGivenSourceCoherent : public Feature { public: double Score(const FeatureContext& context) const; diff --git a/extractor/grammar.h b/extractor/grammar.h index a424d65a..fed41b16 100644 --- a/extractor/grammar.h +++ b/extractor/grammar.h @@ -11,6 +11,9 @@ namespace extractor { class Rule; +/** + * Grammar class wrapping the set of rules to be extracted. + */ class Grammar { public: Grammar(const vector<Rule>& rules, const vector<string>& feature_names); diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 6b1dcf98..b36ceeb9 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -19,6 +19,10 @@ class Scorer; class SuffixArray; class Vocabulary; +/** + * Class wrapping all the logic for extracting the synchronous context free + * grammars. + */ class GrammarExtractor { public: GrammarExtractor( @@ -38,11 +42,15 @@ class GrammarExtractor { GrammarExtractor(shared_ptr<Vocabulary> vocabulary, shared_ptr<HieroCachingRuleFactory> rule_factory); + // Converts the sentence to a vector of word ids and uses the RuleFactory to + // extract the SCFG rules which may be used to decode the sentence. Grammar GetGrammar(const string& sentence); private: + // Splits the sentence in a vector of words. vector<string> TokenizeSentence(const string& sentence); + // Maps the words to word ids. vector<int> AnnotateWords(const vector<string>& words); shared_ptr<Vocabulary> vocabulary; diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h index fbb504ef..451f4a4c 100644 --- a/extractor/matchings_finder.h +++ b/extractor/matchings_finder.h @@ -11,12 +11,17 @@ namespace extractor { class PhraseLocation; class SuffixArray; +/** + * Class wrapping the suffix array lookup for a contiguous phrase. + */ class MatchingsFinder { public: MatchingsFinder(shared_ptr<SuffixArray> suffix_array); virtual ~MatchingsFinder(); + // Uses the suffix array to search only for the last word of the phrase + // starting from the range in which the prefix of the phrase occurs. virtual PhraseLocation Find(PhraseLocation& location, const string& word, int offset); diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index f3dcc075..1fb29693 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -11,20 +11,27 @@ using namespace std; namespace extractor { +/** + * Trie node containing all the occurrences of the corresponding phrase in the + * source data. + */ struct TrieNode { TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(), Phrase phrase = Phrase(), PhraseLocation matchings = PhraseLocation()) : suffix_link(suffix_link), phrase(phrase), matchings(matchings) {} + // Adds a trie node as a child of the current node. void AddChild(int key, shared_ptr<TrieNode> child_node) { children[key] = child_node; } + // Checks if a child exists for a given key. bool HasChild(int key) { return children.count(key); } + // Gets the child corresponding to the given key. shared_ptr<TrieNode> GetChild(int key) { return children[key]; } @@ -35,15 +42,20 @@ struct TrieNode { unordered_map<int, shared_ptr<TrieNode> > children; }; +/** + * Trie containing all the phrases that can be obtained from a sentence. + */ class MatchingsTrie { public: MatchingsTrie(); virtual ~MatchingsTrie(); + // Returns the root of the trie. shared_ptr<TrieNode> GetRoot() const; private: + // Recursively deletes a subtree of the trie. void DeleteTree(shared_ptr<TrieNode> root); shared_ptr<TrieNode> root; diff --git a/extractor/phrase.h b/extractor/phrase.h index 6521c438..a8e91e3c 100644 --- a/extractor/phrase.h +++ b/extractor/phrase.h @@ -11,20 +11,30 @@ using namespace std; namespace extractor { +/** + * Structure containing the data for a phrase. + */ class Phrase { public: friend Phrase PhraseBuilder::Build(const vector<int>& phrase); + // Returns the number of nonterminals in the phrase. int Arity() const; + // Returns the number of terminals (length) for the given chunk. (A chunk is a + // contiguous sequence of terminals in the phrase). int GetChunkLen(int index) const; + // Returns the symbols (word ids) marking up the phrase. vector<int> Get() const; + // Returns the symbol located at the given position in the phrase. int GetSymbol(int position) const; + // Returns the number of symbols in the phrase. int GetNumSymbols() const; + // Returns the words making up the phrase. (Nonterminals are stripped out.) vector<string> GetWords() const; bool operator<(const Phrase& other) const; diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h index 2956fd35..de86dbae 100644 --- a/extractor/phrase_builder.h +++ b/extractor/phrase_builder.h @@ -11,12 +11,17 @@ namespace extractor { class Phrase; class Vocabulary; +/** + * Component for constructing phrases. + */ class PhraseBuilder { public: PhraseBuilder(shared_ptr<Vocabulary> vocabulary); + // Constructs a phrase starting from an array of symbols. Phrase Build(const vector<int>& symbols); + // Extends a phrase with a leading and/or trailing nonterminal. Phrase Extend(const Phrase& phrase, bool start_x, bool end_x); private: diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h index e5f3cf08..91950e03 100644 --- a/extractor/phrase_location.h +++ b/extractor/phrase_location.h @@ -8,13 +8,25 @@ using namespace std; namespace extractor { +/** + * Structure containing information about the occurrences of a phrase in the + * source data. + * + * Every consecutive (disjoint) group of num_subpatterns entries in matchings + * vector encodes an occurrence of the phrase. The i-th entry of a group + * represents the start of the i-th subpattern of the phrase. If the phrase + * doesn't contain any nonterminals, then it may also be represented as the + * range in the suffix array which matches the phrase. + */ struct PhraseLocation { PhraseLocation(int sa_low = -1, int sa_high = -1); PhraseLocation(const vector<int>& matchings, int num_subpatterns); + // Checks if a phrase has any occurrences in the source data. bool IsEmpty() const; + // Returns the number of occurrences of a phrase in the source data. int GetSize() const; friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 0fadc95c..b3906943 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -23,6 +23,8 @@ Precomputation::Precomputation( suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, min_frequency); + // Construct sets containing the frequent and superfrequent contiguous + // collocations. unordered_set<vector<int>, VectorHash> frequent_patterns_set; unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; for (size_t i = 0; i < frequent_patterns.size(); ++i) { @@ -34,6 +36,8 @@ Precomputation::Precomputation( vector<tuple<int, int, int> > matchings; for (size_t i = 0; i < data.size(); ++i) { + // If the sentence is over, add all the discontiguous frequent patterns to + // the index. if (data[i] == DataArray::END_OF_LINE) { AddCollocations(matchings, data, max_rule_span, min_gap_size, max_rule_symbols); @@ -41,6 +45,7 @@ Precomputation::Precomputation( continue; } vector<int> pattern; + // Find all the contiguous frequent patterns starting at position i. for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { pattern.push_back(data[i + j - 1]); if (frequent_patterns_set.count(pattern)) { @@ -65,6 +70,7 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns( vector<int> lcp = suffix_array->BuildLCPArray(); vector<int> run_start(max_frequent_phrase_len); + // Find all the patterns occurring at least min_frequency times. priority_queue<pair<int, pair<int, int> > > heap; for (size_t i = 1; i < lcp.size(); ++i) { for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { @@ -77,6 +83,7 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns( } } + // Extract the most frequent patterns. vector<vector<int> > frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { int start = heap.top().second.first; @@ -95,10 +102,12 @@ vector<vector<int> > Precomputation::FindMostFrequentPatterns( void Precomputation::AddCollocations( const vector<tuple<int, int, int> >& matchings, const vector<int>& data, int max_rule_span, int min_gap_size, int max_rule_symbols) { + // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { int start1, size1, is_super1; tie(start1, size1, is_super1) = matchings[i]; + // Select the second (middle) subpattern for (size_t j = i + 1; j < matchings.size(); ++j) { int start2, size2, is_super2; tie(start2, size2, is_super2) = matchings[j]; @@ -116,8 +125,10 @@ void Precomputation::AddCollocations( data.begin() + start2 + size2); AddStartPositions(collocations[pattern], start1, start2); + // Try extending the binary collocation to a ternary collocation. if (is_super2) { pattern.push_back(Precomputation::SECOND_NONTERMINAL); + // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; tie(start3, size3, is_super3) = matchings[k]; diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 2c1eccf8..e3c4d26a 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -20,8 +20,19 @@ typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; class SuffixArray; +/** + * Data structure wrapping an index with all the occurrences of the most + * frequent discontiguous collocations in the source data. + * + * Let a, b, c be contiguous collocations. The index will contain an entry for + * every collocation of the form: + * - aXb, where a and b are frequent + * - aXbXc, where a and b are super-frequent and c is frequent or + * b and c are super-frequent and a is frequent. + */ class Precomputation { public: + // Constructs the index using the suffix array. Precomputation( shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, int num_super_frequent_patterns, int max_rule_span, @@ -32,6 +43,7 @@ class Precomputation { void WriteBinary(const fs::path& filepath) const; + // Returns a reference to the index. virtual const Index& GetCollocations() const; static int FIRST_NONTERMINAL; @@ -41,14 +53,23 @@ class Precomputation { Precomputation(); private: + // Finds the most frequent contiguous collocations. vector<vector<int> > FindMostFrequentPatterns( shared_ptr<SuffixArray> suffix_array, const vector<int>& data, int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency); + + // Given the locations of the frequent contiguous collocations in a sentence, + // it adds new entries to the index for each discontiguous collocation + // matching the criteria specified in the class description. void AddCollocations( const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data, int max_rule_span, int min_gap_size, int max_rule_symbols); + + // Adds an occurrence of a binary collocation. void AddStartPositions(vector<int>& positions, int pos1, int pos2); + + // Adds an occurrence of a ternary collocation. void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); Index collocations; diff --git a/extractor/rule.h b/extractor/rule.h index b4d45fc1..bc95709e 100644 --- a/extractor/rule.h +++ b/extractor/rule.h @@ -9,6 +9,9 @@ using namespace std; namespace extractor { +/** + * Structure containing the data for a SCFG rule. + */ struct Rule { Rule(const Phrase& source_phrase, const Phrase& target_phrase, const vector<double>& scores, const vector<pair<int, int> >& alignment); diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc index b9286472..9f5e8e00 100644 --- a/extractor/rule_extractor.cc +++ b/extractor/rule_extractor.cc @@ -79,6 +79,7 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, int num_subpatterns = location.num_subpatterns; vector<int> matchings = *location.matchings; + // Calculate statistics for the (sampled) occurrences of the source phrase. map<Phrase, double> source_phrase_counter; map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter; for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { @@ -91,6 +92,8 @@ vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, } } + // Compute the feature scores and find the most likely (frequent) alignment + // for each pair of source-target phrases. int num_samples = matchings.size() / num_subpatterns; vector<Rule> rules; for (auto source_phrase_entry: alignments_counter) { @@ -124,6 +127,8 @@ vector<Extract> RuleExtractor::ExtractAlignments( int sentence_id = source_data_array->GetSentenceId(matching[0]); int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + // Get the span in the opposite sentence for each word in the source-target + // sentece pair. vector<int> source_low, source_high, target_low, target_high; helper->GetLinksSpans(source_low, source_high, target_low, target_high, sentence_id); @@ -134,6 +139,7 @@ vector<Extract> RuleExtractor::ExtractAlignments( chunklen[i] = phrase.GetChunkLen(i); } + // Basic checks to see if we can extract phrase pairs for this occurrence. if (!helper->CheckAlignedTerminals(matching, chunklen, source_low) || !helper->CheckTightPhrases(matching, chunklen, source_low)) { return extracts; @@ -144,6 +150,7 @@ vector<Extract> RuleExtractor::ExtractAlignments( int source_phrase_high = matching.back() + chunklen.back() - source_sent_start; int target_phrase_low = -1, target_phrase_high = -1; + // Find target span and reflected source span for the source phrase. if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, source_high, target_phrase_low, target_phrase_high, target_low, target_high, source_back_low, @@ -153,6 +160,7 @@ vector<Extract> RuleExtractor::ExtractAlignments( return extracts; } + // Get spans for nonterminal gaps. bool met_constraints = true; int num_symbols = phrase.GetNumSymbols(); vector<pair<int, int> > source_gaps, target_gaps; @@ -163,6 +171,7 @@ vector<Extract> RuleExtractor::ExtractAlignments( return extracts; } + // Find target phrases aligned with the initial source phrase. bool starts_with_x = source_back_low != source_phrase_low; bool ends_with_x = source_back_high != source_phrase_high; Phrase source_phrase = phrase_builder->Extend( @@ -181,6 +190,8 @@ vector<Extract> RuleExtractor::ExtractAlignments( return extracts; } + // Extend the source phrase by adding a leading and/or trailing nonterminal + // and find target phrases aligned with the extended source phrase. for (int i = 0; i < 2; ++i) { for (int j = 1 - i; j < 2; ++j) { AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, @@ -203,6 +214,8 @@ void RuleExtractor::AddExtracts( source_indexes, sentence_id); if (target_phrases.size() > 0) { + // Split the probability equally across all target phrases that can be + // aligned with a single occurrence of the source phrase. double pairs_count = 1.0 / target_phrases.size(); for (auto target_phrase: target_phrases) { extracts.push_back(Extract(source_phrase, target_phrase.first, @@ -221,6 +234,7 @@ void RuleExtractor::AddNonterminalExtremities( int extend_right) const { int source_x_low = source_back_low, source_x_high = source_back_high; + // Check if the extended source phrase will remain tight. if (require_tight_phrases) { if (source_low[source_back_low - extend_left] == -1 || source_low[source_back_high + extend_right - 1] == -1) { @@ -228,6 +242,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // Check if we can add a nonterminal to the left. if (extend_left) { if (starts_with_x || source_back_low < min_gap_size) { return; @@ -244,6 +259,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // Check if we can add a nonterminal to the right. if (extend_right) { int source_sent_len = source_data_array->GetSentenceLength(sentence_id); if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { @@ -262,6 +278,7 @@ void RuleExtractor::AddNonterminalExtremities( } } + // More length checks. int new_nonterminals = extend_left + extend_right; if (source_x_high - source_x_low > max_rule_span || target_gaps.size() + new_nonterminals > max_nonterminals || @@ -269,6 +286,7 @@ void RuleExtractor::AddNonterminalExtremities( return; } + // Find the target span for the extended phrase and the reflected source span. int target_x_low = -1, target_x_high = -1; if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, source_high, target_x_low, target_x_high, @@ -279,6 +297,7 @@ void RuleExtractor::AddNonterminalExtremities( return; } + // Check gap integrity for the leading nonterminal. if (extend_left) { int source_gap_low = -1, source_gap_high = -1; int target_gap_low = -1, target_gap_high = -1; @@ -294,6 +313,7 @@ void RuleExtractor::AddNonterminalExtremities( make_pair(target_gap_low, target_gap_high)); } + // Check gap integrity for the trailing nonterminal. if (extend_right) { int target_gap_low = -1, target_gap_high = -1; int source_gap_low = -1, source_gap_high = -1; @@ -308,6 +328,7 @@ void RuleExtractor::AddNonterminalExtremities( target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); } + // Find target phrases aligned with the extended source phrase. Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, extend_right); unordered_map<int, int> source_indexes = helper->GetSourceIndexes( diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h index 8b6daeea..bfec0225 100644 --- a/extractor/rule_extractor.h +++ b/extractor/rule_extractor.h @@ -22,6 +22,10 @@ class RuleExtractorHelper; class Scorer; class TargetPhraseExtractor; +/** + * Structure containing data about the occurrences of a source-target phrase pair + * in the parallel corpus. + */ struct Extract { Extract(const Phrase& source_phrase, const Phrase& target_phrase, double pairs_count, const PhraseAlignment& alignment) : @@ -34,6 +38,9 @@ struct Extract { PhraseAlignment alignment; }; +/** + * Component for extracting SCFG rules. + */ class RuleExtractor { public: RuleExtractor(shared_ptr<DataArray> source_data_array, @@ -64,6 +71,8 @@ class RuleExtractor { virtual ~RuleExtractor(); + // Extracts SCFG rules given a source phrase and a set of its occurrences + // in the source data. virtual vector<Rule> ExtractRules(const Phrase& phrase, const PhraseLocation& location) const; @@ -71,15 +80,22 @@ class RuleExtractor { RuleExtractor(); private: + // Finds all target phrases that can be aligned with the source phrase for a + // particular occurrence in the data. vector<Extract> ExtractAlignments(const Phrase& phrase, const vector<int>& matching) const; + // Extracts all target phrases for a given occurrence of the source phrase in + // the data. Constructs a vector of Extracts using these target phrases. void AddExtracts( vector<Extract>& extracts, const Phrase& source_phrase, const unordered_map<int, int>& source_indexes, const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, int target_phrase_low, int target_phrase_high, int sentence_id) const; + // Adds a leading and/or trailing nonterminal to the source phrase and + // extracts target phrases that can be aligned with the extended source + // phrase. void AddNonterminalExtremities( vector<Extract>& extracts, const vector<int>& matching, const vector<int>& chunklen, const Phrase& source_phrase, diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc index 81b522f0..6410d147 100644 --- a/extractor/rule_extractor_helper.cc +++ b/extractor/rule_extractor_helper.cc @@ -88,6 +88,7 @@ bool RuleExtractorHelper::CheckTightPhrases( return true; } + // Check if the chunk extremities are aligned. int sentence_id = source_data_array->GetSentenceId(matching[0]); int source_sent_start = source_data_array->GetSentenceStart(sentence_id); for (size_t i = 0; i + 1 < chunklen.size(); ++i) { @@ -126,6 +127,7 @@ bool RuleExtractorHelper::FindFixPoint( int source_sent_len = source_data_array->GetSentenceLength(sentence_id); int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + // Extend the target span to the left. if (prev_target_low != -1 && target_phrase_low != prev_target_low) { if (prev_target_low - target_phrase_low < min_target_gap_size) { target_phrase_low = prev_target_low - min_target_gap_size; @@ -135,6 +137,7 @@ bool RuleExtractorHelper::FindFixPoint( } } + // Extend the target span to the right. if (prev_target_high != -1 && target_phrase_high != prev_target_high) { if (target_phrase_high - prev_target_high < min_target_gap_size) { target_phrase_high = prev_target_high + min_target_gap_size; @@ -144,10 +147,12 @@ bool RuleExtractorHelper::FindFixPoint( } } + // Check target span length. if (target_phrase_high - target_phrase_low > max_rule_span) { return false; } + // Find the initial reflected source span. source_back_low = source_back_high = -1; FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, source_back_low, source_back_high); @@ -157,6 +162,7 @@ bool RuleExtractorHelper::FindFixPoint( source_back_low = min(source_back_low, source_phrase_low); source_back_high = max(source_back_high, source_phrase_high); + // Stop if the reflected source span matches the previous source span. if (source_back_low == source_phrase_low && source_back_high == source_phrase_high) { return true; @@ -212,10 +218,14 @@ bool RuleExtractorHelper::FindFixPoint( prev_target_low = target_phrase_low; prev_target_high = target_phrase_high; + // Find the reflection including the left gap (if one was added). FindProjection(source_back_low, source_phrase_low, source_low, source_high, target_phrase_low, target_phrase_high); + // Find the reflection including the right gap (if one was added). FindProjection(source_phrase_high, source_back_high, source_low, source_high, target_phrase_low, target_phrase_high); + // Stop if the new re-reflected target span matches the previous target + // span. if (prev_target_low == target_phrase_low && prev_target_high == target_phrase_high) { return true; @@ -232,6 +242,7 @@ bool RuleExtractorHelper::FindFixPoint( source_phrase_low = source_back_low; source_phrase_high = source_back_high; + // Re-reflect the target span. FindProjection(target_phrase_low, prev_target_low, target_low, target_high, source_back_low, source_back_high); FindProjection(prev_target_high, target_phrase_high, target_low, diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h index 7bf80c4b..bea75bc3 100644 --- a/extractor/rule_extractor_helper.h +++ b/extractor/rule_extractor_helper.h @@ -12,6 +12,9 @@ namespace extractor { class Alignment; class DataArray; +/** + * Helper class for extracting SCFG rules. + */ class RuleExtractorHelper { public: RuleExtractorHelper(shared_ptr<DataArray> source_data_array, @@ -25,18 +28,23 @@ class RuleExtractorHelper { virtual ~RuleExtractorHelper(); + // Find the alignment span for each word in the source target sentence pair. virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high, vector<int>& target_low, vector<int>& target_high, int sentence_id) const; + // Check if one chunk (all chunks) is aligned at least in one point. virtual bool CheckAlignedTerminals(const vector<int>& matching, const vector<int>& chunklen, const vector<int>& source_low) const; + // Check if the chunks are tight. virtual bool CheckTightPhrases(const vector<int>& matching, const vector<int>& chunklen, const vector<int>& source_low) const; + // Find the target span and the reflected source span for a source phrase + // occurrence. virtual bool FindFixPoint( int source_phrase_low, int source_phrase_high, const vector<int>& source_low, const vector<int>& source_high, @@ -47,6 +55,7 @@ class RuleExtractorHelper { int max_new_x, bool allow_low_x, bool allow_high_x, bool allow_arbitrary_expansion) const; + // Find the gap spans for each nonterminal in the source phrase. virtual bool GetGaps( vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, const vector<int>& matching, const vector<int>& chunklen, @@ -55,8 +64,10 @@ class RuleExtractorHelper { int source_phrase_low, int source_phrase_high, int source_back_low, int source_back_high, int& num_symbols, bool& met_constraints) const; + // Get the order of the nonterminals in the target phrase. virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const; + // Map each terminal symbol with its position in the source phrase. virtual unordered_map<int, int> GetSourceIndexes( const vector<int>& matching, const vector<int>& chunklen, int starts_with_x) const; @@ -65,6 +76,8 @@ class RuleExtractorHelper { RuleExtractorHelper(); private: + // Find the projection of a source phrase in the target sentence. May also be + // used to find the projection of a target phrase in the source sentence. void FindProjection( int source_phrase_low, int source_phrase_high, const vector<int>& source_low, const vector<int>& source_high, diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index fbc62e50..8c30fb9e 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -152,12 +152,18 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { + // For phrases containing a nonterminal, we use either the occurrences + // of the prefix or the suffix to determine the occurrences of the + // phrase. Clock::time_point intersect_start = Clock::now(); phrase_location = fast_intersector->Intersect( node->matchings, next_suffix_link->matchings, next_phrase); Clock::time_point intersect_stop = Clock::now(); total_intersect_time += GetDuration(intersect_start, intersect_stop); } else { + // For phrases not containing any nonterminals, we simply query the + // suffix array using the suffix array range of the prefix as a + // starting point. Clock::time_point lookup_start = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, @@ -170,9 +176,12 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { if (phrase_location.IsEmpty()) { continue; } + + // Create new trie node to store data about the current phrase. next_node = make_shared<TrieNode>( next_suffix_link, next_phrase, phrase_location); } + // Add the new trie node to the trie cache. node->AddChild(word_id, next_node); // Automatically adds a trailing non terminal if allowed. Simply copy the @@ -182,6 +191,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { + // Extract rules for the sampled set of occurrences. PhraseLocation sample = sampler->Sample(next_node->matchings); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); @@ -193,6 +203,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { next_node = node->GetChild(word_id); } + // Create more states (phrases) to be analyzed. vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase, next_node); for (State new_state: new_states) { @@ -262,6 +273,7 @@ vector<State> HieroCachingRuleFactory::ExtendState( return new_states; } + // New state for adding the next symbol. new_states.push_back(State(state.start, state.end + 1, symbols, state.subpatterns_start, node, state.starts_with_x)); @@ -272,6 +284,7 @@ vector<State> HieroCachingRuleFactory::ExtendState( return new_states; } + // New states for adding a nonterminal followed by a new symbol. int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); symbols.push_back(var_id); vector<int> subpatterns_start = state.subpatterns_start; diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index d8dc2ccc..52e8712a 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -25,6 +25,17 @@ class State; class SuffixArray; class Vocabulary; +/** + * Component containing most of the logic for extracting SCFG rules for a given + * sentence. + * + * Given a sentence (as a vector of word ids), this class constructs all the + * possible source phrases starting from this sentence. For each source phrase, + * it finds all its occurrences in the source data and samples some of these + * occurrences to extract aligned source-target phrase pairs. A trie cache is + * used to avoid unnecessary computations if a source phrase can be constructed + * more than once (e.g. some words occur more than once in the sentence). + */ class HieroCachingRuleFactory { public: HieroCachingRuleFactory( @@ -58,21 +69,30 @@ class HieroCachingRuleFactory { virtual ~HieroCachingRuleFactory(); + // Constructs SCFG rules for a given sentence. + // (See class description for more details.) virtual Grammar GetGrammar(const vector<int>& word_ids); protected: HieroCachingRuleFactory(); private: + // Checks if the phrase (if previously encountered) or its prefix have any + // occurrences in the source data. bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id); + // Checks if the phrase has previously been analyzed. bool RequiresLookup(shared_ptr<TrieNode> node, int word_id); + // Creates a new state in the trie that corresponds to adding a trailing + // nonterminal to the current phrase. void AddTrailingNonterminal(vector<int> symbols, const Phrase& prefix, const shared_ptr<TrieNode>& prefix_node, bool starts_with_x); + // Extends the current state by possibly adding a nonterminal followed by a + // terminal. vector<State> ExtendState(const vector<int>& word_ids, const State& state, vector<int> symbols, diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index dba4578c..d5ff23b2 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -35,6 +35,7 @@ using namespace std; using namespace extractor; using namespace features; +// Returns the file path in which a given grammar should be written. fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { string file_name = "grammar." + to_string(file_number); return grammar_path / file_name; @@ -45,6 +46,7 @@ int main(int argc, char** argv) { #pragma omp parallel num_threads_default = omp_get_num_threads(); + // Sets up the command line arguments map. po::options_description desc("Command line options"); desc.add_options() ("help,h", "Show available options") @@ -69,7 +71,7 @@ int main(int argc, char** argv) { ("max_nonterminals", po::value<int>()->default_value(2), "Maximum number of nonterminals in a rule") ("min_frequency", po::value<int>()->default_value(1000), - "Minimum number of occurences for a pharse to be considered frequent") + "Minimum number of occurrences for a pharse to be considered frequent") ("max_samples", po::value<int>()->default_value(300), "Maximum number of samples") ("tight_phrases", po::value<bool>()->default_value(true), @@ -78,8 +80,8 @@ int main(int argc, char** argv) { po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); - // Check for help argument before notify, so we don't need to pass in the - // required parameters. + // Checks for the help option before calling notify, so the we don't get an + // exception for missing required arguments. if (vm.count("help")) { cout << desc << endl; return 0; @@ -94,6 +96,7 @@ int main(int argc, char** argv) { return 1; } + // Reads the parallel corpus. Clock::time_point preprocess_start_time = Clock::now(); cerr << "Reading source and target data..." << endl; Clock::time_point start_time = Clock::now(); @@ -111,6 +114,7 @@ int main(int argc, char** argv) { cerr << "Reading data took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs the suffix array for the source data. cerr << "Creating source suffix array..." << endl; start_time = Clock::now(); shared_ptr<SuffixArray> source_suffix_array = @@ -119,6 +123,7 @@ int main(int argc, char** argv) { cerr << "Creating suffix array took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Reads the alignment. cerr << "Reading alignment..." << endl; start_time = Clock::now(); shared_ptr<Alignment> alignment = @@ -127,6 +132,8 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs an index storing the occurrences in the source data for each + // frequent collocation. cerr << "Precomputing collocations..." << endl; start_time = Clock::now(); shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( @@ -142,6 +149,8 @@ int main(int argc, char** argv) { cerr << "Precomputing collocations took " << GetDuration(start_time, stop_time) << " seconds" << endl; + // Constructs a table storing p(e | f) and p(f | e) for every pair of source + // and target words. cerr << "Precomputing conditional probabilities..." << endl; start_time = Clock::now(); shared_ptr<TranslationTable> table = make_shared<TranslationTable>( @@ -155,6 +164,7 @@ int main(int argc, char** argv) { << GetDuration(preprocess_start_time, preprocess_stop_time) << " seconds" << endl; + // Features used to score each grammar rule. Clock::time_point extraction_start_time = Clock::now(); vector<shared_ptr<Feature> > features = { make_shared<TargetGivenSourceCoherent>(), @@ -167,6 +177,7 @@ int main(int argc, char** argv) { }; shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + // Sets up the grammar extractor. GrammarExtractor extractor( source_suffix_array, target_data_array, @@ -180,26 +191,30 @@ int main(int argc, char** argv) { vm["max_samples"].as<int>(), vm["tight_phrases"].as<bool>()); - // Release extra memory used by the initial precomputation. + // Releases extra memory used by the initial precomputation. precomputation.reset(); + // Creates the grammars directory if it doesn't exist. fs::path grammar_path = vm["grammars"].as<string>(); if (!fs::is_directory(grammar_path)) { fs::create_directory(grammar_path); } + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). string sentence; vector<string> sentences; while (getline(cin, sentence)) { sentences.push_back(sentence); } + // Extracts the grammar for each sentence and saves it to a file. vector<string> suffixes(sentences.size()); #pragma omp parallel for schedule(dynamic) \ num_threads(vm["threads"].as<int>()) for (size_t i = 0; i < sentences.size(); ++i) { - string delimiter = "|||", suffix; - int position = sentences[i].find(delimiter); + string suffix; + int position = sentences[i].find("|||"); if (position != sentences[i].npos) { suffix = sentences[i].substr(position); sentences[i] = sentences[i].substr(0, position); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index f64a408c..d81956b5 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -16,6 +16,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { vector<int> sample; int num_subpatterns; if (location.matchings == NULL) { + // Sample suffix array range. num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); @@ -23,6 +24,7 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { sample.push_back(suffix_array->GetSuffix(Round(i))); } } else { + // Sample vector of occurrences. num_subpatterns = location.num_subpatterns; int num_matchings = location.matchings->size() / num_subpatterns; double step = max(1.0, (double) num_matchings / max_samples); diff --git a/extractor/sampler.h b/extractor/sampler.h index cda28b10..be4aa1bb 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -10,18 +10,23 @@ namespace extractor { class PhraseLocation; class SuffixArray; +/** + * Provides uniform sampling for a PhraseLocation. + */ class Sampler { public: Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); virtual ~Sampler(); + // Samples uniformly at most max_samples phrase occurrences. virtual PhraseLocation Sample(const PhraseLocation& location) const; protected: Sampler(); private: + // Round floating point number to the nearest integer. int Round(double x) const; shared_ptr<SuffixArray> suffix_array; diff --git a/extractor/scorer.h b/extractor/scorer.h index c31db0ca..af8a3b10 100644 --- a/extractor/scorer.h +++ b/extractor/scorer.h @@ -14,14 +14,19 @@ namespace features { class FeatureContext; } // namespace features +/** + * Computes the feature scores for a source-target phrase pair. + */ class Scorer { public: Scorer(const vector<shared_ptr<features::Feature> >& features); virtual ~Scorer(); + // Computes the feature score for the given context. virtual vector<double> Score(const features::FeatureContext& context) const; + // Returns the set of feature names used to score any context. virtual vector<string> GetFeatureNames() const; protected: diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index 7a4f1110..bf731d79 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -17,18 +17,26 @@ class PhraseLocation; class SuffixArray { public: + // Creates a suffix array from a data array. SuffixArray(shared_ptr<DataArray> data_array); virtual ~SuffixArray(); + // Returns the size of the suffix array. virtual int GetSize() const; + // Returns the data array on top of which the suffix array is constructed. virtual shared_ptr<DataArray> GetData() const; + // Constructs the longest-common-prefix array using the algorithm of Kasai et + // al. (2001). virtual vector<int> BuildLCPArray() const; + // Returns the i-th suffix. virtual int GetSuffix(int rank) const; + // Given the range in which a phrase is located and the next word, returns the + // range corresponding to the phrase extended with the next word. virtual PhraseLocation Lookup(int low, int high, const string& word, int offset) const; @@ -38,14 +46,23 @@ class SuffixArray { SuffixArray(); private: + // Constructs the suffix array using the algorithm of Larsson and Sadakane + // (1999). void BuildSuffixArray(); + // Bucket sort on the data array (used for initializing the construction of + // the suffix array.) void InitialBucketSort(vector<int>& groups); void TernaryQuicksort(int left, int right, int step, vector<int>& groups); + // Constructs the suffix array in log(n) steps by doubling the length of the + // suffixes at each step. void PrefixDoublingSort(vector<int>& groups); + // Given a [low, high) range in the suffix array in which all elements have + // the first offset-1 values the same, it returns the first position where the + // offset value is greater or equal to word_id. int LookupRangeStart(int low, int high, int word_id, int offset) const; shared_ptr<DataArray> data_array; diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc index 9f8bc6e2..2b8a2e4a 100644 --- a/extractor/target_phrase_extractor.cc +++ b/extractor/target_phrase_extractor.cc @@ -43,11 +43,13 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( int target_x_low = target_phrase_low, target_x_high = target_phrase_high; if (!require_tight_phrases) { + // Extend loose target phrase to the left. while (target_x_low > 0 && target_phrase_high - target_x_low < max_rule_span && target_low[target_x_low - 1] == -1) { --target_x_low; } + // Extend loose target phrase to the right. while (target_x_high < target_sent_len && target_x_high - target_phrase_low < max_rule_span && target_low[target_x_high] == -1) { @@ -59,10 +61,12 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( for (size_t i = 0; i < gaps.size(); ++i) { gaps[i] = target_gaps[target_gap_order[i]]; if (!require_tight_phrases) { + // Extend gap to the left. while (gaps[i].first > target_x_low && target_low[gaps[i].first - 1] == -1) { --gaps[i].first; } + // Extend gap to the right. while (gaps[i].second < target_x_high && target_low[gaps[i].second] == -1) { ++gaps[i].second; @@ -70,6 +74,9 @@ vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( } } + // Compute the range in which each chunk may start or end. (Even indexes + // represent the range in which the chunk may start, odd indexes represent the + // range in which the chunk may end.) vector<pair<int, int> > ranges(2 * gaps.size() + 2); ranges.front() = make_pair(target_x_low, target_phrase_low); ranges.back() = make_pair(target_phrase_high, target_x_high); @@ -101,6 +108,7 @@ void TargetPhraseExtractor::GeneratePhrases( vector<int> symbols; unordered_map<int, int> target_indexes; + // Construct target phrase chunk by chunk. int target_sent_start = target_data_array->GetSentenceStart(sentence_id); for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { @@ -115,6 +123,7 @@ void TargetPhraseExtractor::GeneratePhrases( } } + // Construct the alignment between the source and the target phrase. vector<pair<int, int> > links = alignment->GetLinks(sentence_id); vector<pair<int, int> > alignment; for (pair<int, int> link: links) { @@ -133,6 +142,7 @@ void TargetPhraseExtractor::GeneratePhrases( if (index > 0) { subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); } + // Choose every possible combination of [start, end) for the current chunk. while (subpatterns[index] <= ranges[index].second) { subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); while (subpatterns[index + 1] <= ranges[index + 1].second) { diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h index a4b54145..289bae2f 100644 --- a/extractor/target_phrase_extractor.h +++ b/extractor/target_phrase_extractor.h @@ -30,6 +30,8 @@ class TargetPhraseExtractor { virtual ~TargetPhraseExtractor(); + // Finds all the target phrases that can extracted from a span in the + // target sentence (matching the given set of target phrase gaps). virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases( const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, int target_phrase_low, int target_phrase_high, @@ -39,6 +41,8 @@ class TargetPhraseExtractor { TargetPhraseExtractor(); private: + // Computes the cartesian product over the sets of possible target phrase + // chunks. void GeneratePhrases( vector<pair<Phrase, PhraseAlignment> >& target_phrases, const vector<pair<int, int> >& ranges, int index, diff --git a/extractor/time_util.h b/extractor/time_util.h index 45f79199..f7fd51d3 100644 --- a/extractor/time_util.h +++ b/extractor/time_util.h @@ -10,6 +10,7 @@ namespace extractor { typedef high_resolution_clock Clock; +// Computes the duration in seconds of the specified time interval. double GetDuration(const Clock::time_point& start_time, const Clock::time_point& stop_time); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 1852a357..45da707a 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -23,6 +23,8 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, unordered_map<int, int> target_links_count; unordered_map<pair<int, int>, int, PairHash> links_count; + // For each pair of aligned source target words increment their link count by + // 1. Unaligned words are paired with the NULL token. for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { vector<pair<int, int> > links = alignment->GetLinks(i); int source_start = source_data_array->GetSentenceStart(i); @@ -40,25 +42,28 @@ TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, for (pair<int, int> link: links) { source_linked_words[link.first] = 1; target_linked_words[link.second] = 1; - IncreaseLinksCount(source_links_count, target_links_count, links_count, + IncrementLinksCount(source_links_count, target_links_count, links_count, source_sentence[link.first], target_sentence[link.second]); } for (size_t i = 0; i < source_sentence.size(); ++i) { if (!source_linked_words[i]) { - IncreaseLinksCount(source_links_count, target_links_count, links_count, - source_sentence[i], DataArray::NULL_WORD); + IncrementLinksCount(source_links_count, target_links_count, links_count, + source_sentence[i], DataArray::NULL_WORD); } } for (size_t i = 0; i < target_sentence.size(); ++i) { if (!target_linked_words[i]) { - IncreaseLinksCount(source_links_count, target_links_count, links_count, - DataArray::NULL_WORD, target_sentence[i]); + IncrementLinksCount(source_links_count, target_links_count, links_count, + DataArray::NULL_WORD, target_sentence[i]); } } } + // Calculating: + // p(e | f) = count(e, f) / count(f) + // p(f | e) = count(e, f) / count(e) for (pair<pair<int, int>, int> link_count: links_count) { int source_word = link_count.first.first; int target_word = link_count.first.second; @@ -72,7 +77,7 @@ TranslationTable::TranslationTable() {} TranslationTable::~TranslationTable() {} -void TranslationTable::IncreaseLinksCount( +void TranslationTable::IncrementLinksCount( unordered_map<int, int>& source_links_count, unordered_map<int, int>& target_links_count, unordered_map<pair<int, int>, int, PairHash>& links_count, diff --git a/extractor/translation_table.h b/extractor/translation_table.h index a7be26f5..10504d3b 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -18,6 +18,9 @@ typedef boost::hash<pair<int, int> > PairHash; class Alignment; class DataArray; +/** + * Bilexical table with conditional probabilities. + */ class TranslationTable { public: TranslationTable( @@ -27,9 +30,11 @@ class TranslationTable { virtual ~TranslationTable(); + // Returns p(e | f). virtual double GetTargetGivenSourceScore(const string& source_word, const string& target_word); + // Returns p(f | e). virtual double GetSourceGivenTargetScore(const string& source_word, const string& target_word); @@ -39,7 +44,8 @@ class TranslationTable { TranslationTable(); private: - void IncreaseLinksCount( + // Increment links count for the given (f, e) word pair. + void IncrementLinksCount( unordered_map<int, int>& source_links_count, unordered_map<int, int>& target_links_count, unordered_map<pair<int, int>, int, PairHash>& links_count, diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index 03c7dc66..c8fd9411 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -9,16 +9,33 @@ using namespace std; namespace extractor { +/** + * Data structure for mapping words to word ids. + * + * This strucure contains words located in the frequent collocations and words + * encountered during the grammar extraction time. This dictionary is + * considerably smaller than the dictionaries in the data arrays (and so is the + * query time). Note that this is the single data structure that changes state + * and needs to have thread safe read/write operations. + * + * Note: For an experiment using different vocabulary instances for each thread, + * the running time did not improve implying that the critical regions do not + * cause bottlenecks. + */ class Vocabulary { public: virtual ~Vocabulary(); + // Returns the word id for the given word. virtual int GetTerminalIndex(const string& word); + // Returns the id for a nonterminal located at the given position in a phrase. int GetNonterminalIndex(int position); + // Checks if a symbol is a nonterminal. bool IsTerminal(int symbol); + // Returns the word corresponding to the given word id. virtual string GetTerminalValue(int symbol); private: |