summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2013-11-13 11:22:24 -0800
committerChris Dyer <redpony@gmail.com>2013-11-13 11:22:24 -0800
commit9be8d89b5a4065a81b26d8af1f3443d152e7922a (patch)
tree237090ff519a0419c3ba379ec3a6884f05caa6c2 /extractor
parent8a24bb77bc2e9fd17a6f6529a2942cde96a6af49 (diff)
parent4a9449a564e626fe004200b730bfaa44d6152e0f (diff)
Merge pull request #27 from pks/master
Tidying (soft) syntax features; loo for C++ extractor; updates for dtrain
Diffstat (limited to 'extractor')
-rw-r--r--extractor/grammar_extractor.cc6
-rw-r--r--extractor/grammar_extractor.h3
-rw-r--r--extractor/grammar_extractor_test.cc7
-rw-r--r--extractor/mocks/mock_rule_factory.h2
-rw-r--r--extractor/rule_factory.cc5
-rw-r--r--extractor/rule_factory.h3
-rw-r--r--extractor/rule_factory_test.cc8
-rw-r--r--extractor/run_extractor.cc13
-rw-r--r--extractor/sample_source.txt2
-rw-r--r--extractor/sampler.cc35
-rw-r--r--extractor/sampler.h5
-rw-r--r--extractor/sampler_test.cc24
-rw-r--r--extractor/sampler_test_blacklist.cc102
13 files changed, 189 insertions, 26 deletions
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 8050ce7b..1fbdee5b 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -3,11 +3,13 @@
#include <iterator>
#include <sstream>
#include <vector>
+#include <unordered_set>
#include "grammar.h"
#include "rule.h"
#include "rule_factory.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
@@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence) {
+Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index b36ceeb9..6c0aafbf 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -4,6 +4,7 @@
#include <memory>
#include <string>
#include <vector>
+#include <unordered_set>
using namespace std;
@@ -44,7 +45,7 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence);
+ Grammar GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index 823bb8b4..f32a9599 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -39,12 +39,15 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
vector<Rule> rules;
vector<string> feature_names;
Grammar grammar(rules, feature_names);
- EXPECT_CALL(*factory, GetGrammar(word_ids))
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence);
+
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
}
} // namespace
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 7389b396..86a084b5 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,7 +7,7 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids));
+ MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array));
};
} // namespace extractor
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 8c30fb9e..e52019ae 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -17,6 +17,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
using namespace chrono;
@@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings);
+ PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index 52e8712a..c7332720 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -3,6 +3,7 @@
#include <memory>
#include <vector>
+#include <unordered_set>
#include "matchings_trie.h"
@@ -71,7 +72,7 @@ class HieroCachingRuleFactory {
// Constructs SCFG rules for a given sentence.
// (See class description for more details.)
- virtual Grammar GetGrammar(const vector<int>& word_ids);
+ virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index 08af3dcd..f26cc567 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -76,7 +76,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector<int> word_ids = {2, 3, 4};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -94,7 +96,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector<int> word_ids = {2, 3, 4, 2, 3};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 8a9ca89d..6eb55073 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -75,7 +75,9 @@ int main(int argc, char** argv) {
("max_samples", po::value<int>()->default_value(300),
"Maximum number of samples")
("tight_phrases", po::value<bool>()->default_value(true),
- "False if phrases may be loose (better, but slower)");
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -96,6 +98,11 @@ int main(int argc, char** argv) {
return 1;
}
+ bool leave_one_out = false;
+ if (vm.count("leave_one_out")) {
+ leave_one_out = true;
+ }
+
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -223,7 +230,9 @@ int main(int argc, char** argv) {
}
suffixes[i] = suffix;
- Grammar grammar = extractor.GetGrammar(sentences[i]);
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) blacklisted_sentence_ids.insert(i);
+ Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sample_source.txt b/extractor/sample_source.txt
new file mode 100644
index 00000000..971baf6d
--- /dev/null
+++ b/extractor/sample_source.txt
@@ -0,0 +1,2 @@
+ana are mere .
+ana bea mult lapte .
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index d81956b5..d332dd90 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,7 +12,7 @@ Sampler::Sampler() {}
Sampler::~Sampler() {}
-PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
+PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
@@ -20,8 +20,37 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
double step = max(1.0, (double) (high - low) / max_samples);
- for (double i = low; i < high && sample.size() < max_samples; i += step) {
- sample.push_back(suffix_array->GetSuffix(Round(i)));
+ double i = low, last = i;
+ bool found;
+ while (sample.size() < max_samples && i < high) {
+ int x = suffix_array->GetSuffix(Round(i));
+ int id = source_data_array->GetSentenceId(x);
+ if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
+ found = false;
+ double backoff_step = 1;
+ while (true) {
+ if ((double)backoff_step >= step) break;
+ double j = i - backoff_step;
+ x = suffix_array->GetSuffix(Round(j));
+ id = source_data_array->GetSentenceId(x);
+ if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ found = true; last = i; break;
+ }
+ double k = i + backoff_step;
+ x = suffix_array->GetSuffix(Round(k));
+ id = source_data_array->GetSentenceId(x);
+ if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ found = true; last = k; break;
+ }
+ if (j <= last && k >= high) break;
+ backoff_step++;
+ }
+ } else {
+ found = true;
+ last = i;
+ }
+ if (found) sample.push_back(x);
+ i += step;
}
} else {
// Sample vector of occurrences.
diff --git a/extractor/sampler.h b/extractor/sampler.h
index be4aa1bb..30e747fd 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -2,6 +2,9 @@
#define _SAMPLER_H_
#include <memory>
+#include <unordered_set>
+
+#include "data_array.h"
using namespace std;
@@ -20,7 +23,7 @@ class Sampler {
virtual ~Sampler();
// Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location) const;
+ virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
protected:
Sampler();
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index e9abebfa..965567ba 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -3,6 +3,7 @@
#include <memory>
#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
#include "phrase_location.h"
#include "sampler.h"
@@ -15,6 +16,8 @@ namespace {
class SamplerTest : public Test {
protected:
virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
suffix_array = make_shared<MockSuffixArray>();
for (int i = 0; i < 10; ++i) {
EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
@@ -23,51 +26,54 @@ class SamplerTest : public Test {
shared_ptr<MockSuffixArray> suffix_array;
shared_ptr<Sampler> sampler;
+ shared_ptr<MockDataArray> source_data_array;
};
TEST_F(SamplerTest, TestSuffixArrayRange) {
PhraseLocation location(0, 10);
+ unordered_set<int> blacklist;
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 4);
expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 100);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
}
TEST_F(SamplerTest, TestSubstringsSample) {
vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ unordered_set<int> blacklist;
PhraseLocation location(locations, 2);
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 7);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
}
} // namespace
diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc
new file mode 100644
index 00000000..3305b990
--- /dev/null
+++ b/extractor/sampler_test_blacklist.cc
@@ -0,0 +1,102 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
+#include "phrase_location.h"
+#include "sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SamplerTestBlacklist : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
+ }
+ for (int i = -10; i < 0; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0));
+ }
+ suffix_array = make_shared<MockSuffixArray>();
+ for (int i = -10; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<Sampler> sampler;
+ shared_ptr<MockDataArray> source_data_array;
+};
+
+TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) {
+ PhraseLocation location(0, 10);
+ unordered_set<int> blacklist;
+ vector<int> expected_locations;
+
+ blacklist.insert(0);
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ expected_locations = {1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ for (int i = 0; i < 9; i++) {
+ blacklist.insert(i);
+ }
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ expected_locations = {9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(5);
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {1, 4};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(1);
+ blacklist.insert(2);
+ blacklist.insert(3);
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(7);
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {1, 2, 6};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(5);
+ blacklist.insert(8);
+ sampler = make_shared<Sampler>(suffix_array, 4);
+ expected_locations = {1, 2, 4, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(9);
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+}
+
+} // namespace
+} // namespace extractor