diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
commit | 6d347f1ce078dede3da0e1498f75e357351c6543 (patch) | |
tree | 8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/rule_extractor_helper.cc | |
parent | d11b76def6899790161c47a73018146311356d8b (diff) | |
parent | 5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff) |
merge paul's extractor code
Diffstat (limited to 'extractor/rule_extractor_helper.cc')
-rw-r--r-- | extractor/rule_extractor_helper.cc | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc new file mode 100644 index 00000000..8a9516f2 --- /dev/null +++ b/extractor/rule_extractor_helper.cc @@ -0,0 +1,362 @@ +#include "rule_extractor_helper.h" + +#include "data_array.h" +#include "alignment.h" + +namespace extractor { + +RuleExtractorHelper::RuleExtractorHelper( + shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + int max_rule_span, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases) : + source_data_array(source_data_array), + target_data_array(target_data_array), + alignment(alignment), + max_rule_span(max_rule_span), + max_rule_symbols(max_rule_symbols), + require_aligned_terminal(require_aligned_terminal), + require_aligned_chunks(require_aligned_chunks), + require_tight_phrases(require_tight_phrases) {} + +RuleExtractorHelper::RuleExtractorHelper() {} + +RuleExtractorHelper::~RuleExtractorHelper() {} + +void RuleExtractorHelper::GetLinksSpans( + vector<int>& source_low, vector<int>& source_high, + vector<int>& target_low, vector<int>& target_high, int sentence_id) const { + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + source_low = vector<int>(source_sent_len, -1); + source_high = vector<int>(source_sent_len, -1); + + target_low = vector<int>(target_sent_len, -1); + target_high = vector<int>(target_sent_len, -1); + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + for (auto link: links) { + if (source_low[link.first] == -1 || source_low[link.first] > link.second) { + source_low[link.first] = link.second; + } + source_high[link.first] = max(source_high[link.first], link.second + 1); + + if (target_low[link.second] == -1 || target_low[link.second] > link.first) { + target_low[link.second] = link.first; + } + target_high[link.second] = max(target_high[link.second], link.first + 1); + } +} + +bool RuleExtractorHelper::CheckAlignedTerminals( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low, + int source_sent_start) const { + if (!require_aligned_terminal) { + return true; + } + + int num_aligned_chunks = 0; + for (size_t i = 0; i < chunklen.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + int sent_index = matching[i] - source_sent_start + j; + if (source_low[sent_index] != -1) { + ++num_aligned_chunks; + break; + } + } + } + + if (num_aligned_chunks == 0) { + return false; + } + + return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); +} + +bool RuleExtractorHelper::CheckTightPhrases( + const vector<int>& matching, + const vector<int>& chunklen, + const vector<int>& source_low, + int source_sent_start) const { + if (!require_tight_phrases) { + return true; + } + + // Check if the chunk extremities are aligned. + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - 1 - source_sent_start; + if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { + return false; + } + } + + return true; +} + +bool RuleExtractorHelper::FindFixPoint( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high, + const vector<int>& target_low, const vector<int>& target_high, + int& source_back_low, int& source_back_high, int sentence_id, + int min_source_gap_size, int min_target_gap_size, + int max_new_x, bool allow_low_x, bool allow_high_x, + bool allow_arbitrary_expansion) const { + int prev_target_low = target_phrase_low; + int prev_target_high = target_phrase_high; + + FindProjection(source_phrase_low, source_phrase_high, source_low, + source_high, target_phrase_low, target_phrase_high); + + if (target_phrase_low == -1) { + // Note: Low priority corner case inherited from Adam's code: + // If w is unaligned, but we don't require aligned terminals, returning an + // error here prevents the extraction of the allowed rule + // X -> X_1 w X_2 / X_1 X_2 + return false; + } + + int source_sent_len = source_data_array->GetSentenceLength(sentence_id); + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + // Extend the target span to the left. + if (prev_target_low != -1 && target_phrase_low != prev_target_low) { + if (prev_target_low - target_phrase_low < min_target_gap_size) { + target_phrase_low = prev_target_low - min_target_gap_size; + if (target_phrase_low < 0) { + return false; + } + } + } + + // Extend the target span to the right. + if (prev_target_high != -1 && target_phrase_high != prev_target_high) { + if (target_phrase_high - prev_target_high < min_target_gap_size) { + target_phrase_high = prev_target_high + min_target_gap_size; + if (target_phrase_high > target_sent_len) { + return false; + } + } + } + + // Check target span length. + if (target_phrase_high - target_phrase_low > max_rule_span) { + return false; + } + + // Find the initial reflected source span. + source_back_low = source_back_high = -1; + FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, + source_back_low, source_back_high); + int new_x = 0; + bool new_low_x = false, new_high_x = false; + while (true) { + source_back_low = min(source_back_low, source_phrase_low); + source_back_high = max(source_back_high, source_phrase_high); + + // Stop if the reflected source span matches the previous source span. + if (source_back_low == source_phrase_low && + source_back_high == source_phrase_high) { + return true; + } + + if (!allow_low_x && source_back_low < source_phrase_low) { + // Extension on the left side not allowed. + return false; + } + if (!allow_high_x && source_back_high > source_phrase_high) { + // Extension on the right side not allowed. + return false; + } + + // Extend left side. + if (source_back_low < source_phrase_low) { + if (new_low_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_low_x = true; + ++new_x; + } + if (source_phrase_low - source_back_low < min_source_gap_size) { + source_back_low = source_phrase_low - min_source_gap_size; + if (source_back_low < 0) { + return false; + } + } + } + + // Extend right side. + if (source_back_high > source_phrase_high) { + if (new_high_x == false) { + if (new_x >= max_new_x) { + return false; + } + new_high_x = true; + ++new_x; + } + if (source_back_high - source_phrase_high < min_source_gap_size) { + source_back_high = source_phrase_high + min_source_gap_size; + if (source_back_high > source_sent_len) { + return false; + } + } + } + + if (source_back_high - source_back_low > max_rule_span) { + // Rule span too wide. + return false; + } + + prev_target_low = target_phrase_low; + prev_target_high = target_phrase_high; + // Find the reflection including the left gap (if one was added). + FindProjection(source_back_low, source_phrase_low, source_low, source_high, + target_phrase_low, target_phrase_high); + // Find the reflection including the right gap (if one was added). + FindProjection(source_phrase_high, source_back_high, source_low, + source_high, target_phrase_low, target_phrase_high); + // Stop if the new re-reflected target span matches the previous target + // span. + if (prev_target_low == target_phrase_low && + prev_target_high == target_phrase_high) { + return true; + } + + if (!allow_arbitrary_expansion) { + // Arbitrary expansion not allowed. + return false; + } + if (target_phrase_high - target_phrase_low > max_rule_span) { + // Target side too wide. + return false; + } + + source_phrase_low = source_back_low; + source_phrase_high = source_back_high; + // Re-reflect the target span. + FindProjection(target_phrase_low, prev_target_low, target_low, target_high, + source_back_low, source_back_high); + FindProjection(prev_target_high, target_phrase_high, target_low, + target_high, source_back_low, source_back_high); + } + + return false; +} + +void RuleExtractorHelper::FindProjection( + int source_phrase_low, int source_phrase_high, + const vector<int>& source_low, const vector<int>& source_high, + int& target_phrase_low, int& target_phrase_high) const { + for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { + if (source_low[i] != -1) { + if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { + target_phrase_low = source_low[i]; + } + target_phrase_high = max(target_phrase_high, source_high[i]); + } + } +} + +bool RuleExtractorHelper::GetGaps( + vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, + const vector<int>& matching, const vector<int>& chunklen, + const vector<int>& source_low, const vector<int>& source_high, + const vector<int>& target_low, const vector<int>& target_high, + int source_phrase_low, int source_phrase_high, int source_back_low, + int source_back_high, int sentence_id, int source_sent_start, + int& num_symbols, bool& met_constraints) const { + if (source_back_low < source_phrase_low) { + source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_back_low] == -1 || + source_low[source_phrase_low - 1] == -1)) { + // Inside edges of preceding gap are not tight. + return false; + } + } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + for (size_t i = 0; i + 1 < chunklen.size(); ++i) { + int gap_start = matching[i] + chunklen[i] - source_sent_start; + int gap_end = matching[i + 1] - source_sent_start; + source_gaps.push_back(make_pair(gap_start, gap_end)); + } + + if (source_phrase_high < source_back_high) { + source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); + if (num_symbols >= max_rule_symbols) { + // Source side contains too many symbols. + return false; + } + ++num_symbols; + if (require_tight_phrases && (source_low[source_phrase_high] == -1 || + source_low[source_back_high - 1] == -1)) { + // Inside edges of following gap are not tight. + return false; + } + } else if (require_tight_phrases && + source_low[source_phrase_high - 1] == -1) { + // This is not a hard error. We can't extract this phrase, but we might + // still be able to extract a superphrase. + met_constraints = false; + } + + target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); + for (size_t i = 0; i < source_gaps.size(); ++i) { + if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, + source_high, target_gaps[i].first, target_gaps[i].second, + target_low, target_high, source_gaps[i].first, + source_gaps[i].second, sentence_id, 0, 0, 0, false, false, + false)) { + // Gap fails integrity check. + return false; + } + } + + return true; +} + +vector<int> RuleExtractorHelper::GetGapOrder( + const vector<pair<int, int> >& gaps) const { + vector<int> gap_order(gaps.size()); + for (size_t i = 0; i < gap_order.size(); ++i) { + for (size_t j = 0; j < i; ++j) { + if (gaps[gap_order[j]] < gaps[i]) { + ++gap_order[i]; + } else { + ++gap_order[j]; + } + } + } + return gap_order; +} + +unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes( + const vector<int>& matching, const vector<int>& chunklen, + int starts_with_x, int source_sent_start) const { + unordered_map<int, int> source_indexes; + int num_symbols = starts_with_x; + for (size_t i = 0; i < matching.size(); ++i) { + for (size_t j = 0; j < chunklen[i]; ++j) { + source_indexes[matching[i] + j - source_sent_start] = num_symbols; + ++num_symbols; + } + ++num_symbols; + } + return source_indexes; +} + +} // namespace extractor |