diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-04-23 19:35:18 -0400 |
commit | 6d347f1ce078dede3da0e1498f75e357351c6543 (patch) | |
tree | 8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/rule_factory.h | |
parent | d11b76def6899790161c47a73018146311356d8b (diff) | |
parent | 5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff) |
merge paul's extractor code
Diffstat (limited to 'extractor/rule_factory.h')
-rw-r--r-- | extractor/rule_factory.h | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h new file mode 100644 index 00000000..52e8712a --- /dev/null +++ b/extractor/rule_factory.h @@ -0,0 +1,118 @@ +#ifndef _RULE_FACTORY_H_ +#define _RULE_FACTORY_H_ + +#include <memory> +#include <vector> + +#include "matchings_trie.h" + +using namespace std; + +namespace extractor { + +class Alignment; +class DataArray; +class FastIntersector; +class Grammar; +class MatchingsFinder; +class PhraseBuilder; +class Precomputation; +class Rule; +class RuleExtractor; +class Sampler; +class Scorer; +class State; +class SuffixArray; +class Vocabulary; + +/** + * Component containing most of the logic for extracting SCFG rules for a given + * sentence. + * + * Given a sentence (as a vector of word ids), this class constructs all the + * possible source phrases starting from this sentence. For each source phrase, + * it finds all its occurrences in the source data and samples some of these + * occurrences to extract aligned source-target phrase pairs. A trie cache is + * used to avoid unnecessary computations if a source phrase can be constructed + * more than once (e.g. some words occur more than once in the sentence). + */ +class HieroCachingRuleFactory { + public: + HieroCachingRuleFactory( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + const shared_ptr<Vocabulary>& vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + int max_samples, + bool require_tight_phrases); + + // For testing only. + HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<FastIntersector> fast_intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols); + + virtual ~HieroCachingRuleFactory(); + + // Constructs SCFG rules for a given sentence. + // (See class description for more details.) + virtual Grammar GetGrammar(const vector<int>& word_ids); + + protected: + HieroCachingRuleFactory(); + + private: + // Checks if the phrase (if previously encountered) or its prefix have any + // occurrences in the source data. + bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id); + + // Checks if the phrase has previously been analyzed. + bool RequiresLookup(shared_ptr<TrieNode> node, int word_id); + + // Creates a new state in the trie that corresponds to adding a trailing + // nonterminal to the current phrase. + void AddTrailingNonterminal(vector<int> symbols, + const Phrase& prefix, + const shared_ptr<TrieNode>& prefix_node, + bool starts_with_x); + + // Extends the current state by possibly adding a nonterminal followed by a + // terminal. + vector<State> ExtendState(const vector<int>& word_ids, + const State& state, + vector<int> symbols, + const Phrase& phrase, + const shared_ptr<TrieNode>& node); + + shared_ptr<MatchingsFinder> matchings_finder; + shared_ptr<FastIntersector> fast_intersector; + shared_ptr<PhraseBuilder> phrase_builder; + shared_ptr<RuleExtractor> rule_extractor; + shared_ptr<Vocabulary> vocabulary; + shared_ptr<Sampler> sampler; + shared_ptr<Scorer> scorer; + int min_gap_size; + int max_rule_span; + int max_nonterminals; + int max_chunks; + int max_rule_symbols; +}; + +} // namespace extractor + +#endif |