diff options
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r-- | extractor/rule_factory.cc | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc new file mode 100644 index 00000000..4101fcfa --- /dev/null +++ b/extractor/rule_factory.cc @@ -0,0 +1,305 @@ +#include "rule_factory.h" + +#include <chrono> +#include <memory> +#include <queue> +#include <vector> + +#include "grammar.h" +#include "fast_intersector.h" +#include "intersector.h" +#include "matchings_finder.h" +#include "matching_comparator.h" +#include "phrase.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "vocabulary.h" + +using namespace std; +using namespace chrono; + +typedef high_resolution_clock Clock; + +struct State { + State(int start, int end, const vector<int>& phrase, + const vector<int>& subpatterns_start, shared_ptr<TrieNode> node, + bool starts_with_x) : + start(start), end(end), phrase(phrase), + subpatterns_start(subpatterns_start), node(node), + starts_with_x(starts_with_x) {} + + int start, end; + vector<int> phrase, subpatterns_start; + shared_ptr<TrieNode> node; + bool starts_with_x; +}; + +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<SuffixArray> source_suffix_array, + shared_ptr<DataArray> target_data_array, + shared_ptr<Alignment> alignment, + const shared_ptr<Vocabulary>& vocabulary, + shared_ptr<Precomputation> precomputation, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_rule_symbols, + int max_samples, + bool use_fast_intersect, + bool use_baeza_yates, + bool require_tight_phrases) : + vocabulary(vocabulary), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_nonterminals + 1), + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) { + matchings_finder = make_shared<MatchingsFinder>(source_suffix_array); + shared_ptr<MatchingComparator> comparator = + make_shared<MatchingComparator>(min_gap_size, max_rule_span); + intersector = make_shared<Intersector>(vocabulary, precomputation, + source_suffix_array, comparator, use_baeza_yates); + fast_intersector = make_shared<FastIntersector>(source_suffix_array, + precomputation, vocabulary, max_rule_span, min_gap_size); + phrase_builder = make_shared<PhraseBuilder>(vocabulary); + rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(), + target_data_array, alignment, phrase_builder, scorer, vocabulary, + max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, + false, require_tight_phrases); + sampler = make_shared<Sampler>(source_suffix_array, max_samples); +} + +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<FastIntersector> fast_intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols, + bool use_fast_intersect) : + matchings_finder(finder), + intersector(intersector), + fast_intersector(fast_intersector), + phrase_builder(phrase_builder), + rule_extractor(rule_extractor), + vocabulary(vocabulary), + sampler(sampler), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_chunks), + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {} + +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { + intersector->sort_time = 0; + Clock::time_point start_time = Clock::now(); + double total_extract_time = 0; + double total_intersect_time = 0; + double total_lookup_time = 0; + // Clear cache for every new sentence. + trie.Reset(); + shared_ptr<TrieNode> root = trie.GetRoot(); + + int first_x = vocabulary->GetNonterminalIndex(1); + shared_ptr<TrieNode> x_root(new TrieNode(root)); + root->AddChild(first_x, x_root); + + queue<State> states; + for (size_t i = 0; i < word_ids.size(); ++i) { + states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false)); + } + for (size_t i = min_gap_size; i < word_ids.size(); ++i) { + states.push(State(i - min_gap_size, i, vector<int>(1, first_x), + vector<int>(1, i), x_root, true)); + } + + vector<Rule> rules; + while (!states.empty()) { + State state = states.front(); + states.pop(); + + shared_ptr<TrieNode> node = state.node; + vector<int> phrase = state.phrase; + int word_id = word_ids[state.end]; + phrase.push_back(word_id); + Phrase next_phrase = phrase_builder->Build(phrase); + shared_ptr<TrieNode> next_node; + + if (CannotHaveMatchings(node, word_id)) { + if (!node->HasChild(word_id)) { + node->AddChild(word_id, shared_ptr<TrieNode>()); + } + continue; + } + + if (RequiresLookup(node, word_id)) { + shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? + trie.GetRoot() : node->suffix_link->GetChild(word_id); + if (state.starts_with_x) { + // If the phrase starts with a non terminal, we simply use the matchings + // from the suffix link. + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, next_suffix_link->matchings); + } else { + PhraseLocation phrase_location; + if (next_phrase.Arity() > 0) { + Clock::time_point intersect_start = Clock::now(); + if (use_fast_intersect) { + phrase_location = fast_intersector->Intersect( + node->matchings, next_suffix_link->matchings, next_phrase); + } else { + phrase_location = intersector->Intersect( + node->phrase, + node->matchings, + next_suffix_link->phrase, + next_suffix_link->matchings, + next_phrase); + } + Clock::time_point intersect_stop = Clock::now(); + total_intersect_time += GetDuration(intersect_start, intersect_stop); + } else { + Clock::time_point lookup_start = Clock::now(); + phrase_location = matchings_finder->Find( + node->matchings, + vocabulary->GetTerminalValue(word_id), + state.phrase.size()); + Clock::time_point lookup_stop = Clock::now(); + total_lookup_time += GetDuration(lookup_start, lookup_stop); + } + + if (phrase_location.IsEmpty()) { + continue; + } + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, phrase_location); + } + node->AddChild(word_id, next_node); + + // Automatically adds a trailing non terminal if allowed. Simply copy the + // matchings from the prefix node. + AddTrailingNonterminal(phrase, next_phrase, next_node, + state.starts_with_x); + + Clock::time_point extract_start = Clock::now(); + if (!state.starts_with_x) { + PhraseLocation sample = sampler->Sample(next_node->matchings); + vector<Rule> new_rules = + rule_extractor->ExtractRules(next_phrase, sample); + rules.insert(rules.end(), new_rules.begin(), new_rules.end()); + } + Clock::time_point extract_stop = Clock::now(); + total_extract_time += GetDuration(extract_start, extract_stop); + } else { + next_node = node->GetChild(word_id); + } + + vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase, + next_node); + for (State new_state: new_states) { + states.push(new_state); + } + } + + Clock::time_point stop_time = Clock::now(); + cerr << "Total time for rule lookup, extraction, and scoring = " + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Extract time = " << total_extract_time << " seconds" << endl; + cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; + cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; + return Grammar(rules, scorer->GetFeatureNames()); +} + +bool HieroCachingRuleFactory::CannotHaveMatchings( + shared_ptr<TrieNode> node, int word_id) { + if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) { + return true; + } + + shared_ptr<TrieNode> suffix_link = node->suffix_link; + return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL; +} + +bool HieroCachingRuleFactory::RequiresLookup( + shared_ptr<TrieNode> node, int word_id) { + return !node->HasChild(word_id); +} + +void HieroCachingRuleFactory::AddTrailingNonterminal( + vector<int> symbols, + const Phrase& prefix, + const shared_ptr<TrieNode>& prefix_node, + bool starts_with_x) { + if (prefix.Arity() >= max_nonterminals) { + return; + } + + int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); + symbols.push_back(var_id); + Phrase var_phrase = phrase_builder->Build(symbols); + + int suffix_var_id = vocabulary->GetNonterminalIndex( + prefix.Arity() + (starts_with_x == 0)); + shared_ptr<TrieNode> var_suffix_link = + prefix_node->suffix_link->GetChild(suffix_var_id); + + prefix_node->AddChild(var_id, make_shared<TrieNode>( + var_suffix_link, var_phrase, prefix_node->matchings)); +} + +vector<State> HieroCachingRuleFactory::ExtendState( + const vector<int>& word_ids, + const State& state, + vector<int> symbols, + const Phrase& phrase, + const shared_ptr<TrieNode>& node) { + int span = state.end - state.start; + vector<State> new_states; + if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() || + span >= max_rule_span) { + return new_states; + } + + new_states.push_back(State(state.start, state.end + 1, symbols, + state.subpatterns_start, node, state.starts_with_x)); + + int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); + if (symbols.size() + 1 >= max_rule_symbols || + phrase.Arity() >= max_nonterminals || + num_subpatterns >= max_chunks) { + return new_states; + } + + int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); + symbols.push_back(var_id); + vector<int> subpatterns_start = state.subpatterns_start; + size_t i = state.end + 1 + min_gap_size; + while (i < word_ids.size() && i - state.start <= max_rule_span) { + subpatterns_start.push_back(i); + new_states.push_back(State(state.start, i, symbols, subpatterns_start, + node->GetChild(var_id), state.starts_with_x)); + subpatterns_start.pop_back(); + ++i; + } + + return new_states; +} |