diff options
-rw-r--r-- | extractor/grammar_extractor.cc | 6 | ||||
-rw-r--r-- | extractor/grammar_extractor.h | 3 | ||||
-rw-r--r-- | extractor/rule_factory.cc | 5 | ||||
-rw-r--r-- | extractor/rule_factory.h | 3 | ||||
-rw-r--r-- | extractor/run_extractor.cc | 13 | ||||
-rw-r--r-- | extractor/sample_alignment.txt | 3 | ||||
-rw-r--r-- | extractor/sample_bitext.txt | 3 | ||||
-rw-r--r-- | extractor/sampler.cc | 35 | ||||
-rw-r--r-- | extractor/sampler.h | 5 |
9 files changed, 62 insertions, 14 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 8050ce7b..1fbdee5b 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -3,11 +3,13 @@ #include <iterator> #include <sstream> #include <vector> +#include <unordered_set> #include "grammar.h" #include "rule.h" #include "rule_factory.h" #include "vocabulary.h" +#include "data_array.h" using namespace std; @@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence) { +Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { vector<string> words = TokenizeSentence(sentence); vector<int> word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); } vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index b36ceeb9..6c0aafbf 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -4,6 +4,7 @@ #include <memory> #include <string> #include <vector> +#include <unordered_set> using namespace std; @@ -44,7 +45,7 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence); + Grammar GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); private: // Splits the sentence in a vector of words. diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 8c30fb9e..e52019ae 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -17,6 +17,7 @@ #include "suffix_array.h" #include "time_util.h" #include "vocabulary.h" +#include "data_array.h" using namespace std; using namespace chrono; @@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings); + PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 52e8712a..c7332720 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -3,6 +3,7 @@ #include <memory> #include <vector> +#include <unordered_set> #include "matchings_trie.h" @@ -71,7 +72,7 @@ class HieroCachingRuleFactory { // Constructs SCFG rules for a given sentence. // (See class description for more details.) - virtual Grammar GetGrammar(const vector<int>& word_ids); + virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); protected: HieroCachingRuleFactory(); diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 8a9ca89d..6eb55073 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -75,7 +75,9 @@ int main(int argc, char** argv) { ("max_samples", po::value<int>()->default_value(300), "Maximum number of samples") ("tight_phrases", po::value<bool>()->default_value(true), - "False if phrases may be loose (better, but slower)"); + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value<bool>()->zero_tokens(), + "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -96,6 +98,11 @@ int main(int argc, char** argv) { return 1; } + bool leave_one_out = false; + if (vm.count("leave_one_out")) { + leave_one_out = true; + } + int num_threads = vm["threads"].as<int>(); cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -223,7 +230,9 @@ int main(int argc, char** argv) { } suffixes[i] = suffix; - Grammar grammar = extractor.GetGrammar(sentences[i]); + unordered_set<int> blacklisted_sentence_ids; + if (leave_one_out) blacklisted_sentence_ids.insert(i); + Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt index 80b446a4..f0292b01 100644 --- a/extractor/sample_alignment.txt +++ b/extractor/sample_alignment.txt @@ -1,2 +1,5 @@ 0-0 1-1 2-2 1-0 2-1 +0-0 +0-0 1-1 +0-0 1-1 diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt index 93d6b39d..2b7c8e40 100644 --- a/extractor/sample_bitext.txt +++ b/extractor/sample_bitext.txt @@ -1,2 +1,5 @@ +asdf ||| dontseeme +qqq asdf ||| zzz fdsa +asdf qqq ||| fdsa zzz ana are mere . ||| anna has apples . ana bea mult lapte . ||| anna drinks a lot of milk . diff --git a/extractor/sampler.cc b/extractor/sampler.cc index d81956b5..2f7738db 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,18 +12,43 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location) const { +PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { vector<int> sample; int num_subpatterns; if (location.matchings == NULL) { // Sample suffix array range. num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; - double step = max(1.0, (double) (high - low) / max_samples); - for (double i = low; i < high && sample.size() < max_samples; i += step) { - sample.push_back(suffix_array->GetSuffix(Round(i))); + double step = Round(max(1.0, (double) (high - low) / max_samples)); + int i = location.sa_low; + bool found = false; + while (sample.size() < max_samples && i <= location.sa_high) { + int x = suffix_array->GetSuffix(i); + int id = source_data_array->GetSentenceId(x); + if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + int backoff_step = 1; + while (true) { + int j = i - backoff_step; + x = suffix_array->GetSuffix(j); + id = source_data_array->GetSentenceId(x); + if ((j >= location.sa_low) && (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) + && (find(sample.begin(), sample.end(), x) == sample.end())) { found = true; break; } + int k = i + backoff_step; + x = suffix_array->GetSuffix(k); + id = source_data_array->GetSentenceId(x); + if ((k <= location.sa_high) && (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) + && (find(sample.begin(), sample.end(), x) == sample.end())) { found = true; break; } + if (j <= location.sa_low && k >= location.sa_high) break; + backoff_step++; + } + } else { + found = true; + } + if (found && (find(sample.begin(), sample.end(), x) == sample.end())) sample.push_back(x); + i += step; + found = false; } - } else { + } else { // when do we get here? // Sample vector of occurrences. num_subpatterns = location.num_subpatterns; int num_matchings = location.matchings->size() / num_subpatterns; diff --git a/extractor/sampler.h b/extractor/sampler.h index be4aa1bb..30e747fd 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -2,6 +2,9 @@ #define _SAMPLER_H_ #include <memory> +#include <unordered_set> + +#include "data_array.h" using namespace std; @@ -20,7 +23,7 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location) const; + virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; protected: Sampler(); |