summaryrefslogtreecommitdiff
path: root/extractor/rule_extractor.h
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_extractor.h')
-rw-r--r--extractor/rule_extractor.h124
1 files changed, 124 insertions, 0 deletions
diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h
new file mode 100644
index 00000000..26e6f21c
--- /dev/null
+++ b/extractor/rule_extractor.h
@@ -0,0 +1,124 @@
+#ifndef _RULE_EXTRACTOR_H_
+#define _RULE_EXTRACTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "phrase.h"
+
+using namespace std;
+
+namespace extractor {
+
+typedef vector<pair<int, int> > PhraseAlignment;
+
+class Alignment;
+class DataArray;
+class PhraseBuilder;
+class PhraseLocation;
+class Rule;
+class RuleExtractorHelper;
+class Scorer;
+class TargetPhraseExtractor;
+
+/**
+ * Structure containing data about the occurrences of a source-target phrase pair
+ * in the parallel corpus.
+ */
+struct Extract {
+ Extract(const Phrase& source_phrase, const Phrase& target_phrase,
+ double pairs_count, const PhraseAlignment& alignment) :
+ source_phrase(source_phrase), target_phrase(target_phrase),
+ pairs_count(pairs_count), alignment(alignment) {}
+
+ Phrase source_phrase;
+ Phrase target_phrase;
+ double pairs_count;
+ PhraseAlignment alignment;
+};
+
+/**
+ * Component for extracting SCFG rules.
+ */
+class RuleExtractor {
+ public:
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alingment,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<Vocabulary> vocabulary,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_aligned_terminal,
+ bool require_aligned_chunks,
+ bool require_tight_phrases);
+
+ // For testing only.
+ RuleExtractor(shared_ptr<DataArray> source_data_array,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<Scorer> scorer,
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor,
+ shared_ptr<RuleExtractorHelper> helper,
+ int max_rule_span,
+ int min_gap_size,
+ int max_nonterminals,
+ int max_rule_symbols,
+ bool require_tight_phrases);
+
+ virtual ~RuleExtractor();
+
+ // Extracts SCFG rules given a source phrase and a set of its occurrences
+ // in the source data.
+ virtual vector<Rule> ExtractRules(const Phrase& phrase,
+ const PhraseLocation& location) const;
+
+ protected:
+ RuleExtractor();
+
+ private:
+ // Finds all target phrases that can be aligned with the source phrase for a
+ // particular occurrence in the data.
+ vector<Extract> ExtractAlignments(const Phrase& phrase,
+ const vector<int>& matching) const;
+
+ // Extracts all target phrases for a given occurrence of the source phrase in
+ // the data. Constructs a vector of Extracts using these target phrases.
+ void AddExtracts(
+ vector<Extract>& extracts, const Phrase& source_phrase,
+ const unordered_map<int, int>& source_indexes,
+ const vector<pair<int, int> >& target_gaps, const vector<int>& target_low,
+ int target_phrase_low, int target_phrase_high, int sentence_id) const;
+
+ // Adds a leading and/or trailing nonterminal to the source phrase and
+ // extracts target phrases that can be aligned with the extended source
+ // phrase.
+ void AddNonterminalExtremities(
+ vector<Extract>& extracts, const vector<int>& matching,
+ const vector<int>& chunklen, const Phrase& source_phrase,
+ int source_back_low, int source_back_high, const vector<int>& source_low,
+ const vector<int>& source_high, const vector<int>& target_low,
+ const vector<int>& target_high, vector<pair<int, int> > target_gaps,
+ int sentence_id, int source_sent_start, int starts_with_x,
+ int ends_with_x, int extend_left, int extend_right) const;
+
+ private:
+ shared_ptr<DataArray> target_data_array;
+ shared_ptr<DataArray> source_data_array;
+ shared_ptr<PhraseBuilder> phrase_builder;
+ shared_ptr<Scorer> scorer;
+ shared_ptr<TargetPhraseExtractor> target_phrase_extractor;
+ shared_ptr<RuleExtractorHelper> helper;
+ int max_rule_span;
+ int min_gap_size;
+ int max_nonterminals;
+ int max_rule_symbols;
+ bool require_tight_phrases;
+};
+
+} // namespace extractor
+
+#endif