diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
commit | 6d347f1ce078dede3da0e1498f75e357351c6543 (patch) | |
tree | 8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/rule_extractor.cc | |
parent | d11b76def6899790161c47a73018146311356d8b (diff) | |
parent | 5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff) |
merge paul's extractor code
Diffstat (limited to 'extractor/rule_extractor.cc')
-rw-r--r-- | extractor/rule_extractor.cc | 343 |
1 files changed, 343 insertions, 0 deletions
diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc new file mode 100644 index 00000000..fa7386a4 --- /dev/null +++ b/extractor/rule_extractor.cc @@ -0,0 +1,343 @@ +#include "rule_extractor.h" + +#include <map> + +#include "alignment.h" +#include "data_array.h" +#include "features/feature.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule.h" +#include "rule_extractor_helper.h" +#include "scorer.h" +#include "target_phrase_extractor.h" + +using namespace std; + +namespace extractor { + +RuleExtractor::RuleExtractor( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + target_data_array(target_data_array), + source_data_array(source_data_array), + phrase_builder(phrase_builder), + scorer(scorer), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), + require_tight_phrases(require_tight_phrases) { + helper = make_shared<RuleExtractorHelper>( + source_data_array, target_data_array, alignment, max_rule_span, + max_rule_symbols, require_aligned_terminal, require_aligned_chunks, + require_tight_phrases); + target_phrase_extractor = make_shared<TargetPhraseExtractor>( + target_data_array, alignment, phrase_builder, helper, vocabulary, + max_rule_span, require_tight_phrases); +} + +RuleExtractor::RuleExtractor( + shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases) : + source_data_array(source_data_array), + phrase_builder(phrase_builder), + scorer(scorer), + target_phrase_extractor(target_phrase_extractor), + helper(helper), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size), + max_nonterminals(max_nonterminals), + max_rule_symbols(max_rule_symbols), + require_tight_phrases(require_tight_phrases) {} + +RuleExtractor::RuleExtractor() {} + +RuleExtractor::~RuleExtractor() {} + +vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const { + int num_subpatterns = location.num_subpatterns; + vector<int> matchings = *location.matchings; + + // Calculate statistics for the (sampled) occurrences of the source phrase. + map<Phrase, double> source_phrase_counter; + map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter; + for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { + vector<int> matching(i, i + num_subpatterns); + vector<Extract> extracts = ExtractAlignments(phrase, matching); + + for (Extract e: extracts) { + source_phrase_counter[e.source_phrase] += e.pairs_count; + alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1; + } + } + + // Compute the feature scores and find the most likely (frequent) alignment + // for each pair of source-target phrases. + int num_samples = matchings.size() / num_subpatterns; + vector<Rule> rules; + for (auto source_phrase_entry: alignments_counter) { + Phrase source_phrase = source_phrase_entry.first; + for (auto target_phrase_entry: source_phrase_entry.second) { + Phrase target_phrase = target_phrase_entry.first; + + int max_locations = 0, num_locations = 0; + PhraseAlignment most_frequent_alignment; + for (auto alignment_entry: target_phrase_entry.second) { + num_locations += alignment_entry.second; + if (alignment_entry.second > max_locations) { + most_frequent_alignment = alignment_entry.first; + max_locations = alignment_entry.second; + } + } + + features::FeatureContext context(source_phrase, target_phrase, + source_phrase_counter[source_phrase], num_locations, num_samples); + vector<double> scores = scorer->Score(context); + rules.push_back(Rule(source_phrase, target_phrase, scores, + most_frequent_alignment)); + } + } + return rules; +} + +vector<Extract> RuleExtractor::ExtractAlignments( + const Phrase& phrase, const vector<int>& matching) const { + vector<Extract> extracts; + int sentence_id = source_data_array->GetSentenceId(matching[0]); + int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + + // Get the span in the opposite sentence for each word in the source-target + // sentece pair. + vector<int> source_low, source_high, target_low, target_high; + helper->GetLinksSpans(source_low, source_high, target_low, target_high, + sentence_id); + + int num_subpatterns = matching.size(); + vector<int> chunklen(num_subpatterns); + for (size_t i = 0; i < num_subpatterns; ++i) { + chunklen[i] = phrase.GetChunkLen(i); + } + + // Basic checks to see if we can extract phrase pairs for this occurrence. + if (!helper->CheckAlignedTerminals(matching, chunklen, source_low, + source_sent_start) || + !helper->CheckTightPhrases(matching, chunklen, source_low, + source_sent_start)) { + return extracts; + } + + int source_back_low = -1, source_back_high = -1; + int source_phrase_low = matching[0] - source_sent_start; + int source_phrase_high = matching.back() + chunklen.back() - + source_sent_start; + int target_phrase_low = -1, target_phrase_high = -1; + // Find target span and reflected source span for the source phrase. + if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high, + target_low, target_high, source_back_low, + source_back_high, sentence_id, min_gap_size, 0, + max_nonterminals - matching.size() + 1, true, true, + false)) { + return extracts; + } + + // Get spans for nonterminal gaps. + bool met_constraints = true; + int num_symbols = phrase.GetNumSymbols(); + vector<pair<int, int> > source_gaps, target_gaps; + if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low, + source_high, target_low, target_high, source_phrase_low, + source_phrase_high, source_back_low, source_back_high, + sentence_id, source_sent_start, num_symbols, + met_constraints)) { + return extracts; + } + + // Find target phrases aligned with the initial source phrase. + bool starts_with_x = source_back_low != source_phrase_low; + bool ends_with_x = source_back_high != source_phrase_high; + Phrase source_phrase = phrase_builder->Extend( + phrase, starts_with_x, ends_with_x); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, starts_with_x, source_sent_start); + if (met_constraints) { + AddExtracts(extracts, source_phrase, source_indexes, target_gaps, + target_low, target_phrase_low, target_phrase_high, sentence_id); + } + + if (source_gaps.size() >= max_nonterminals || + source_phrase.GetNumSymbols() >= max_rule_symbols || + source_back_high - source_back_low + min_gap_size > max_rule_span) { + // Cannot add any more nonterminals. + return extracts; + } + + // Extend the source phrase by adding a leading and/or trailing nonterminal + // and find target phrases aligned with the extended source phrase. + for (int i = 0; i < 2; ++i) { + for (int j = 1 - i; j < 2; ++j) { + AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, + source_back_low, source_back_high, source_low, source_high, + target_low, target_high, target_gaps, sentence_id, source_sent_start, + starts_with_x, ends_with_x, i, j); + } + } + + return extracts; +} + +void RuleExtractor::AddExtracts( + vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const { + auto target_phrases = target_phrase_extractor->ExtractPhrases( + target_gaps, target_low, target_phrase_low, target_phrase_high, + source_indexes, sentence_id); + + if (target_phrases.size() > 0) { + // Split the probability equally across all target phrases that can be + // aligned with a single occurrence of the source phrase. + double pairs_count = 1.0 / target_phrases.size(); + for (auto target_phrase: target_phrases) { + extracts.push_back(Extract(source_phrase, target_phrase.first, + pairs_count, target_phrase.second)); + } + } +} + +void RuleExtractor::AddNonterminalExtremities( + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, + const vector<int>& source_high, const vector<int>& target_low, + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int source_sent_start, int starts_with_x, int ends_with_x, + int extend_left, int extend_right) const { + int source_x_low = source_back_low, source_x_high = source_back_high; + + // Check if the extended source phrase will remain tight. + if (require_tight_phrases) { + if (source_low[source_back_low - extend_left] == -1 || + source_low[source_back_high + extend_right - 1] == -1) { + return; + } + } + + // Check if we can add a nonterminal to the left. + if (extend_left) { + if (starts_with_x || source_back_low < min_gap_size) { + return; + } + + source_x_low = source_back_low - min_gap_size; + if (require_tight_phrases) { + while (source_x_low >= 0 && source_low[source_x_low] == -1) { + --source_x_low; + } + } + if (source_x_low < 0) { + return; + } + } + + // Check if we can add a nonterminal to the right. + if (extend_right) { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { + return; + } + source_x_high = source_back_high + min_gap_size; + if (require_tight_phrases) { + while (source_x_high <= source_sent_len && + source_low[source_x_high - 1] == -1) { + ++source_x_high; + } + } + + if (source_x_high > source_sent_len) { + return; + } + } + + // More length checks. + int new_nonterminals = extend_left + extend_right; + if (source_x_high - source_x_low > max_rule_span || + target_gaps.size() + new_nonterminals > max_nonterminals || + source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) { + return; + } + + // Find the target span for the extended phrase and the reflected source span. + int target_x_low = -1, target_x_high = -1; + if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, + source_high, target_x_low, target_x_high, + target_low, target_high, source_x_low, + source_x_high, sentence_id, 1, 1, + new_nonterminals, extend_left, extend_right, + true)) { + return; + } + + // Check gap integrity for the leading nonterminal. + if (extend_left) { + int source_gap_low = -1, source_gap_high = -1; + int target_gap_low = -1, target_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_low] == -1) || + !helper->FindFixPoint(source_x_low, source_back_low, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.insert(target_gaps.begin(), + make_pair(target_gap_low, target_gap_high)); + } + + // Check gap integrity for the trailing nonterminal. + if (extend_right) { + int target_gap_low = -1, target_gap_high = -1; + int source_gap_low = -1, source_gap_high = -1; + if ((require_tight_phrases && source_low[source_x_high - 1] == -1) || + !helper->FindFixPoint(source_back_high, source_x_high, source_low, + source_high, target_gap_low, target_gap_high, + target_low, target_high, source_gap_low, + source_gap_high, sentence_id, 0, 0, 0, false, + false, false)) { + return; + } + target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); + } + + // Find target phrases aligned with the extended source phrase. + Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, + extend_right); + unordered_map<int, int> source_indexes = helper->GetSourceIndexes( + matching, chunklen, extend_left || starts_with_x, source_sent_start); + AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps, + target_low, target_x_low, target_x_high, sentence_id); +} + +} // namespace extractor |