From 0a53f7eca74c165b5ce1c238f1999ddf1febea55 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Fri, 1 Feb 2013 16:11:10 +0000 Subject: Second working commit. --- extractor/rule_factory.cc | 56 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 16 deletions(-) (limited to 'extractor/rule_factory.cc') diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 7a8356b8..c22f9b48 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -5,8 +5,15 @@ #include #include +#include "grammar.h" +#include "intersector.h" +#include "matchings_finder.h" #include "matching_comparator.h" #include "phrase.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" #include "suffix_array.h" #include "vocabulary.h" @@ -30,28 +37,39 @@ struct State { HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr source_suffix_array, shared_ptr target_data_array, - const Alignment& alignment, + shared_ptr alignment, const shared_ptr& vocabulary, - const Precomputation& precomputation, + shared_ptr precomputation, + shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, - bool use_baeza_yates) : - matchings_finder(source_suffix_array), - intersector(vocabulary, precomputation, source_suffix_array, - make_shared(min_gap_size, max_rule_span), - use_baeza_yates), - phrase_builder(vocabulary), - rule_extractor(source_suffix_array, target_data_array, alignment), + int max_samples, + bool use_baeza_yates, + bool require_tight_phrases) : vocabulary(vocabulary), + scorer(scorer), min_gap_size(min_gap_size), max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), - max_rule_symbols(max_rule_symbols) {} + max_rule_symbols(max_rule_symbols) { + matchings_finder = make_shared(source_suffix_array); + shared_ptr comparator = + make_shared(min_gap_size, max_rule_span); + intersector = make_shared(vocabulary, precomputation, + source_suffix_array, comparator, use_baeza_yates); + phrase_builder = make_shared(vocabulary); + rule_extractor = make_shared(source_suffix_array->GetData(), + target_data_array, alignment, phrase_builder, scorer, vocabulary, + max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, + false, require_tight_phrases); + sampler = make_shared(source_suffix_array, max_samples); +} + -void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { // Clear cache for every new sentence. trie.Reset(); shared_ptr root = trie.GetRoot(); @@ -69,6 +87,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { vector(1, i), x_root, true)); } + vector rules; while (!states.empty()) { State state = states.front(); states.pop(); @@ -77,7 +96,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { vector phrase = state.phrase; int word_id = word_ids[state.end]; phrase.push_back(word_id); - Phrase next_phrase = phrase_builder.Build(phrase); + Phrase next_phrase = phrase_builder->Build(phrase); shared_ptr next_node; if (CannotHaveMatchings(node, word_id)) { @@ -98,14 +117,14 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { - phrase_location = intersector.Intersect( + phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); } else { - phrase_location = matchings_finder.Find( + phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); @@ -125,7 +144,10 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { state.starts_with_x); if (!state.starts_with_x) { - rule_extractor.ExtractRules(); + PhraseLocation sample = sampler->Sample(next_node->matchings); + vector new_rules = + rule_extractor->ExtractRules(next_phrase, sample); + rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } } else { next_node = node->GetChild(word_id); @@ -137,6 +159,8 @@ void HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { states.push(new_state); } } + + return Grammar(rules, scorer->GetFeatureNames()); } bool HieroCachingRuleFactory::CannotHaveMatchings( @@ -165,7 +189,7 @@ void HieroCachingRuleFactory::AddTrailingNonterminal( int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); symbols.push_back(var_id); - Phrase var_phrase = phrase_builder.Build(symbols); + Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( prefix.Arity() + starts_with_x == 0); -- cgit v1.2.3