summaryrefslogtreecommitdiff
path: root/extractor/rule_extractor_helper.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
commit6d347f1ce078dede3da0e1498f75e357351c6543 (patch)
tree8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/rule_extractor_helper.cc
parentd11b76def6899790161c47a73018146311356d8b (diff)
parent5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff)
merge paul's extractor code
Diffstat (limited to 'extractor/rule_extractor_helper.cc')
-rw-r--r--extractor/rule_extractor_helper.cc362
1 files changed, 362 insertions, 0 deletions
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
new file mode 100644
index 00000000..8a9516f2
--- /dev/null
+++ b/extractor/rule_extractor_helper.cc
@@ -0,0 +1,362 @@
+#include "rule_extractor_helper.h"
+
+#include "data_array.h"
+#include "alignment.h"
+
+namespace extractor {
+
+RuleExtractorHelper::RuleExtractorHelper(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ target_data_array(target_data_array),
+ alignment(alignment),
+ max_rule_span(max_rule_span),
+ max_rule_symbols(max_rule_symbols),
+ require_aligned_terminal(require_aligned_terminal),
+ require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {}
+
+RuleExtractorHelper::RuleExtractorHelper() {}
+
+RuleExtractorHelper::~RuleExtractorHelper() {}
+
+void RuleExtractorHelper::GetLinksSpans(
+ vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ source_low = vector<int>(source_sent_len, -1);
+ source_high = vector<int>(source_sent_len, -1);
+
+ target_low = vector<int>(target_sent_len, -1);
+ target_high = vector<int>(target_sent_len, -1);
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ for (auto link: links) {
+ if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
+ source_low[link.first] = link.second;
+ }
+ source_high[link.first] = max(source_high[link.first], link.second + 1);
+
+ if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
+ target_low[link.second] = link.first;
+ }
+ target_high[link.second] = max(target_high[link.second], link.first + 1);
+ }
+}
+
+bool RuleExtractorHelper::CheckAlignedTerminals(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const {
+ if (!require_aligned_terminal) {
+ return true;
+ }
+
+ int num_aligned_chunks = 0;
+ for (size_t i = 0; i < chunklen.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ int sent_index = matching[i] - source_sent_start + j;
+ if (source_low[sent_index] != -1) {
+ ++num_aligned_chunks;
+ break;
+ }
+ }
+ }
+
+ if (num_aligned_chunks == 0) {
+ return false;
+ }
+
+ return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
+}
+
+bool RuleExtractorHelper::CheckTightPhrases(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low,
+ int source_sent_start) const {
+ if (!require_tight_phrases) {
+ return true;
+ }
+
+ // Check if the chunk extremities are aligned.
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - 1 - source_sent_start;
+ if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RuleExtractorHelper::FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const {
+ int prev_target_low = target_phrase_low;
+ int prev_target_high = target_phrase_high;
+
+ FindProjection(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+
+ if (target_phrase_low == -1) {
+ // Note: Low priority corner case inherited from Adam's code:
+ // If w is unaligned, but we don't require aligned terminals, returning an
+ // error here prevents the extraction of the allowed rule
+ // X -> X_1 w X_2 / X_1 X_2
+ return false;
+ }
+
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ // Extend the target span to the left.
+ if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
+ if (prev_target_low - target_phrase_low < min_target_gap_size) {
+ target_phrase_low = prev_target_low - min_target_gap_size;
+ if (target_phrase_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend the target span to the right.
+ if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
+ if (target_phrase_high - prev_target_high < min_target_gap_size) {
+ target_phrase_high = prev_target_high + min_target_gap_size;
+ if (target_phrase_high > target_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ // Check target span length.
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ return false;
+ }
+
+ // Find the initial reflected source span.
+ source_back_low = source_back_high = -1;
+ FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high);
+ int new_x = 0;
+ bool new_low_x = false, new_high_x = false;
+ while (true) {
+ source_back_low = min(source_back_low, source_phrase_low);
+ source_back_high = max(source_back_high, source_phrase_high);
+
+ // Stop if the reflected source span matches the previous source span.
+ if (source_back_low == source_phrase_low &&
+ source_back_high == source_phrase_high) {
+ return true;
+ }
+
+ if (!allow_low_x && source_back_low < source_phrase_low) {
+ // Extension on the left side not allowed.
+ return false;
+ }
+ if (!allow_high_x && source_back_high > source_phrase_high) {
+ // Extension on the right side not allowed.
+ return false;
+ }
+
+ // Extend left side.
+ if (source_back_low < source_phrase_low) {
+ if (new_low_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_low_x = true;
+ ++new_x;
+ }
+ if (source_phrase_low - source_back_low < min_source_gap_size) {
+ source_back_low = source_phrase_low - min_source_gap_size;
+ if (source_back_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend right side.
+ if (source_back_high > source_phrase_high) {
+ if (new_high_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_high_x = true;
+ ++new_x;
+ }
+ if (source_back_high - source_phrase_high < min_source_gap_size) {
+ source_back_high = source_phrase_high + min_source_gap_size;
+ if (source_back_high > source_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (source_back_high - source_back_low > max_rule_span) {
+ // Rule span too wide.
+ return false;
+ }
+
+ prev_target_low = target_phrase_low;
+ prev_target_high = target_phrase_high;
+ // Find the reflection including the left gap (if one was added).
+ FindProjection(source_back_low, source_phrase_low, source_low, source_high,
+ target_phrase_low, target_phrase_high);
+ // Find the reflection including the right gap (if one was added).
+ FindProjection(source_phrase_high, source_back_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+ // Stop if the new re-reflected target span matches the previous target
+ // span.
+ if (prev_target_low == target_phrase_low &&
+ prev_target_high == target_phrase_high) {
+ return true;
+ }
+
+ if (!allow_arbitrary_expansion) {
+ // Arbitrary expansion not allowed.
+ return false;
+ }
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ // Target side too wide.
+ return false;
+ }
+
+ source_phrase_low = source_back_low;
+ source_phrase_high = source_back_high;
+ // Re-reflect the target span.
+ FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
+ source_back_low, source_back_high);
+ FindProjection(prev_target_high, target_phrase_high, target_low,
+ target_high, source_back_low, source_back_high);
+ }
+
+ return false;
+}
+
+void RuleExtractorHelper::FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const {
+ for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
+ if (source_low[i] != -1) {
+ if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
+ target_phrase_low = source_low[i];
+ }
+ target_phrase_high = max(target_phrase_high, source_high[i]);
+ }
+ }
+}
+
+bool RuleExtractorHelper::GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int sentence_id, int source_sent_start,
+ int& num_symbols, bool& met_constraints) const {
+ if (source_back_low < source_phrase_low) {
+ source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_back_low] == -1 ||
+ source_low[source_phrase_low - 1] == -1)) {
+ // Inside edges of preceding gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - source_sent_start;
+ source_gaps.push_back(make_pair(gap_start, gap_end));
+ }
+
+ if (source_phrase_high < source_back_high) {
+ source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
+ source_low[source_back_high - 1] == -1)) {
+ // Inside edges of following gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases &&
+ source_low[source_phrase_high - 1] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
+ for (size_t i = 0; i < source_gaps.size(); ++i) {
+ if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
+ source_high, target_gaps[i].first, target_gaps[i].second,
+ target_low, target_high, source_gaps[i].first,
+ source_gaps[i].second, sentence_id, 0, 0, 0, false, false,
+ false)) {
+ // Gap fails integrity check.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+vector<int> RuleExtractorHelper::GetGapOrder(
+ const vector<pair<int, int> >& gaps) const {
+ vector<int> gap_order(gaps.size());
+ for (size_t i = 0; i < gap_order.size(); ++i) {
+ for (size_t j = 0; j < i; ++j) {
+ if (gaps[gap_order[j]] < gaps[i]) {
+ ++gap_order[i];
+ } else {
+ ++gap_order[j];
+ }
+ }
+ }
+ return gap_order;
+}
+
+unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x, int source_sent_start) const {
+ unordered_map<int, int> source_indexes;
+ int num_symbols = starts_with_x;
+ for (size_t i = 0; i < matching.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ source_indexes[matching[i] + j - source_sent_start] = num_symbols;
+ ++num_symbols;
+ }
+ ++num_symbols;
+ }
+ return source_indexes;
+}
+
+} // namespace extractor