#ifndef _RULE_FACTORY_H_ #define _RULE_FACTORY_H_ #include #include #include #include "matchings_trie.h" using namespace std; namespace extractor { class Alignment; class DataArray; class FastIntersector; class Grammar; class MatchingsFinder; class PhraseBuilder; class Precomputation; class Rule; class RuleExtractor; class Sampler; class Scorer; class State; class SuffixArray; class Vocabulary; /** * Component containing most of the logic for extracting SCFG rules for a given * sentence. * * Given a sentence (as a vector of word ids), this class constructs all the * possible source phrases starting from this sentence. For each source phrase, * it finds all its occurrences in the source data and samples some of these * occurrences to extract aligned source-target phrase pairs. A trie cache is * used to avoid unnecessary computations if a source phrase can be constructed * more than once (e.g. some words occur more than once in the sentence). */ class HieroCachingRuleFactory { public: HieroCachingRuleFactory( shared_ptr source_suffix_array, shared_ptr target_data_array, shared_ptr alignment, const shared_ptr& vocabulary, shared_ptr precomputation, shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, bool require_tight_phrases); // For testing only. HieroCachingRuleFactory( shared_ptr finder, shared_ptr fast_intersector, shared_ptr phrase_builder, shared_ptr rule_extractor, shared_ptr vocabulary, shared_ptr sampler, shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_chunks, int max_rule_symbols); virtual ~HieroCachingRuleFactory(); // Constructs SCFG rules for a given sentence. // (See class description for more details.) virtual Grammar GetGrammar( const vector& word_ids, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array); protected: HieroCachingRuleFactory(); private: // Checks if the phrase (if previously encountered) or its prefix have any // occurrences in the source data. bool CannotHaveMatchings(shared_ptr node, int word_id); // Checks if the phrase has previously been analyzed. bool RequiresLookup(shared_ptr node, int word_id); // Creates a new state in the trie that corresponds to adding a trailing // nonterminal to the current phrase. void AddTrailingNonterminal(vector symbols, const Phrase& prefix, const shared_ptr& prefix_node, bool starts_with_x); // Extends the current state by possibly adding a nonterminal followed by a // terminal. vector ExtendState(const vector& word_ids, const State& state, vector symbols, const Phrase& phrase, const shared_ptr& node); shared_ptr matchings_finder; shared_ptr fast_intersector; shared_ptr phrase_builder; shared_ptr rule_extractor; shared_ptr vocabulary; shared_ptr sampler; shared_ptr scorer; int min_gap_size; int max_rule_span; int max_nonterminals; int max_chunks; int max_rule_symbols; }; } // namespace extractor #endif