#include "rule_extractor_helper.h" #include "data_array.h" #include "alignment.h" namespace extractor { RuleExtractorHelper::RuleExtractorHelper( shared_ptr<DataArray> source_data_array, shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment, int max_rule_span, int max_rule_symbols, bool require_aligned_terminal, bool require_aligned_chunks, bool require_tight_phrases) : source_data_array(source_data_array), target_data_array(target_data_array), alignment(alignment), max_rule_span(max_rule_span), max_rule_symbols(max_rule_symbols), require_aligned_terminal(require_aligned_terminal), require_aligned_chunks(require_aligned_chunks), require_tight_phrases(require_tight_phrases) {} RuleExtractorHelper::RuleExtractorHelper() {} RuleExtractorHelper::~RuleExtractorHelper() {} void RuleExtractorHelper::GetLinksSpans( vector<int>& source_low, vector<int>& source_high, vector<int>& target_low, vector<int>& target_high, int sentence_id) const { int source_sent_len = source_data_array->GetSentenceLength(sentence_id); int target_sent_len = target_data_array->GetSentenceLength(sentence_id); source_low = vector<int>(source_sent_len, -1); source_high = vector<int>(source_sent_len, -1); target_low = vector<int>(target_sent_len, -1); target_high = vector<int>(target_sent_len, -1); vector<pair<int, int>> links = alignment->GetLinks(sentence_id); for (auto link: links) { if (source_low[link.first] == -1 || source_low[link.first] > link.second) { source_low[link.first] = link.second; } source_high[link.first] = max(source_high[link.first], link.second + 1); if (target_low[link.second] == -1 || target_low[link.second] > link.first) { target_low[link.second] = link.first; } target_high[link.second] = max(target_high[link.second], link.first + 1); } } bool RuleExtractorHelper::CheckAlignedTerminals( const vector<int>& matching, const vector<int>& chunklen, const vector<int>& source_low, int source_sent_start) const { if (!require_aligned_terminal) { return true; } int num_aligned_chunks = 0; for (size_t i = 0; i < chunklen.size(); ++i) { for (size_t j = 0; j < chunklen[i]; ++j) { int sent_index = matching[i] - source_sent_start + j; if (source_low[sent_index] != -1) { ++num_aligned_chunks; break; } } } if (num_aligned_chunks == 0) { return false; } return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); } bool RuleExtractorHelper::CheckTightPhrases( const vector<int>& matching, const vector<int>& chunklen, const vector<int>& source_low, int source_sent_start) const { if (!require_tight_phrases) { return true; } // Check if the chunk extremities are aligned. for (size_t i = 0; i + 1 < chunklen.size(); ++i) { int gap_start = matching[i] + chunklen[i] - source_sent_start; int gap_end = matching[i + 1] - 1 - source_sent_start; if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { return false; } } return true; } bool RuleExtractorHelper::FindFixPoint( int source_phrase_low, int source_phrase_high, const vector<int>& source_low, const vector<int>& source_high, int& target_phrase_low, int& target_phrase_high, const vector<int>& target_low, const vector<int>& target_high, int& source_back_low, int& source_back_high, int sentence_id, int min_source_gap_size, int min_target_gap_size, int max_new_x, bool allow_low_x, bool allow_high_x, bool allow_arbitrary_expansion) const { int prev_target_low = target_phrase_low; int prev_target_high = target_phrase_high; FindProjection(source_phrase_low, source_phrase_high, source_low, source_high, target_phrase_low, target_phrase_high); if (target_phrase_low == -1) { // Note: Low priority corner case inherited from Adam's code: // If w is unaligned, but we don't require aligned terminals, returning an // error here prevents the extraction of the allowed rule // X -> X_1 w X_2 / X_1 X_2 return false; } int source_sent_len = source_data_array->GetSentenceLength(sentence_id); int target_sent_len = target_data_array->GetSentenceLength(sentence_id); // Extend the target span to the left. if (prev_target_low != -1 && target_phrase_low != prev_target_low) { if (prev_target_low - target_phrase_low < min_target_gap_size) { target_phrase_low = prev_target_low - min_target_gap_size; if (target_phrase_low < 0) { return false; } } } // Extend the target span to the right. if (prev_target_high != -1 && target_phrase_high != prev_target_high) { if (target_phrase_high - prev_target_high < min_target_gap_size) { target_phrase_high = prev_target_high + min_target_gap_size; if (target_phrase_high > target_sent_len) { return false; } } } // Check target span length. if (target_phrase_high - target_phrase_low > max_rule_span) { return false; } // Find the initial reflected source span. source_back_low = source_back_high = -1; FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, source_back_low, source_back_high); int new_x = 0; bool new_low_x = false, new_high_x = false; while (true) { source_back_low = min(source_back_low, source_phrase_low); source_back_high = max(source_back_high, source_phrase_high); // Stop if the reflected source span matches the previous source span. if (source_back_low == source_phrase_low && source_back_high == source_phrase_high) { return true; } if (!allow_low_x && source_back_low < source_phrase_low) { // Extension on the left side not allowed. return false; } if (!allow_high_x && source_back_high > source_phrase_high) { // Extension on the right side not allowed. return false; } // Extend left side. if (source_back_low < source_phrase_low) { if (new_low_x == false) { if (new_x >= max_new_x) { return false; } new_low_x = true; ++new_x; } if (source_phrase_low - source_back_low < min_source_gap_size) { source_back_low = source_phrase_low - min_source_gap_size; if (source_back_low < 0) { return false; } } } // Extend right side. if (source_back_high > source_phrase_high) { if (new_high_x == false) { if (new_x >= max_new_x) { return false; } new_high_x = true; ++new_x; } if (source_back_high - source_phrase_high < min_source_gap_size) { source_back_high = source_phrase_high + min_source_gap_size; if (source_back_high > source_sent_len) { return false; } } } if (source_back_high - source_back_low > max_rule_span) { // Rule span too wide. return false; } prev_target_low = target_phrase_low; prev_target_high = target_phrase_high; // Find the reflection including the left gap (if one was added). FindProjection(source_back_low, source_phrase_low, source_low, source_high, target_phrase_low, target_phrase_high); // Find the reflection including the right gap (if one was added). FindProjection(source_phrase_high, source_back_high, source_low, source_high, target_phrase_low, target_phrase_high); // Stop if the new re-reflected target span matches the previous target // span. if (prev_target_low == target_phrase_low && prev_target_high == target_phrase_high) { return true; } if (!allow_arbitrary_expansion) { // Arbitrary expansion not allowed. return false; } if (target_phrase_high - target_phrase_low > max_rule_span) { // Target side too wide. return false; } source_phrase_low = source_back_low; source_phrase_high = source_back_high; // Re-reflect the target span. FindProjection(target_phrase_low, prev_target_low, target_low, target_high, source_back_low, source_back_high); FindProjection(prev_target_high, target_phrase_high, target_low, target_high, source_back_low, source_back_high); } return false; } void RuleExtractorHelper::FindProjection( int source_phrase_low, int source_phrase_high, const vector<int>& source_low, const vector<int>& source_high, int& target_phrase_low, int& target_phrase_high) const { for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { if (source_low[i] != -1) { if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { target_phrase_low = source_low[i]; } target_phrase_high = max(target_phrase_high, source_high[i]); } } } bool RuleExtractorHelper::GetGaps( vector<pair<int, int>>& source_gaps, vector<pair<int, int>>& target_gaps, const vector<int>& matching, const vector<int>& chunklen, const vector<int>& source_low, const vector<int>& source_high, const vector<int>& target_low, const vector<int>& target_high, int source_phrase_low, int source_phrase_high, int source_back_low, int source_back_high, int sentence_id, int source_sent_start, int& num_symbols, bool& met_constraints) const { if (source_back_low < source_phrase_low) { source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); if (num_symbols >= max_rule_symbols) { // Source side contains too many symbols. return false; } ++num_symbols; if (require_tight_phrases && (source_low[source_back_low] == -1 || source_low[source_phrase_low - 1] == -1)) { // Inside edges of preceding gap are not tight. return false; } } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { // This is not a hard error. We can't extract this phrase, but we might // still be able to extract a superphrase. met_constraints = false; } for (size_t i = 0; i + 1 < chunklen.size(); ++i) { int gap_start = matching[i] + chunklen[i] - source_sent_start; int gap_end = matching[i + 1] - source_sent_start; source_gaps.push_back(make_pair(gap_start, gap_end)); } if (source_phrase_high < source_back_high) { source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); if (num_symbols >= max_rule_symbols) { // Source side contains too many symbols. return false; } ++num_symbols; if (require_tight_phrases && (source_low[source_phrase_high] == -1 || source_low[source_back_high - 1] == -1)) { // Inside edges of following gap are not tight. return false; } } else if (require_tight_phrases && source_low[source_phrase_high - 1] == -1) { // This is not a hard error. We can't extract this phrase, but we might // still be able to extract a superphrase. met_constraints = false; } target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); for (size_t i = 0; i < source_gaps.size(); ++i) { if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, source_high, target_gaps[i].first, target_gaps[i].second, target_low, target_high, source_gaps[i].first, source_gaps[i].second, sentence_id, 0, 0, 0, false, false, false)) { // Gap fails integrity check. return false; } } return true; } vector<int> RuleExtractorHelper::GetGapOrder( const vector<pair<int, int>>& gaps) const { vector<int> gap_order(gaps.size()); for (size_t i = 0; i < gap_order.size(); ++i) { for (size_t j = 0; j < i; ++j) { if (gaps[gap_order[j]] < gaps[i]) { ++gap_order[i]; } else { ++gap_order[j]; } } } return gap_order; } unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes( const vector<int>& matching, const vector<int>& chunklen, int starts_with_x, int source_sent_start) const { unordered_map<int, int> source_indexes; int num_symbols = starts_with_x; for (size_t i = 0; i < matching.size(); ++i) { for (size_t j = 0; j < chunklen[i]; ++j) { source_indexes[matching[i] + j - source_sent_start] = num_symbols; ++num_symbols; } ++num_symbols; } return source_indexes; } } // namespace extractor