#include "rule_factory.h" #include #include #include #include #include "grammar.h" #include "fast_intersector.h" #include "intersector.h" #include "matchings_finder.h" #include "matching_comparator.h" #include "phrase.h" #include "rule.h" #include "rule_extractor.h" #include "sampler.h" #include "scorer.h" #include "suffix_array.h" #include "time_util.h" #include "vocabulary.h" using namespace std; using namespace chrono; typedef high_resolution_clock Clock; struct State { State(int start, int end, const vector& phrase, const vector& subpatterns_start, shared_ptr node, bool starts_with_x) : start(start), end(end), phrase(phrase), subpatterns_start(subpatterns_start), node(node), starts_with_x(starts_with_x) {} int start, end; vector phrase, subpatterns_start; shared_ptr node; bool starts_with_x; }; HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr source_suffix_array, shared_ptr target_data_array, shared_ptr alignment, const shared_ptr& vocabulary, shared_ptr precomputation, shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) : vocabulary(vocabulary), scorer(scorer), min_gap_size(min_gap_size), max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), max_rule_symbols(max_rule_symbols), use_fast_intersect(use_fast_intersect) { matchings_finder = make_shared(source_suffix_array); shared_ptr comparator = make_shared(min_gap_size, max_rule_span); intersector = make_shared(vocabulary, precomputation, source_suffix_array, comparator, use_baeza_yates); fast_intersector = make_shared(source_suffix_array, precomputation, vocabulary, max_rule_span, min_gap_size); phrase_builder = make_shared(vocabulary); rule_extractor = make_shared(source_suffix_array->GetData(), target_data_array, alignment, phrase_builder, scorer, vocabulary, max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, false, require_tight_phrases); sampler = make_shared(source_suffix_array, max_samples); } HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr finder, shared_ptr intersector, shared_ptr fast_intersector, shared_ptr phrase_builder, shared_ptr rule_extractor, shared_ptr vocabulary, shared_ptr sampler, shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_chunks, int max_rule_symbols, bool use_fast_intersect) : matchings_finder(finder), intersector(intersector), fast_intersector(fast_intersector), phrase_builder(phrase_builder), rule_extractor(rule_extractor), vocabulary(vocabulary), sampler(sampler), scorer(scorer), min_gap_size(min_gap_size), max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_chunks), max_rule_symbols(max_rule_symbols), use_fast_intersect(use_fast_intersect) {} HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { intersector->sort_time = 0; Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; double total_lookup_time = 0; // Clear cache for every new sentence. trie.Reset(); shared_ptr root = trie.GetRoot(); int first_x = vocabulary->GetNonterminalIndex(1); shared_ptr x_root(new TrieNode(root)); root->AddChild(first_x, x_root); queue states; for (size_t i = 0; i < word_ids.size(); ++i) { states.push(State(i, i, vector(), vector(1, i), root, false)); } for (size_t i = min_gap_size; i < word_ids.size(); ++i) { states.push(State(i - min_gap_size, i, vector(1, first_x), vector(1, i), x_root, true)); } vector rules; while (!states.empty()) { State state = states.front(); states.pop(); shared_ptr node = state.node; vector phrase = state.phrase; int word_id = word_ids[state.end]; phrase.push_back(word_id); Phrase next_phrase = phrase_builder->Build(phrase); shared_ptr next_node; if (CannotHaveMatchings(node, word_id)) { if (!node->HasChild(word_id)) { node->AddChild(word_id, shared_ptr()); } continue; } if (RequiresLookup(node, word_id)) { shared_ptr next_suffix_link = node->suffix_link == NULL ? trie.GetRoot() : node->suffix_link->GetChild(word_id); if (state.starts_with_x) { // If the phrase starts with a non terminal, we simply use the matchings // from the suffix link. next_node = make_shared( next_suffix_link, next_phrase, next_suffix_link->matchings); } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { Clock::time_point intersect_start = Clock::now(); if (use_fast_intersect) { phrase_location = fast_intersector->Intersect( node->matchings, next_suffix_link->matchings, next_phrase); } else { phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); } Clock::time_point intersect_stop = Clock::now(); total_intersect_time += GetDuration(intersect_start, intersect_stop); } else { Clock::time_point lookup_start = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); Clock::time_point lookup_stop = Clock::now(); total_lookup_time += GetDuration(lookup_start, lookup_stop); } if (phrase_location.IsEmpty()) { continue; } next_node = make_shared( next_suffix_link, next_phrase, phrase_location); } node->AddChild(word_id, next_node); // Automatically adds a trailing non terminal if allowed. Simply copy the // matchings from the prefix node. AddTrailingNonterminal(phrase, next_phrase, next_node, state.starts_with_x); Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { PhraseLocation sample = sampler->Sample(next_node->matchings); vector new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } Clock::time_point extract_stop = Clock::now(); total_extract_time += GetDuration(extract_start, extract_stop); } else { next_node = node->GetChild(word_id); } vector new_states = ExtendState(word_ids, state, phrase, next_phrase, next_node); for (State new_state: new_states) { states.push(new_state); } } cerr << "Vocabulary size = " << vocabulary->Size() << endl; Clock::time_point stop_time = Clock::now(); cerr << "Total time for rule lookup, extraction, and scoring = " << GetDuration(start_time, stop_time) << " seconds" << endl; cerr << "Extract time = " << total_extract_time << " seconds" << endl; cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; return Grammar(rules, scorer->GetFeatureNames()); } bool HieroCachingRuleFactory::CannotHaveMatchings( shared_ptr node, int word_id) { if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) { return true; } shared_ptr suffix_link = node->suffix_link; return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL; } bool HieroCachingRuleFactory::RequiresLookup( shared_ptr node, int word_id) { return !node->HasChild(word_id); } void HieroCachingRuleFactory::AddTrailingNonterminal( vector symbols, const Phrase& prefix, const shared_ptr& prefix_node, bool starts_with_x) { if (prefix.Arity() >= max_nonterminals) { return; } int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); symbols.push_back(var_id); Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( prefix.Arity() + (starts_with_x == 0)); shared_ptr var_suffix_link = prefix_node->suffix_link->GetChild(suffix_var_id); prefix_node->AddChild(var_id, make_shared( var_suffix_link, var_phrase, prefix_node->matchings)); } vector HieroCachingRuleFactory::ExtendState( const vector& word_ids, const State& state, vector symbols, const Phrase& phrase, const shared_ptr& node) { int span = state.end - state.start; vector new_states; if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() || span >= max_rule_span) { return new_states; } new_states.push_back(State(state.start, state.end + 1, symbols, state.subpatterns_start, node, state.starts_with_x)); int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); if (symbols.size() + 1 >= max_rule_symbols || phrase.Arity() >= max_nonterminals || num_subpatterns >= max_chunks) { return new_states; } int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); symbols.push_back(var_id); vector subpatterns_start = state.subpatterns_start; size_t i = state.end + 1 + min_gap_size; while (i < word_ids.size() && i - state.start <= max_rule_span) { subpatterns_start.push_back(i); new_states.push_back(State(state.start, i, symbols, subpatterns_start, node->GetChild(var_id), state.starts_with_x)); subpatterns_start.pop_back(); ++i; } return new_states; }