summaryrefslogtreecommitdiff
path: root/extractor/rule_extractor_helper.cc
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-02-14 23:17:15 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-02-14 23:17:15 +0000
commit9a026ba2db8fa7723374109e6a4a8dcaff8733cd (patch)
tree34a60703a53ada76e7213da5940e86d6f476f1e4 /extractor/rule_extractor_helper.cc
parent252fb164c208ec8f3005f8a652eb3b48c0644e3d (diff)
Working version of the grammar extractor.
Diffstat (limited to 'extractor/rule_extractor_helper.cc')
-rw-r--r--extractor/rule_extractor_helper.cc356
1 files changed, 356 insertions, 0 deletions
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
new file mode 100644
index 00000000..ed6ae3a1
--- /dev/null
+++ b/extractor/rule_extractor_helper.cc
@@ -0,0 +1,356 @@
+#include "rule_extractor_helper.h"
+
+#include "data_array.h"
+#include "alignment.h"
+
+RuleExtractorHelper::RuleExtractorHelper(
+ shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ int max_rule_span,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases) :
+ source_data_array(source_data_array),
+ target_data_array(target_data_array),
+ alignment(alignment),
+ max_rule_span(max_rule_span),
+ max_rule_symbols(max_rule_symbols),
+ require_aligned_terminal(require_aligned_terminal),
+ require_aligned_chunks(require_aligned_chunks),
+ require_tight_phrases(require_tight_phrases) {}
+
+RuleExtractorHelper::RuleExtractorHelper() {}
+
+RuleExtractorHelper::~RuleExtractorHelper() {}
+
+void RuleExtractorHelper::GetLinksSpans(
+ vector<int>& source_low, vector<int>& source_high,
+ vector<int>& target_low, vector<int>& target_high, int sentence_id) const {
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ source_low = vector<int>(source_sent_len, -1);
+ source_high = vector<int>(source_sent_len, -1);
+
+ // TODO(pauldb): Adam Lopez claims this part is really inefficient. See if we
+ // can speed it up.
+ target_low = vector<int>(target_sent_len, -1);
+ target_high = vector<int>(target_sent_len, -1);
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ for (auto link: links) {
+ if (source_low[link.first] == -1 || source_low[link.first] > link.second) {
+ source_low[link.first] = link.second;
+ }
+ source_high[link.first] = max(source_high[link.first], link.second + 1);
+
+ if (target_low[link.second] == -1 || target_low[link.second] > link.first) {
+ target_low[link.second] = link.first;
+ }
+ target_high[link.second] = max(target_high[link.second], link.first + 1);
+ }
+}
+
+bool RuleExtractorHelper::CheckAlignedTerminals(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_aligned_terminal) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ int num_aligned_chunks = 0;
+ for (size_t i = 0; i < chunklen.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ int sent_index = matching[i] - source_sent_start + j;
+ if (source_low[sent_index] != -1) {
+ ++num_aligned_chunks;
+ break;
+ }
+ }
+ }
+
+ if (num_aligned_chunks == 0) {
+ return false;
+ }
+
+ return !require_aligned_chunks || num_aligned_chunks == chunklen.size();
+}
+
+bool RuleExtractorHelper::CheckTightPhrases(
+ const vector<int>& matching,
+ const vector<int>& chunklen,
+ const vector<int>& source_low) const {
+ if (!require_tight_phrases) {
+ return true;
+ }
+
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - 1 - source_sent_start;
+ if (source_low[gap_start] == -1 || source_low[gap_end] == -1) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool RuleExtractorHelper::FindFixPoint(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int& source_back_low, int& source_back_high, int sentence_id,
+ int min_source_gap_size, int min_target_gap_size,
+ int max_new_x, bool allow_low_x, bool allow_high_x,
+ bool allow_arbitrary_expansion) const {
+ int prev_target_low = target_phrase_low;
+ int prev_target_high = target_phrase_high;
+
+ FindProjection(source_phrase_low, source_phrase_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+
+ if (target_phrase_low == -1) {
+ // TODO(pauldb): Low priority corner case inherited from Adam's code:
+ // If w is unaligned, but we don't require aligned terminals, returning an
+ // error here prevents the extraction of the allowed rule
+ // X -> X_1 w X_2 / X_1 X_2
+ return false;
+ }
+
+ int source_sent_len = source_data_array->GetSentenceLength(sentence_id);
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+ if (prev_target_low != -1 && target_phrase_low != prev_target_low) {
+ if (prev_target_low - target_phrase_low < min_target_gap_size) {
+ target_phrase_low = prev_target_low - min_target_gap_size;
+ if (target_phrase_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ if (prev_target_high != -1 && target_phrase_high != prev_target_high) {
+ if (target_phrase_high - prev_target_high < min_target_gap_size) {
+ target_phrase_high = prev_target_high + min_target_gap_size;
+ if (target_phrase_high > target_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ return false;
+ }
+
+ source_back_low = source_back_high = -1;
+ FindProjection(target_phrase_low, target_phrase_high, target_low, target_high,
+ source_back_low, source_back_high);
+ int new_x = 0;
+ bool new_low_x = false, new_high_x = false;
+ while (true) {
+ source_back_low = min(source_back_low, source_phrase_low);
+ source_back_high = max(source_back_high, source_phrase_high);
+
+ if (source_back_low == source_phrase_low &&
+ source_back_high == source_phrase_high) {
+ return true;
+ }
+
+ if (!allow_low_x && source_back_low < source_phrase_low) {
+ // Extension on the left side not allowed.
+ return false;
+ }
+ if (!allow_high_x && source_back_high > source_phrase_high) {
+ // Extension on the right side not allowed.
+ return false;
+ }
+
+ // Extend left side.
+ if (source_back_low < source_phrase_low) {
+ if (new_low_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_low_x = true;
+ ++new_x;
+ }
+ if (source_phrase_low - source_back_low < min_source_gap_size) {
+ source_back_low = source_phrase_low - min_source_gap_size;
+ if (source_back_low < 0) {
+ return false;
+ }
+ }
+ }
+
+ // Extend right side.
+ if (source_back_high > source_phrase_high) {
+ if (new_high_x == false) {
+ if (new_x >= max_new_x) {
+ return false;
+ }
+ new_high_x = true;
+ ++new_x;
+ }
+ if (source_back_high - source_phrase_high < min_source_gap_size) {
+ source_back_high = source_phrase_high + min_source_gap_size;
+ if (source_back_high > source_sent_len) {
+ return false;
+ }
+ }
+ }
+
+ if (source_back_high - source_back_low > max_rule_span) {
+ // Rule span too wide.
+ return false;
+ }
+
+ prev_target_low = target_phrase_low;
+ prev_target_high = target_phrase_high;
+ FindProjection(source_back_low, source_phrase_low, source_low, source_high,
+ target_phrase_low, target_phrase_high);
+ FindProjection(source_phrase_high, source_back_high, source_low,
+ source_high, target_phrase_low, target_phrase_high);
+ if (prev_target_low == target_phrase_low &&
+ prev_target_high == target_phrase_high) {
+ return true;
+ }
+
+ if (!allow_arbitrary_expansion) {
+ // Arbitrary expansion not allowed.
+ return false;
+ }
+ if (target_phrase_high - target_phrase_low > max_rule_span) {
+ // Target side too wide.
+ return false;
+ }
+
+ source_phrase_low = source_back_low;
+ source_phrase_high = source_back_high;
+ FindProjection(target_phrase_low, prev_target_low, target_low, target_high,
+ source_back_low, source_back_high);
+ FindProjection(prev_target_high, target_phrase_high, target_low,
+ target_high, source_back_low, source_back_high);
+ }
+
+ return false;
+}
+
+void RuleExtractorHelper::FindProjection(
+ int source_phrase_low, int source_phrase_high,
+ const vector<int>& source_low, const vector<int>& source_high,
+ int& target_phrase_low, int& target_phrase_high) const {
+ for (size_t i = source_phrase_low; i < source_phrase_high; ++i) {
+ if (source_low[i] != -1) {
+ if (target_phrase_low == -1 || source_low[i] < target_phrase_low) {
+ target_phrase_low = source_low[i];
+ }
+ target_phrase_high = max(target_phrase_high, source_high[i]);
+ }
+ }
+}
+
+bool RuleExtractorHelper::GetGaps(
+ vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps,
+ const vector<int>& matching, const vector<int>& chunklen,
+ const vector<int>& source_low, const vector<int>& source_high,
+ const vector<int>& target_low, const vector<int>& target_high,
+ int source_phrase_low, int source_phrase_high, int source_back_low,
+ int source_back_high, int& num_symbols, bool& met_constraints) const {
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+
+ if (source_back_low < source_phrase_low) {
+ source_gaps.push_back(make_pair(source_back_low, source_phrase_low));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_back_low] == -1 ||
+ source_low[source_phrase_low - 1] == -1)) {
+ // Inside edges of preceding gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases && source_low[source_phrase_low] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ for (size_t i = 0; i + 1 < chunklen.size(); ++i) {
+ int gap_start = matching[i] + chunklen[i] - source_sent_start;
+ int gap_end = matching[i + 1] - source_sent_start;
+ source_gaps.push_back(make_pair(gap_start, gap_end));
+ }
+
+ if (source_phrase_high < source_back_high) {
+ source_gaps.push_back(make_pair(source_phrase_high, source_back_high));
+ if (num_symbols >= max_rule_symbols) {
+ // Source side contains too many symbols.
+ return false;
+ }
+ ++num_symbols;
+ if (require_tight_phrases && (source_low[source_phrase_high] == -1 ||
+ source_low[source_back_high - 1] == -1)) {
+ // Inside edges of following gap are not tight.
+ return false;
+ }
+ } else if (require_tight_phrases &&
+ source_low[source_phrase_high - 1] == -1) {
+ // This is not a hard error. We can't extract this phrase, but we might
+ // still be able to extract a superphrase.
+ met_constraints = false;
+ }
+
+ target_gaps.resize(source_gaps.size(), make_pair(-1, -1));
+ for (size_t i = 0; i < source_gaps.size(); ++i) {
+ if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low,
+ source_high, target_gaps[i].first, target_gaps[i].second,
+ target_low, target_high, source_gaps[i].first,
+ source_gaps[i].second, sentence_id, 0, 0, 0, false, false,
+ false)) {
+ // Gap fails integrity check.
+ return false;
+ }
+ }
+
+ return true;
+}
+
+vector<int> RuleExtractorHelper::GetGapOrder(
+ const vector<pair<int, int> >& gaps) const {
+ vector<int> gap_order(gaps.size());
+ for (size_t i = 0; i < gap_order.size(); ++i) {
+ for (size_t j = 0; j < i; ++j) {
+ if (gaps[gap_order[j]] < gaps[i]) {
+ ++gap_order[i];
+ } else {
+ ++gap_order[j];
+ }
+ }
+ }
+ return gap_order;
+}
+
+unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes(
+ const vector<int>& matching, const vector<int>& chunklen,
+ int starts_with_x) const {
+ unordered_map<int, int> source_indexes;
+ int sentence_id = source_data_array->GetSentenceId(matching[0]);
+ int source_sent_start = source_data_array->GetSentenceStart(sentence_id);
+ int num_symbols = starts_with_x;
+ for (size_t i = 0; i < matching.size(); ++i) {
+ for (size_t j = 0; j < chunklen[i]; ++j) {
+ source_indexes[matching[i] + j - source_sent_start] = num_symbols;
+ ++num_symbols;
+ }
+ ++num_symbols;
+ }
+ return source_indexes;
+}