diff options
Diffstat (limited to 'extractor/rule_extractor.h')
-rw-r--r-- | extractor/rule_extractor.h | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h new file mode 100644 index 00000000..26e6f21c --- /dev/null +++ b/extractor/rule_extractor.h @@ -0,0 +1,124 @@ +#ifndef _RULE_EXTRACTOR_H_ +#define _RULE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include "phrase.h" + +using namespace std; + +namespace extractor { + +typedef vector<pair<int, int> > PhraseAlignment; + +class Alignment; +class DataArray; +class PhraseBuilder; +class PhraseLocation; +class Rule; +class RuleExtractorHelper; +class Scorer; +class TargetPhraseExtractor; + +/** + * Structure containing data about the occurrences of a source-target phrase pair + * in the parallel corpus. + */ +struct Extract { + Extract(const Phrase& source_phrase, const Phrase& target_phrase, + double pairs_count, const PhraseAlignment& alignment) : + source_phrase(source_phrase), target_phrase(target_phrase), + pairs_count(pairs_count), alignment(alignment) {} + + Phrase source_phrase; + Phrase target_phrase; + double pairs_count; + PhraseAlignment alignment; +}; + +/** + * Component for extracting SCFG rules. + */ +class RuleExtractor { + public: + RuleExtractor(shared_ptr<DataArray> source_data_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alingment, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<Vocabulary> vocabulary, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + bool require_aligned_terminal, + bool require_aligned_chunks, + bool require_tight_phrases); + + // For testing only. + RuleExtractor(shared_ptr<DataArray> source_data_array, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<Scorer> scorer, + shared_ptr<TargetPhraseExtractor> target_phrase_extractor, + shared_ptr<RuleExtractorHelper> helper, + int max_rule_span, + int min_gap_size, + int max_nonterminals, + int max_rule_symbols, + bool require_tight_phrases); + + virtual ~RuleExtractor(); + + // Extracts SCFG rules given a source phrase and a set of its occurrences + // in the source data. + virtual vector<Rule> ExtractRules(const Phrase& phrase, + const PhraseLocation& location) const; + + protected: + RuleExtractor(); + + private: + // Finds all target phrases that can be aligned with the source phrase for a + // particular occurrence in the data. + vector<Extract> ExtractAlignments(const Phrase& phrase, + const vector<int>& matching) const; + + // Extracts all target phrases for a given occurrence of the source phrase in + // the data. Constructs a vector of Extracts using these target phrases. + void AddExtracts( + vector<Extract>& extracts, const Phrase& source_phrase, + const unordered_map<int, int>& source_indexes, + const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, + int target_phrase_low, int target_phrase_high, int sentence_id) const; + + // Adds a leading and/or trailing nonterminal to the source phrase and + // extracts target phrases that can be aligned with the extended source + // phrase. + void AddNonterminalExtremities( + vector<Extract>& extracts, const vector<int>& matching, + const vector<int>& chunklen, const Phrase& source_phrase, + int source_back_low, int source_back_high, const vector<int>& source_low, + const vector<int>& source_high, const vector<int>& target_low, + const vector<int>& target_high, vector<pair<int, int> > target_gaps, + int sentence_id, int source_sent_start, int starts_with_x, + int ends_with_x, int extend_left, int extend_right) const; + + private: + shared_ptr<DataArray> target_data_array; + shared_ptr<DataArray> source_data_array; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<Scorer> scorer; + shared_ptr<TargetPhraseExtractor> target_phrase_extractor; + shared_ptr<RuleExtractorHelper> helper; + int max_rule_span; + int min_gap_size; + int max_nonterminals; + int max_rule_symbols; + bool require_tight_phrases; +}; + +} // namespace extractor + +#endif |