#include "target_phrase_extractor.h" #include <unordered_set> #include "alignment.h" #include "data_array.h" #include "phrase.h" #include "phrase_builder.h" #include "rule_extractor_helper.h" #include "vocabulary.h" using namespace std; namespace extractor { TargetPhraseExtractor::TargetPhraseExtractor( shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment, shared_ptr<PhraseBuilder> phrase_builder, shared_ptr<RuleExtractorHelper> helper, shared_ptr<Vocabulary> vocabulary, int max_rule_span, bool require_tight_phrases) : target_data_array(target_data_array), alignment(alignment), phrase_builder(phrase_builder), helper(helper), vocabulary(vocabulary), max_rule_span(max_rule_span), require_tight_phrases(require_tight_phrases) {} TargetPhraseExtractor::TargetPhraseExtractor() {} TargetPhraseExtractor::~TargetPhraseExtractor() {} vector<pair<Phrase, PhraseAlignment>> TargetPhraseExtractor::ExtractPhrases( const vector<pair<int, int>>& target_gaps, const vector<int>& target_low, int target_phrase_low, int target_phrase_high, const unordered_map<int, int>& source_indexes, int sentence_id) const { int target_sent_len = target_data_array->GetSentenceLength(sentence_id); vector<int> target_gap_order = helper->GetGapOrder(target_gaps); int target_x_low = target_phrase_low, target_x_high = target_phrase_high; if (!require_tight_phrases) { // Extend loose target phrase to the left. while (target_x_low > 0 && target_phrase_high - target_x_low < max_rule_span && target_low[target_x_low - 1] == -1) { --target_x_low; } // Extend loose target phrase to the right. while (target_x_high < target_sent_len && target_x_high - target_phrase_low < max_rule_span && target_low[target_x_high] == -1) { ++target_x_high; } } vector<pair<int, int>> gaps(target_gaps.size()); for (size_t i = 0; i < gaps.size(); ++i) { gaps[i] = target_gaps[target_gap_order[i]]; if (!require_tight_phrases) { // Extend gap to the left. while (gaps[i].first > target_x_low && target_low[gaps[i].first - 1] == -1) { --gaps[i].first; } // Extend gap to the right. while (gaps[i].second < target_x_high && target_low[gaps[i].second] == -1) { ++gaps[i].second; } } } // Compute the range in which each chunk may start or end. (Even indexes // represent the range in which the chunk may start, odd indexes represent the // range in which the chunk may end.) vector<pair<int, int>> ranges(2 * gaps.size() + 2); ranges.front() = make_pair(target_x_low, target_phrase_low); ranges.back() = make_pair(target_phrase_high, target_x_high); for (size_t i = 0; i < gaps.size(); ++i) { int j = target_gap_order[i]; ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first); ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second); } vector<pair<Phrase, PhraseAlignment>> target_phrases; vector<int> subpatterns(ranges.size()); GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, target_phrase_low, target_phrase_high, source_indexes, sentence_id); return target_phrases; } void TargetPhraseExtractor::GeneratePhrases( vector<pair<Phrase, PhraseAlignment>>& target_phrases, const vector<pair<int, int>>& ranges, int index, vector<int>& subpatterns, const vector<int>& target_gap_order, int target_phrase_low, int target_phrase_high, const unordered_map<int, int>& source_indexes, int sentence_id) const { if (index >= ranges.size()) { if (subpatterns.back() - subpatterns.front() > max_rule_span) { return; } vector<int> symbols; unordered_map<int, int> target_indexes; // Construct target phrase chunk by chunk. int target_sent_start = target_data_array->GetSentenceStart(sentence_id); for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { target_indexes[j] = symbols.size(); string target_word = target_data_array->GetWordAtIndex( target_sent_start + j); symbols.push_back(vocabulary->GetTerminalIndex(target_word)); } if (i < target_gap_order.size()) { symbols.push_back(vocabulary->GetNonterminalIndex( target_gap_order[i] + 1)); } } // Construct the alignment between the source and the target phrase. vector<pair<int, int>> links = alignment->GetLinks(sentence_id); vector<pair<int, int>> alignment; for (pair<int, int> link: links) { if (target_indexes.count(link.second)) { alignment.push_back(make_pair(source_indexes.find(link.first)->second, target_indexes[link.second])); } } Phrase target_phrase = phrase_builder->Build(symbols); target_phrases.push_back(make_pair(target_phrase, alignment)); return; } subpatterns[index] = ranges[index].first; if (index > 0) { subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); } // Choose every possible combination of [start, end) for the current chunk. while (subpatterns[index] <= ranges[index].second) { subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); while (subpatterns[index + 1] <= ranges[index + 1].second) { GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, target_gap_order, target_phrase_low, target_phrase_high, source_indexes, sentence_id); ++subpatterns[index + 1]; } ++subpatterns[index]; } } } // namespace extractor