summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extractor/grammar_extractor.cc6
-rw-r--r--extractor/grammar_extractor.h3
-rw-r--r--extractor/rule_factory.cc5
-rw-r--r--extractor/rule_factory.h3
-rw-r--r--extractor/run_extractor.cc13
-rw-r--r--extractor/sample_alignment.txt3
-rw-r--r--extractor/sample_bitext.txt3
-rw-r--r--extractor/sampler.cc35
-rw-r--r--extractor/sampler.h5
9 files changed, 62 insertions, 14 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 8050ce7b..1fbdee5b 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -3,11 +3,13 @@
#include <iterator>
#include <sstream>
#include <vector>
+#include <unordered_set>
#include "grammar.h"
#include "rule.h"
#include "rule_factory.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
@@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence) {
+Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index b36ceeb9..6c0aafbf 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -4,6 +4,7 @@
#include <memory>
#include <string>
#include <vector>
+#include <unordered_set>
using namespace std;
@@ -44,7 +45,7 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence);
+ Grammar GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 8c30fb9e..e52019ae 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -17,6 +17,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
using namespace chrono;
@@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings);
+ PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index 52e8712a..c7332720 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -3,6 +3,7 @@
#include <memory>
#include <vector>
+#include <unordered_set>
#include "matchings_trie.h"
@@ -71,7 +72,7 @@ class HieroCachingRuleFactory {
// Constructs SCFG rules for a given sentence.
// (See class description for more details.)
- virtual Grammar GetGrammar(const vector<int>& word_ids);
+ virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 8a9ca89d..6eb55073 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -75,7 +75,9 @@ int main(int argc, char** argv) {
("max_samples", po::value<int>()->default_value(300),
"Maximum number of samples")
("tight_phrases", po::value<bool>()->default_value(true),
- "False if phrases may be loose (better, but slower)");
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -96,6 +98,11 @@ int main(int argc, char** argv) {
return 1;
}
+ bool leave_one_out = false;
+ if (vm.count("leave_one_out")) {
+ leave_one_out = true;
+ }
+
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -223,7 +230,9 @@ int main(int argc, char** argv) {
}
suffixes[i] = suffix;
- Grammar grammar = extractor.GetGrammar(sentences[i]);
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) blacklisted_sentence_ids.insert(i);
+ Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt
index 80b446a4..f0292b01 100644
--- a/extractor/sample_alignment.txt
+++ b/extractor/sample_alignment.txt
@@ -1,2 +1,5 @@
0-0 1-1 2-2
1-0 2-1
+0-0
+0-0 1-1
+0-0 1-1
diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt
index 93d6b39d..2b7c8e40 100644
--- a/extractor/sample_bitext.txt
+++ b/extractor/sample_bitext.txt
@@ -1,2 +1,5 @@
+asdf ||| dontseeme
+qqq asdf ||| zzz fdsa
+asdf qqq ||| fdsa zzz
ana are mere . ||| anna has apples .
ana bea mult lapte . ||| anna drinks a lot of milk .
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index d81956b5..2f7738db 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,18 +12,43 @@ Sampler::Sampler() {}
Sampler::~Sampler() {}
-PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
+PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
// Sample suffix array range.
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
- double step = max(1.0, (double) (high - low) / max_samples);
- for (double i = low; i < high && sample.size() < max_samples; i += step) {
- sample.push_back(suffix_array->GetSuffix(Round(i)));
+ double step = Round(max(1.0, (double) (high - low) / max_samples));
+ int i = location.sa_low;
+ bool found = false;
+ while (sample.size() < max_samples && i <= location.sa_high) {
+ int x = suffix_array->GetSuffix(i);
+ int id = source_data_array->GetSentenceId(x);
+ if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
+ int backoff_step = 1;
+ while (true) {
+ int j = i - backoff_step;
+ x = suffix_array->GetSuffix(j);
+ id = source_data_array->GetSentenceId(x);
+ if ((j >= location.sa_low) && (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end())
+ && (find(sample.begin(), sample.end(), x) == sample.end())) { found = true; break; }
+ int k = i + backoff_step;
+ x = suffix_array->GetSuffix(k);
+ id = source_data_array->GetSentenceId(x);
+ if ((k <= location.sa_high) && (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end())
+ && (find(sample.begin(), sample.end(), x) == sample.end())) { found = true; break; }
+ if (j <= location.sa_low && k >= location.sa_high) break;
+ backoff_step++;
+ }
+ } else {
+ found = true;
+ }
+ if (found && (find(sample.begin(), sample.end(), x) == sample.end())) sample.push_back(x);
+ i += step;
+ found = false;
}
- } else {
+ } else { // when do we get here?
// Sample vector of occurrences.
num_subpatterns = location.num_subpatterns;
int num_matchings = location.matchings->size() / num_subpatterns;
diff --git a/extractor/sampler.h b/extractor/sampler.h
index be4aa1bb..30e747fd 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -2,6 +2,9 @@
#define _SAMPLER_H_
#include <memory>
+#include <unordered_set>
+
+#include "data_array.h"
using namespace std;
@@ -20,7 +23,7 @@ class Sampler {
virtual ~Sampler();
// Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location) const;
+ virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
protected:
Sampler();