diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
commit | c164dc0ed8a32e4095ba1b36495e0f743b8cc1ea (patch) | |
tree | 78b81e4c63adfa67adb7b8f80c3e6be87b4a2b2a /extractor/target_phrase_extractor.cc | |
parent | 0e46089cafa4e8e2f060e370d7afaceeda6b90a9 (diff) | |
parent | d467e14b28085809c31431be0478eb3d9322fe96 (diff) |
merge paul's extractor code
Diffstat (limited to 'extractor/target_phrase_extractor.cc')
-rw-r--r-- | extractor/target_phrase_extractor.cc | 158 |
1 files changed, 158 insertions, 0 deletions
diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc new file mode 100644 index 00000000..2b8a2e4a --- /dev/null +++ b/extractor/target_phrase_extractor.cc @@ -0,0 +1,158 @@ +#include "target_phrase_extractor.h" + +#include <unordered_set> + +#include "alignment.h" +#include "data_array.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "rule_extractor_helper.h" +#include "vocabulary.h" + +using namespace std; + +namespace extractor { + +TargetPhraseExtractor::TargetPhraseExtractor( + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractorHelper> helper, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + bool require_tight_phrases) : + target_data_array(target_data_array), + alignment(alignment), + phrase_builder(phrase_builder), + helper(helper), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + require_tight_phrases(require_tight_phrases) {} + +TargetPhraseExtractor::TargetPhraseExtractor() {} + +TargetPhraseExtractor::~TargetPhraseExtractor() {} + +vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, + const unordered_map<int, int>& source_indexes, int sentence_id) const { + int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + + vector<int> target_gap_order = helper->GetGapOrder(target_gaps); + + int target_x_low = target_phrase_low, target_x_high = target_phrase_high; + if (!require_tight_phrases) { + // Extend loose target phrase to the left. + while (target_x_low > 0 && + target_phrase_high - target_x_low < max_rule_span && + target_low[target_x_low - 1] == -1) { + --target_x_low; + } + // Extend loose target phrase to the right. + while (target_x_high < target_sent_len && + target_x_high - target_phrase_low < max_rule_span && + target_low[target_x_high] == -1) { + ++target_x_high; + } + } + + vector<pair<int, int> > gaps(target_gaps.size()); + for (size_t i = 0; i < gaps.size(); ++i) { + gaps[i] = target_gaps[target_gap_order[i]]; + if (!require_tight_phrases) { + // Extend gap to the left. + while (gaps[i].first > target_x_low && + target_low[gaps[i].first - 1] == -1) { + --gaps[i].first; + } + // Extend gap to the right. + while (gaps[i].second < target_x_high && + target_low[gaps[i].second] == -1) { + ++gaps[i].second; + } + } + } + + // Compute the range in which each chunk may start or end. (Even indexes + // represent the range in which the chunk may start, odd indexes represent the + // range in which the chunk may end.) + vector<pair<int, int> > ranges(2 * gaps.size() + 2); + ranges.front() = make_pair(target_x_low, target_phrase_low); + ranges.back() = make_pair(target_phrase_high, target_x_high); + for (size_t i = 0; i < gaps.size(); ++i) { + int j = target_gap_order[i]; + ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first); + ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second); + } + + vector<pair<Phrase, PhraseAlignment> > target_phrases; + vector<int> subpatterns(ranges.size()); + GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, + target_phrase_low, target_phrase_high, source_indexes, + sentence_id); + return target_phrases; +} + +void TargetPhraseExtractor::GeneratePhrases( + vector<pair<Phrase, PhraseAlignment> >& target_phrases, + const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns, + const vector<int>& target_gap_order, int target_phrase_low, + int target_phrase_high, const unordered_map<int, int>& source_indexes, + int sentence_id) const { + if (index >= ranges.size()) { + if (subpatterns.back() - subpatterns.front() > max_rule_span) { + return; + } + + vector<int> symbols; + unordered_map<int, int> target_indexes; + + // Construct target phrase chunk by chunk. + int target_sent_start = target_data_array->GetSentenceStart(sentence_id); + for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { + for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { + target_indexes[j] = symbols.size(); + string target_word = target_data_array->GetWordAtIndex( + target_sent_start + j); + symbols.push_back(vocabulary->GetTerminalIndex(target_word)); + } + if (i < target_gap_order.size()) { + symbols.push_back(vocabulary->GetNonterminalIndex( + target_gap_order[i] + 1)); + } + } + + // Construct the alignment between the source and the target phrase. + vector<pair<int, int> > links = alignment->GetLinks(sentence_id); + vector<pair<int, int> > alignment; + for (pair<int, int> link: links) { + if (target_indexes.count(link.second)) { + alignment.push_back(make_pair(source_indexes.find(link.first)->second, + target_indexes[link.second])); + } + } + + Phrase target_phrase = phrase_builder->Build(symbols); + target_phrases.push_back(make_pair(target_phrase, alignment)); + return; + } + + subpatterns[index] = ranges[index].first; + if (index > 0) { + subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); + } + // Choose every possible combination of [start, end) for the current chunk. + while (subpatterns[index] <= ranges[index].second) { + subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); + while (subpatterns[index + 1] <= ranges[index + 1].second) { + GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, + target_gap_order, target_phrase_low, target_phrase_high, + source_indexes, sentence_id); + ++subpatterns[index + 1]; + } + ++subpatterns[index]; + } +} + +} // namespace extractor |