summaryrefslogtreecommitdiff
path: root/extractor/target_phrase_extractor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/target_phrase_extractor.cc')
-rw-r--r--extractor/target_phrase_extractor.cc144
1 files changed, 144 insertions, 0 deletions
diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc
new file mode 100644
index 00000000..ac583953
--- /dev/null
+++ b/extractor/target_phrase_extractor.cc
@@ -0,0 +1,144 @@
+#include "target_phrase_extractor.h"
+
+#include <unordered_set>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "phrase.h"
+#include "phrase_builder.h"
+#include "rule_extractor_helper.h"
+#include "vocabulary.h"
+
+using namespace std;
+
+TargetPhraseExtractor::TargetPhraseExtractor(
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractorHelper> helper,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ bool require_tight_phrases) :
+ target_data_array(target_data_array),
+ alignment(alignment),
+ phrase_builder(phrase_builder),
+ helper(helper),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ require_tight_phrases(require_tight_phrases) {}
+
+TargetPhraseExtractor::TargetPhraseExtractor() {}
+
+TargetPhraseExtractor::~TargetPhraseExtractor() {}
+
+vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases(
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high,
+ const unordered_map<int, int>& source_indexes, int sentence_id) const {
+ int target_sent_len = target_data_array->GetSentenceLength(sentence_id);
+
+ vector<int> target_gap_order = helper->GetGapOrder(target_gaps);
+
+ int target_x_low = target_phrase_low, target_x_high = target_phrase_high;
+ if (!require_tight_phrases) {
+ while (target_x_low > 0 &&
+ target_phrase_high - target_x_low < max_rule_span &&
+ target_low[target_x_low - 1] == -1) {
+ --target_x_low;
+ }
+ while (target_x_high < target_sent_len &&
+ target_x_high - target_phrase_low < max_rule_span &&
+ target_low[target_x_high] == -1) {
+ ++target_x_high;
+ }
+ }
+
+ vector<pair<int, int> > gaps(target_gaps.size());
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ gaps[i] = target_gaps[target_gap_order[i]];
+ if (!require_tight_phrases) {
+ while (gaps[i].first > target_x_low &&
+ target_low[gaps[i].first - 1] == -1) {
+ --gaps[i].first;
+ }
+ while (gaps[i].second < target_x_high &&
+ target_low[gaps[i].second] == -1) {
+ ++gaps[i].second;
+ }
+ }
+ }
+
+ vector<pair<int, int> > ranges(2 * gaps.size() + 2);
+ ranges.front() = make_pair(target_x_low, target_phrase_low);
+ ranges.back() = make_pair(target_phrase_high, target_x_high);
+ for (size_t i = 0; i < gaps.size(); ++i) {
+ int j = target_gap_order[i];
+ ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first);
+ ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second);
+ }
+
+ vector<pair<Phrase, PhraseAlignment> > target_phrases;
+ vector<int> subpatterns(ranges.size());
+ GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order,
+ target_phrase_low, target_phrase_high, source_indexes,
+ sentence_id);
+ return target_phrases;
+}
+
+void TargetPhraseExtractor::GeneratePhrases(
+ vector<pair<Phrase, PhraseAlignment> >& target_phrases,
+ const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns,
+ const vector<int>& target_gap_order, int target_phrase_low,
+ int target_phrase_high, const unordered_map<int, int>& source_indexes,
+ int sentence_id) const {
+ if (index >= ranges.size()) {
+ if (subpatterns.back() - subpatterns.front() > max_rule_span) {
+ return;
+ }
+
+ vector<int> symbols;
+ unordered_map<int, int> target_indexes;
+
+ int target_sent_start = target_data_array->GetSentenceStart(sentence_id);
+ for (size_t i = 0; i * 2 < subpatterns.size(); ++i) {
+ for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) {
+ target_indexes[j] = symbols.size();
+ string target_word = target_data_array->GetWordAtIndex(
+ target_sent_start + j);
+ symbols.push_back(vocabulary->GetTerminalIndex(target_word));
+ }
+ if (i < target_gap_order.size()) {
+ symbols.push_back(vocabulary->GetNonterminalIndex(
+ target_gap_order[i] + 1));
+ }
+ }
+
+ vector<pair<int, int> > links = alignment->GetLinks(sentence_id);
+ vector<pair<int, int> > alignment;
+ for (pair<int, int> link: links) {
+ if (target_indexes.count(link.second)) {
+ alignment.push_back(make_pair(source_indexes.find(link.first)->second,
+ target_indexes[link.second]));
+ }
+ }
+
+ Phrase target_phrase = phrase_builder->Build(symbols);
+ target_phrases.push_back(make_pair(target_phrase, alignment));
+ return;
+ }
+
+ subpatterns[index] = ranges[index].first;
+ if (index > 0) {
+ subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]);
+ }
+ while (subpatterns[index] <= ranges[index].second) {
+ subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first);
+ while (subpatterns[index + 1] <= ranges[index + 1].second) {
+ GeneratePhrases(target_phrases, ranges, index + 2, subpatterns,
+ target_gap_order, target_phrase_low, target_phrase_high,
+ source_indexes, sentence_id);
+ ++subpatterns[index + 1];
+ }
+ ++subpatterns[index];
+ }
+}