diff options
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r-- | extractor/rule_factory.cc | 56 |
1 files changed, 40 insertions, 16 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 7a8356b8..c22f9b48 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -5,8 +5,15 @@ #include <queue> #include <vector> +#include "grammar.h" +#include "intersector.h" +#include "matchings_finder.h" #include "matching_comparator.h" #include "phrase.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" #include "suffix_array.h" #include "vocabulary.h" @@ -30,28 +37,39 @@ struct State { HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr<SuffixArray> source_suffix_array, shared_ptr<DataArray> target_data_array, - const Alignment& alignment, + shared_ptr<Alignment> alignment, const shared_ptr<Vocabulary>& vocabulary, - const Precomputation& precomputation, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, - bool use_baeza_yates) : - matchings_finder(source_suffix_array), - intersector(vocabulary, precomputation, source_suffix_array, - make_shared<MatchingComparator>(min_gap_size, max_rule_span), - use_baeza_yates), - phrase_builder(vocabulary), - rule_extractor(source_suffix_array, target_data_array, alignment), + int max_samples, + bool use_baeza_yates, + bool require_tight_phrases) : vocabulary(vocabulary), + scorer(scorer), min_gap_size(min_gap_size), max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), - max_rule_symbols(max_rule_symbols) {} + max_rule_symbols(max_rule_symbols) { + matchings_finder = make_shared<MatchingsFinder>(source_suffix_array); + shared_ptr<MatchingComparator> comparator = + make_shared<MatchingComparator>(min_gap_size, max_rule_span); + intersector = make_shared<Intersector>(vocabulary, precomputation, + source_suffix_array, comparator, use_baeza_yates); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(), + target_data_array, alignment, phrase_builder, scorer, vocabulary, + max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, + false, require_tight_phrases); + sampler = make_shared<Sampler>(source_suffix_array, max_samples); +} + -void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { // Clear cache for every new sentence. trie.Reset(); shared_ptr<TrieNode> root = trie.GetRoot(); @@ -69,6 +87,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { vector<int>(1, i), x_root, true)); } + vector<Rule> rules; while (!states.empty()) { State state = states.front(); states.pop(); @@ -77,7 +96,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { vector<int> phrase = state.phrase; int word_id = word_ids[state.end]; phrase.push_back(word_id); - Phrase next_phrase = phrase_builder.Build(phrase); + Phrase next_phrase = phrase_builder->Build(phrase); shared_ptr<TrieNode> next_node; if (CannotHaveMatchings(node, word_id)) { @@ -98,14 +117,14 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { - phrase_location = intersector.Intersect( + phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); } else { - phrase_location = matchings_finder.Find( + phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); @@ -125,7 +144,10 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { state.starts_with_x); if (!state.starts_with_x) { - rule_extractor.ExtractRules(); + PhraseLocation sample = sampler->Sample(next_node->matchings); + vector<Rule> new_rules = + rule_extractor->ExtractRules(next_phrase, sample); + rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } } else { next_node = node->GetChild(word_id); @@ -137,6 +159,8 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { states.push(new_state); } } + + return Grammar(rules, scorer->GetFeatureNames()); } bool HieroCachingRuleFactory::CannotHaveMatchings( @@ -165,7 +189,7 @@ void HieroCachingRuleFactory::AddTrailingNonterminal( int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); symbols.push_back(var_id); - Phrase var_phrase = phrase_builder.Build(symbols); + Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( prefix.Arity() + starts_with_x == 0); |