summaryrefslogtreecommitdiff
path: root/extractor/rule_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r--extractor/rule_factory.cc305
1 files changed, 305 insertions, 0 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
new file mode 100644
index 00000000..4101fcfa
--- /dev/null
+++ b/extractor/rule_factory.cc
@@ -0,0 +1,305 @@
+#include "rule_factory.h"
+
+#include <chrono>
+#include <memory>
+#include <queue>
+#include <vector>
+
+#include "grammar.h"
+#include "fast_intersector.h"
+#include "intersector.h"
+#include "matchings_finder.h"
+#include "matching_comparator.h"
+#include "phrase.h"
+#include "rule.h"
+#include "rule_extractor.h"
+#include "sampler.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "vocabulary.h"
+
+using namespace std;
+using namespace chrono;
+
+typedef high_resolution_clock Clock;
+
+struct State {
+ State(int start, int end, const vector<int>& phrase,
+ const vector<int>& subpatterns_start, shared_ptr<TrieNode> node,
+ bool starts_with_x) :
+ start(start), end(end), phrase(phrase),
+ subpatterns_start(subpatterns_start), node(node),
+ starts_with_x(starts_with_x) {}
+
+ int start, end;
+ vector<int> phrase, subpatterns_start;
+ shared_ptr<TrieNode> node;
+ bool starts_with_x;
+};
+
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<SuffixArray> source_suffix_array,
+ shared_ptr<DataArray> target_data_array,
+ shared_ptr<Alignment> alignment,
+ const shared_ptr<Vocabulary>& vocabulary,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_rule_symbols,
+ int max_samples,
+ bool use_fast_intersect,
+ bool use_baeza_yates,
+ bool require_tight_phrases) :
+ vocabulary(vocabulary),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_nonterminals + 1),
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {
+ matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
+ shared_ptr<MatchingComparator> comparator =
+ make_shared<MatchingComparator>(min_gap_size, max_rule_span);
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ source_suffix_array, comparator, use_baeza_yates);
+ fast_intersector = make_shared<FastIntersector>(source_suffix_array,
+ precomputation, vocabulary, max_rule_span, min_gap_size);
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
+ target_data_array, alignment, phrase_builder, scorer, vocabulary,
+ max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
+ false, require_tight_phrases);
+ sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<Intersector> intersector,
+ shared_ptr<FastIntersector> fast_intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols,
+ bool use_fast_intersect) :
+ matchings_finder(finder),
+ intersector(intersector),
+ fast_intersector(fast_intersector),
+ phrase_builder(phrase_builder),
+ rule_extractor(rule_extractor),
+ vocabulary(vocabulary),
+ sampler(sampler),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_chunks),
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory() {}
+
+HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
+
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+ intersector->sort_time = 0;
+ Clock::time_point start_time = Clock::now();
+ double total_extract_time = 0;
+ double total_intersect_time = 0;
+ double total_lookup_time = 0;
+ // Clear cache for every new sentence.
+ trie.Reset();
+ shared_ptr<TrieNode> root = trie.GetRoot();
+
+ int first_x = vocabulary->GetNonterminalIndex(1);
+ shared_ptr<TrieNode> x_root(new TrieNode(root));
+ root->AddChild(first_x, x_root);
+
+ queue<State> states;
+ for (size_t i = 0; i < word_ids.size(); ++i) {
+ states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false));
+ }
+ for (size_t i = min_gap_size; i < word_ids.size(); ++i) {
+ states.push(State(i - min_gap_size, i, vector<int>(1, first_x),
+ vector<int>(1, i), x_root, true));
+ }
+
+ vector<Rule> rules;
+ while (!states.empty()) {
+ State state = states.front();
+ states.pop();
+
+ shared_ptr<TrieNode> node = state.node;
+ vector<int> phrase = state.phrase;
+ int word_id = word_ids[state.end];
+ phrase.push_back(word_id);
+ Phrase next_phrase = phrase_builder->Build(phrase);
+ shared_ptr<TrieNode> next_node;
+
+ if (CannotHaveMatchings(node, word_id)) {
+ if (!node->HasChild(word_id)) {
+ node->AddChild(word_id, shared_ptr<TrieNode>());
+ }
+ continue;
+ }
+
+ if (RequiresLookup(node, word_id)) {
+ shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ?
+ trie.GetRoot() : node->suffix_link->GetChild(word_id);
+ if (state.starts_with_x) {
+ // If the phrase starts with a non terminal, we simply use the matchings
+ // from the suffix link.
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, next_suffix_link->matchings);
+ } else {
+ PhraseLocation phrase_location;
+ if (next_phrase.Arity() > 0) {
+ Clock::time_point intersect_start = Clock::now();
+ if (use_fast_intersect) {
+ phrase_location = fast_intersector->Intersect(
+ node->matchings, next_suffix_link->matchings, next_phrase);
+ } else {
+ phrase_location = intersector->Intersect(
+ node->phrase,
+ node->matchings,
+ next_suffix_link->phrase,
+ next_suffix_link->matchings,
+ next_phrase);
+ }
+ Clock::time_point intersect_stop = Clock::now();
+ total_intersect_time += GetDuration(intersect_start, intersect_stop);
+ } else {
+ Clock::time_point lookup_start = Clock::now();
+ phrase_location = matchings_finder->Find(
+ node->matchings,
+ vocabulary->GetTerminalValue(word_id),
+ state.phrase.size());
+ Clock::time_point lookup_stop = Clock::now();
+ total_lookup_time += GetDuration(lookup_start, lookup_stop);
+ }
+
+ if (phrase_location.IsEmpty()) {
+ continue;
+ }
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, phrase_location);
+ }
+ node->AddChild(word_id, next_node);
+
+ // Automatically adds a trailing non terminal if allowed. Simply copy the
+ // matchings from the prefix node.
+ AddTrailingNonterminal(phrase, next_phrase, next_node,
+ state.starts_with_x);
+
+ Clock::time_point extract_start = Clock::now();
+ if (!state.starts_with_x) {
+ PhraseLocation sample = sampler->Sample(next_node->matchings);
+ vector<Rule> new_rules =
+ rule_extractor->ExtractRules(next_phrase, sample);
+ rules.insert(rules.end(), new_rules.begin(), new_rules.end());
+ }
+ Clock::time_point extract_stop = Clock::now();
+ total_extract_time += GetDuration(extract_start, extract_stop);
+ } else {
+ next_node = node->GetChild(word_id);
+ }
+
+ vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase,
+ next_node);
+ for (State new_state: new_states) {
+ states.push(new_state);
+ }
+ }
+
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Total time for rule lookup, extraction, and scoring = "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Extract time = " << total_extract_time << " seconds" << endl;
+ cerr << "Intersect time = " << total_intersect_time << " seconds" << endl;
+ cerr << "Lookup time = " << total_lookup_time << " seconds" << endl;
+ return Grammar(rules, scorer->GetFeatureNames());
+}
+
+bool HieroCachingRuleFactory::CannotHaveMatchings(
+ shared_ptr<TrieNode> node, int word_id) {
+ if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) {
+ return true;
+ }
+
+ shared_ptr<TrieNode> suffix_link = node->suffix_link;
+ return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL;
+}
+
+bool HieroCachingRuleFactory::RequiresLookup(
+ shared_ptr<TrieNode> node, int word_id) {
+ return !node->HasChild(word_id);
+}
+
+void HieroCachingRuleFactory::AddTrailingNonterminal(
+ vector<int> symbols,
+ const Phrase& prefix,
+ const shared_ptr<TrieNode>& prefix_node,
+ bool starts_with_x) {
+ if (prefix.Arity() >= max_nonterminals) {
+ return;
+ }
+
+ int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1);
+ symbols.push_back(var_id);
+ Phrase var_phrase = phrase_builder->Build(symbols);
+
+ int suffix_var_id = vocabulary->GetNonterminalIndex(
+ prefix.Arity() + (starts_with_x == 0));
+ shared_ptr<TrieNode> var_suffix_link =
+ prefix_node->suffix_link->GetChild(suffix_var_id);
+
+ prefix_node->AddChild(var_id, make_shared<TrieNode>(
+ var_suffix_link, var_phrase, prefix_node->matchings));
+}
+
+vector<State> HieroCachingRuleFactory::ExtendState(
+ const vector<int>& word_ids,
+ const State& state,
+ vector<int> symbols,
+ const Phrase& phrase,
+ const shared_ptr<TrieNode>& node) {
+ int span = state.end - state.start;
+ vector<State> new_states;
+ if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() ||
+ span >= max_rule_span) {
+ return new_states;
+ }
+
+ new_states.push_back(State(state.start, state.end + 1, symbols,
+ state.subpatterns_start, node, state.starts_with_x));
+
+ int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);
+ if (symbols.size() + 1 >= max_rule_symbols ||
+ phrase.Arity() >= max_nonterminals ||
+ num_subpatterns >= max_chunks) {
+ return new_states;
+ }
+
+ int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1);
+ symbols.push_back(var_id);
+ vector<int> subpatterns_start = state.subpatterns_start;
+ size_t i = state.end + 1 + min_gap_size;
+ while (i < word_ids.size() && i - state.start <= max_rule_span) {
+ subpatterns_start.push_back(i);
+ new_states.push_back(State(state.start, i, symbols, subpatterns_start,
+ node->GetChild(var_id), state.starts_with_x));
+ subpatterns_start.pop_back();
+ ++i;
+ }
+
+ return new_states;
+}