summaryrefslogtreecommitdiff
path: root/extractor/rule_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r--extractor/rule_factory.cc56
1 files changed, 40 insertions, 16 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 7a8356b8..c22f9b48 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -5,8 +5,15 @@
#include <queue>
#include <vector>
+#include "grammar.h"
+#include "intersector.h"
+#include "matchings_finder.h"
#include "matching_comparator.h"
#include "phrase.h"
+#include "rule.h"
+#include "rule_extractor.h"
+#include "sampler.h"
+#include "scorer.h"
#include "suffix_array.h"
#include "vocabulary.h"
@@ -30,28 +37,39 @@ struct State {
HieroCachingRuleFactory::HieroCachingRuleFactory(
shared_ptr<SuffixArray> source_suffix_array,
shared_ptr<DataArray> target_data_array,
- const Alignment& alignment,
+ shared_ptr<Alignment> alignment,
const shared_ptr<Vocabulary>& vocabulary,
- const Precomputation& precomputation,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Scorer> scorer,
int min_gap_size,
int max_rule_span,
int max_nonterminals,
int max_rule_symbols,
- bool use_baeza_yates) :
- matchings_finder(source_suffix_array),
- intersector(vocabulary, precomputation, source_suffix_array,
- make_shared<MatchingComparator>(min_gap_size, max_rule_span),
- use_baeza_yates),
- phrase_builder(vocabulary),
- rule_extractor(source_suffix_array, target_data_array, alignment),
+ int max_samples,
+ bool use_baeza_yates,
+ bool require_tight_phrases) :
vocabulary(vocabulary),
+ scorer(scorer),
min_gap_size(min_gap_size),
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_nonterminals + 1),
- max_rule_symbols(max_rule_symbols) {}
+ max_rule_symbols(max_rule_symbols) {
+ matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
+ shared_ptr<MatchingComparator> comparator =
+ make_shared<MatchingComparator>(min_gap_size, max_rule_span);
+ intersector = make_shared<Intersector>(vocabulary, precomputation,
+ source_suffix_array, comparator, use_baeza_yates);
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
+ target_data_array, alignment, phrase_builder, scorer, vocabulary,
+ max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,
+ false, require_tight_phrases);
+ sampler = make_shared<Sampler>(source_suffix_array, max_samples);
+}
+
-void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
// Clear cache for every new sentence.
trie.Reset();
shared_ptr<TrieNode> root = trie.GetRoot();
@@ -69,6 +87,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
vector<int>(1, i), x_root, true));
}
+ vector<Rule> rules;
while (!states.empty()) {
State state = states.front();
states.pop();
@@ -77,7 +96,7 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
vector<int> phrase = state.phrase;
int word_id = word_ids[state.end];
phrase.push_back(word_id);
- Phrase next_phrase = phrase_builder.Build(phrase);
+ Phrase next_phrase = phrase_builder->Build(phrase);
shared_ptr<TrieNode> next_node;
if (CannotHaveMatchings(node, word_id)) {
@@ -98,14 +117,14 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
- phrase_location = intersector.Intersect(
+ phrase_location = intersector->Intersect(
node->phrase,
node->matchings,
next_suffix_link->phrase,
next_suffix_link->matchings,
next_phrase);
} else {
- phrase_location = matchings_finder.Find(
+ phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
@@ -125,7 +144,10 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
state.starts_with_x);
if (!state.starts_with_x) {
- rule_extractor.ExtractRules();
+ PhraseLocation sample = sampler->Sample(next_node->matchings);
+ vector<Rule> new_rules =
+ rule_extractor->ExtractRules(next_phrase, sample);
+ rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
} else {
next_node = node->GetChild(word_id);
@@ -137,6 +159,8 @@ void HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
states.push(new_state);
}
}
+
+ return Grammar(rules, scorer->GetFeatureNames());
}
bool HieroCachingRuleFactory::CannotHaveMatchings(
@@ -165,7 +189,7 @@ void HieroCachingRuleFactory::AddTrailingNonterminal(
int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1);
symbols.push_back(var_id);
- Phrase var_phrase = phrase_builder.Build(symbols);
+ Phrase var_phrase = phrase_builder->Build(symbols);
int suffix_var_id = vocabulary->GetNonterminalIndex(
prefix.Arity() + starts_with_x == 0);