diff options
Diffstat (limited to 'extractor')
| -rw-r--r-- | extractor/Makefile.am | 27 | ||||
| -rw-r--r-- | extractor/grammar_extractor.cc | 6 | ||||
| -rw-r--r-- | extractor/grammar_extractor.h | 3 | ||||
| -rw-r--r-- | extractor/grammar_extractor_test.cc | 7 | ||||
| -rw-r--r-- | extractor/mocks/mock_rule_factory.h | 2 | ||||
| -rw-r--r-- | extractor/rule_factory.cc | 5 | ||||
| -rw-r--r-- | extractor/rule_factory.h | 3 | ||||
| -rw-r--r-- | extractor/rule_factory_test.cc | 8 | ||||
| -rw-r--r-- | extractor/run_extractor.cc | 13 | ||||
| -rw-r--r-- | extractor/sample_source.txt | 2 | ||||
| -rw-r--r-- | extractor/sampler.cc | 35 | ||||
| -rw-r--r-- | extractor/sampler.h | 5 | ||||
| -rw-r--r-- | extractor/sampler_test.cc | 24 | ||||
| -rw-r--r-- | extractor/sampler_test_blacklist.cc | 102 | 
14 files changed, 214 insertions, 28 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am index e94a9b91..65a3d436 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,7 +1,8 @@ -if HAVE_CXX11  bin_PROGRAMS = compile run_extractor +if HAVE_CXX11 +  EXTRA_PROGRAMS = alignment_test \      data_array_test \      fast_intersector_test \ @@ -113,7 +114,29 @@ libcompile_a_SOURCES = \    precomputation.cc \    suffix_array.cc \    time_util.cc \ -  translation_table.cc +  translation_table.cc \ +  alignment.h \ +  data_array.h \ +  fast_intersector.h \ +  grammar.h \ +  grammar_extractor.h \ +  matchings_finder.h \ +  matchings_trie.h \ +  phrase.h \ +  phrase_builder.h \ +  phrase_location.h \ +  precomputation.h \ +  rule.h \ +  rule_extractor.h \ +  rule_extractor_helper.h \ +  rule_factory.h \ +  sampler.h \ +  scorer.h \ +  suffix_array.h \ +  target_phrase_extractor.h \ +  time_util.h \ +  translation_table.h \ +  vocabulary.h  libextractor_a_SOURCES = \    alignment.cc \ diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 8050ce7b..487abcaf 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -3,11 +3,13 @@  #include <iterator>  #include <sstream>  #include <vector> +#include <unordered_set>  #include "grammar.h"  #include "rule.h"  #include "rule_factory.h"  #include "vocabulary.h" +#include "data_array.h"  using namespace std; @@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor(      vocabulary(vocabulary),      rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence) { +Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {    vector<string> words = TokenizeSentence(sentence);    vector<int> word_ids = AnnotateWords(words); -  return rule_factory->GetGrammar(word_ids); +  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);  }  vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index b36ceeb9..ae407b47 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -4,6 +4,7 @@  #include <memory>  #include <string>  #include <vector> +#include <unordered_set>  using namespace std; @@ -44,7 +45,7 @@ class GrammarExtractor {    // Converts the sentence to a vector of word ids and uses the RuleFactory to    // extract the SCFG rules which may be used to decode the sentence. -  Grammar GetGrammar(const string& sentence); +  Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);   private:    // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index 823bb8b4..f32a9599 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -39,12 +39,15 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {    vector<Rule> rules;    vector<string> feature_names;    Grammar grammar(rules, feature_names); -  EXPECT_CALL(*factory, GetGrammar(word_ids)) +  unordered_set<int> blacklisted_sentence_ids; +  shared_ptr<DataArray> source_data_array; +  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))        .WillOnce(Return(grammar));    GrammarExtractor extractor(vocabulary, factory);    string sentence = "Anna has many many apples ."; -  extractor.GetGrammar(sentence); + +  extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);  }  } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 7389b396..86a084b5 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,7 +7,7 @@ namespace extractor {  class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {   public: -  MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids)); +  MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array));  };  } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 8c30fb9e..6ae2d792 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -17,6 +17,7 @@  #include "suffix_array.h"  #include "time_util.h"  #include "vocabulary.h" +#include "data_array.h"  using namespace std;  using namespace chrono; @@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}  HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {    Clock::time_point start_time = Clock::now();    double total_extract_time = 0;    double total_intersect_time = 0; @@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {        Clock::time_point extract_start = Clock::now();        if (!state.starts_with_x) {          // Extract rules for the sampled set of occurrences. -        PhraseLocation sample = sampler->Sample(next_node->matchings); +        PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);          vector<Rule> new_rules =              rule_extractor->ExtractRules(next_phrase, sample);          rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 52e8712a..df63a9d8 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -3,6 +3,7 @@  #include <memory>  #include <vector> +#include <unordered_set>  #include "matchings_trie.h" @@ -71,7 +72,7 @@ class HieroCachingRuleFactory {    // Constructs SCFG rules for a given sentence.    // (See class description for more details.) -  virtual Grammar GetGrammar(const vector<int>& word_ids); +  virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);   protected:    HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index 08af3dcd..f26cc567 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -76,7 +76,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {        .WillRepeatedly(Return(PhraseLocation(0, 1)));    vector<int> word_ids = {2, 3, 4}; -  Grammar grammar = factory->GetGrammar(word_ids); +  unordered_set<int> blacklisted_sentence_ids; +  shared_ptr<DataArray> source_data_array; +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(7, grammar.GetRules().size());  } @@ -94,7 +96,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {        .WillRepeatedly(Return(PhraseLocation(0, 1)));    vector<int> word_ids = {2, 3, 4, 2, 3}; -  Grammar grammar = factory->GetGrammar(word_ids); +  unordered_set<int> blacklisted_sentence_ids; +  shared_ptr<DataArray> source_data_array; +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(28, grammar.GetRules().size());  } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 8a9ca89d..6eb55073 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -75,7 +75,9 @@ int main(int argc, char** argv) {      ("max_samples", po::value<int>()->default_value(300),          "Maximum number of samples")      ("tight_phrases", po::value<bool>()->default_value(true), -        "False if phrases may be loose (better, but slower)"); +        "False if phrases may be loose (better, but slower)") +    ("leave_one_out", po::value<bool>()->zero_tokens(), +        "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");    po::variables_map vm;    po::store(po::parse_command_line(argc, argv, desc), vm); @@ -96,6 +98,11 @@ int main(int argc, char** argv) {      return 1;    } +  bool leave_one_out = false; +  if (vm.count("leave_one_out")) { +    leave_one_out = true; +  } +    int num_threads = vm["threads"].as<int>();    cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -223,7 +230,9 @@ int main(int argc, char** argv) {      }      suffixes[i] = suffix; -    Grammar grammar = extractor.GetGrammar(sentences[i]); +    unordered_set<int> blacklisted_sentence_ids; +    if (leave_one_out) blacklisted_sentence_ids.insert(i); +    Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);      ofstream output(GetGrammarFilePath(grammar_path, i).c_str());      output << grammar;    } diff --git a/extractor/sample_source.txt b/extractor/sample_source.txt new file mode 100644 index 00000000..971baf6d --- /dev/null +++ b/extractor/sample_source.txt @@ -0,0 +1,2 @@ +ana are mere . +ana bea mult lapte . diff --git a/extractor/sampler.cc b/extractor/sampler.cc index d81956b5..963afa7a 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,7 @@ Sampler::Sampler() {}  Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location) const { +PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {    vector<int> sample;    int num_subpatterns;    if (location.matchings == NULL) { @@ -20,8 +20,37 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {      num_subpatterns = 1;      int low = location.sa_low, high = location.sa_high;      double step = max(1.0, (double) (high - low) / max_samples); -    for (double i = low; i < high && sample.size() < max_samples; i += step) { -      sample.push_back(suffix_array->GetSuffix(Round(i))); +    double i = low, last = i; +    bool found; +    while (sample.size() < max_samples && i < high) { +      int x = suffix_array->GetSuffix(Round(i)); +      int id = source_data_array->GetSentenceId(x); +      if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { +        found = false; +        double backoff_step = 1; +        while (true) { +          if ((double)backoff_step >= step) break; +          double j = i - backoff_step; +          x = suffix_array->GetSuffix(Round(j)); +          id = source_data_array->GetSentenceId(x); +          if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { +            found = true; last = i; break; +          } +          double k = i + backoff_step; +          x = suffix_array->GetSuffix(Round(k)); +          id = source_data_array->GetSentenceId(x); +          if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { +            found = true; last = k; break; +          } +          if (j <= last && k >= high) break; +          backoff_step++; +        } +      } else { +        found = true; +        last = i; +      } +      if (found) sample.push_back(x); +      i += step;      }    } else {      // Sample vector of occurrences. diff --git a/extractor/sampler.h b/extractor/sampler.h index be4aa1bb..de450c48 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -2,6 +2,9 @@  #define _SAMPLER_H_  #include <memory> +#include <unordered_set> + +#include "data_array.h"  using namespace std; @@ -20,7 +23,7 @@ class Sampler {    virtual ~Sampler();    // Samples uniformly at most max_samples phrase occurrences. -  virtual PhraseLocation Sample(const PhraseLocation& location) const; +  virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;   protected:    Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index e9abebfa..965567ba 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -3,6 +3,7 @@  #include <memory>  #include "mocks/mock_suffix_array.h" +#include "mocks/mock_data_array.h"  #include "phrase_location.h"  #include "sampler.h" @@ -15,6 +16,8 @@ namespace {  class SamplerTest : public Test {   protected:    virtual void SetUp() { +    source_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));      suffix_array = make_shared<MockSuffixArray>();      for (int i = 0; i < 10; ++i) {        EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); @@ -23,51 +26,54 @@ class SamplerTest : public Test {    shared_ptr<MockSuffixArray> suffix_array;    shared_ptr<Sampler> sampler; +  shared_ptr<MockDataArray> source_data_array;  };  TEST_F(SamplerTest, TestSuffixArrayRange) {    PhraseLocation location(0, 10); +  unordered_set<int> blacklist;    sampler = make_shared<Sampler>(suffix_array, 1);    vector<int> expected_locations = {0}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 2);    expected_locations = {0, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 3);    expected_locations = {0, 3, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 4);    expected_locations = {0, 3, 5, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 100);    expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));  }  TEST_F(SamplerTest, TestSubstringsSample) {    vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  unordered_set<int> blacklist;    PhraseLocation location(locations, 2);    sampler = make_shared<Sampler>(suffix_array, 1);    vector<int> expected_locations = {0, 1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 2);    expected_locations = {0, 1, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 3);    expected_locations = {0, 1, 4, 5, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));    sampler = make_shared<Sampler>(suffix_array, 7);    expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));  }  } // namespace diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc new file mode 100644 index 00000000..3305b990 --- /dev/null +++ b/extractor/sampler_test_blacklist.cc @@ -0,0 +1,102 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_data_array.h" +#include "phrase_location.h" +#include "sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SamplerTestBlacklist : public Test { + protected: +  virtual void SetUp() { +    source_data_array = make_shared<MockDataArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); +    } +    for (int i = -10; i < 0; ++i) { +      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); +    } +    suffix_array = make_shared<MockSuffixArray>(); +    for (int i = -10; i < 10; ++i) { +      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); +    } +  } + +  shared_ptr<MockSuffixArray> suffix_array; +  shared_ptr<Sampler> sampler; +  shared_ptr<MockDataArray> source_data_array; +}; + +TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { +  PhraseLocation location(0, 10); +  unordered_set<int> blacklist; +  vector<int> expected_locations; +    +  blacklist.insert(0); +  sampler = make_shared<Sampler>(suffix_array, 1); +  expected_locations = {1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); +   +  for (int i = 0; i < 9; i++) { +    blacklist.insert(i); +  } +  sampler = make_shared<Sampler>(suffix_array, 1); +  expected_locations = {9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); + +  blacklist.insert(0); +  blacklist.insert(5); +  sampler = make_shared<Sampler>(suffix_array, 2); +  expected_locations = {1, 4}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); + +  blacklist.insert(0); +  blacklist.insert(1); +  blacklist.insert(2); +  blacklist.insert(3); +  sampler = make_shared<Sampler>(suffix_array, 2); +  expected_locations = {4, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); + +  blacklist.insert(0); +  blacklist.insert(3); +  blacklist.insert(7); +  sampler = make_shared<Sampler>(suffix_array, 3); +  expected_locations = {1, 2, 6}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); + +  blacklist.insert(0); +  blacklist.insert(3); +  blacklist.insert(5); +  blacklist.insert(8); +  sampler = make_shared<Sampler>(suffix_array, 4); +  expected_locations = {1, 2, 4, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); +   +  blacklist.insert(0); +  sampler = make_shared<Sampler>(suffix_array, 100); +  expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +  blacklist.clear(); +  +  blacklist.insert(9); +  sampler = make_shared<Sampler>(suffix_array, 100); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +} + +} // namespace +} // namespace extractor  | 
