From 3973a7e4a8302b4a02fee7d2950bb469b37e2452 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Sun, 24 Nov 2013 13:19:28 +0000 Subject: Reduce memory overhead for constructing the intersector. --- extractor/grammar_extractor.h | 1 + 1 file changed, 1 insertion(+) (limited to 'extractor/grammar_extractor.h') diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index ae407b47..8f570df2 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -32,6 +32,7 @@ class GrammarExtractor { shared_ptr alignment, shared_ptr precomputation, shared_ptr scorer, + shared_ptr vocabulary, int min_gap_size, int max_rule_span, int max_nonterminals, -- cgit v1.2.3 From 3c73e472444ff0cd436b12f3679440a6969cbf2d Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 25 Nov 2013 23:56:31 +0000 Subject: Clean up leave-one-out sampling. --- extractor/grammar_extractor.cc | 6 ++++-- extractor/grammar_extractor.h | 4 +++- extractor/grammar_extractor_test.cc | 4 ++-- extractor/mocks/mock_rule_factory.h | 6 +++--- extractor/mocks/mock_sampler.h | 4 +++- extractor/rule_factory.cc | 7 +++++-- extractor/rule_factory.h | 3 +-- extractor/rule_factory_test.cc | 8 +++----- extractor/run_extractor.cc | 3 ++- extractor/sampler.cc | 12 ++++++++---- extractor/sampler.h | 4 +++- extractor/sampler_test.cc | 30 +++++++++++++++++++++--------- 12 files changed, 58 insertions(+), 33 deletions(-) (limited to 'extractor/grammar_extractor.h') diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 4d0738f7..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) { +Grammar GrammarExtractor::GetGrammar( + const string& sentence, + const unordered_set& blacklisted_sentence_ids) { vector words = TokenizeSentence(sentence); vector word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids); } vector GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 8f570df2..eb79f53c 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -46,7 +46,9 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array); + Grammar GetGrammar( + const string& sentence, + const unordered_set& blacklisted_sentence_ids); private: // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { Grammar grammar(rules, feature_names); unordered_set blacklisted_sentence_ids; shared_ptr source_data_array; - EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) + EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids)) .WillOnce(Return(grammar)); GrammarExtractor extractor(vocabulary, factory); string sentence = "Anna has many many apples ."; - extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); + extractor.GetGrammar(sentence, blacklisted_sentence_ids); } } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 6b7b6586..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,9 +7,9 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD3(GetGrammar, Grammar(const vector& word_ids, const - unordered_set& blacklisted_sentence_ids, - const shared_ptr source_data_array)); + MOCK_METHOD2(GetGrammar, Grammar( + const vector& word_ids, + const unordered_set& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor { class MockSampler : public Sampler { public: - MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..5b66f685 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( + const vector& word_ids, + const unordered_set& blacklisted_sentence_ids) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids, const u Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); + PhraseLocation sample = sampler->Sample( + next_node->matchings, blacklisted_sentence_ids); vector new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a1ff76e4..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -74,8 +74,7 @@ class HieroCachingRuleFactory { // (See class description for more details.) virtual Grammar GetGrammar( const vector& word_ids, - const unordered_set& blacklisted_sentence_ids, - const shared_ptr source_data_array); + const unordered_set& blacklisted_sentence_ids); protected: HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test { .WillRepeatedly(Return(feature_names)); sampler = make_shared(); - EXPECT_CALL(*sampler, Sample(_)) + EXPECT_CALL(*sampler, Sample(_, _)) .WillRepeatedly(Return(PhraseLocation(0, 1))); Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { vector word_ids = {2, 3, 4}; unordered_set blacklisted_sentence_ids; - shared_ptr source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { vector word_ids = {2, 3, 4, 2, 3}; unordered_set blacklisted_sentence_ids; - shared_ptr source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 85c8a422..6b22a302 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -237,7 +237,8 @@ int main(int argc, char** argv) { unordered_set blacklisted_sentence_ids; if (leave_one_out) blacklisted_sentence_ids.insert(i); - Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..fc386ed1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,9 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) const { +PhraseLocation Sampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { vector sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double step = max(1.0, (double) (high - low) / max_samples); double i = low, last = i; bool found; + shared_ptr source_data_array = suffix_array->GetData(); while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); - if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + if (blacklisted_sentence_ids.count(id)) { found = false; double backoff_step = 1; while (true) { @@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { found = true; last = i; break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (k < min(i+step, (double)high) && + !blacklisted_sentence_ids.count(id)) { found = true; last = k; break; } if (j <= last && k >= high) break; diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..bd8a5876 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -23,7 +23,9 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) const; + virtual PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; protected: Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 965567ba..14e72780 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -19,6 +19,8 @@ class SamplerTest : public Test { source_data_array = make_shared(); EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(source_data_array)); for (int i = 0; i < 10; ++i) { EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); } @@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) { sampler = make_shared(suffix_array, 1); vector expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); + return; sampler = make_shared(suffix_array, 2); expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 3); expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 4); expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 100); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); } TEST_F(SamplerTest, TestSubstringsSample) { @@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) { sampler = make_shared(suffix_array, 1); vector expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 2); expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 3); expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 7); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); } } // namespace -- cgit v1.2.3 From bed3e4b867e4132917fa0640956e8ce713f0e451 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Tue, 26 Nov 2013 15:01:14 +0000 Subject: Script for grammar extraction only. --- .gitignore | 1 + extractor/Makefile.am | 40 +------ extractor/extract.cc | 253 ++++++++++++++++++++++++++++++++++++++++++ extractor/grammar_extractor.h | 1 - extractor/run_extractor.cc | 20 ++-- extractor/sampler.cc | 23 ++-- 6 files changed, 278 insertions(+), 60 deletions(-) create mode 100644 extractor/extract.cc (limited to 'extractor/grammar_extractor.h') diff --git a/.gitignore b/.gitignore index 942539cb..f964fa0c 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,7 @@ extools/score_grammar extools/sg_lexer.cc extractor/*_test extractor/compile +extractor/extract extractor/run_extractor gi/clda/src/clda gi/markov_al/ml diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 64a5a2b5..7825012c 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,5 +1,5 @@ -bin_PROGRAMS = compile run_extractor +bin_PROGRAMS = compile run_extractor extract if HAVE_CXX11 @@ -105,44 +105,14 @@ translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $ vocabulary_test_SOURCES = vocabulary_test.cc vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a -noinst_LIBRARIES = libextractor.a libcompile.a +noinst_LIBRARIES = libextractor.a compile_SOURCES = compile.cc -compile_LDADD = libcompile.a +compile_LDADD = libextractor.a run_extractor_SOURCES = run_extractor.cc run_extractor_LDADD = libextractor.a - -libcompile_a_SOURCES = \ - alignment.cc \ - data_array.cc \ - phrase_location.cc \ - precomputation.cc \ - suffix_array.cc \ - time_util.cc \ - translation_table.cc \ - vocabulary.cc \ - alignment.h \ - data_array.h \ - fast_intersector.h \ - grammar.h \ - grammar_extractor.h \ - matchings_finder.h \ - matchings_trie.h \ - phrase.h \ - phrase_builder.h \ - phrase_location.h \ - precomputation.h \ - rule.h \ - rule_extractor.h \ - rule_extractor_helper.h \ - rule_factory.h \ - sampler.h \ - scorer.h \ - suffix_array.h \ - target_phrase_extractor.h \ - time_util.h \ - translation_table.h \ - vocabulary.h +extract_SOURCES = extract.cc +extract_LDADD = libextractor.a libextractor_a_SOURCES = \ alignment.cc \ diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..2d5831fa --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { + string file_name = "grammar." + to_string(file_number); + return grammar_path / file_name; +} + +int main(int argc, char** argv) { + po::options_description general_options("General options"); + int max_threads = 1; + #pragma omp parallel + max_threads = omp_get_num_threads(); + string threads_option = "Number of threads used for grammar extraction " + "max(" + to_string(max_threads) + ")"; + general_options.add_options() + ("threads,t", po::value()->required()->default_value(1), + threads_option.c_str()) + ("grammars,g", po::value()->required(), "Grammars output path") + ("max_rule_span", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size", po::value()->default_value(1), "Minimum gap size") + ("max_nonterminals", po::value()->default_value(2), + "Maximum number of nonterminals in a rule") + ("max_samples", po::value()->default_value(300), + "Maximum number of samples") + ("tight_phrases", po::value()->default_value(true), + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value()->zero_tokens(), + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); + + po::options_description cmdline_options("Command line options"); + cmdline_options.add_options() + ("help", "Show available options") + ("config", po::value()->required(), "Path to config file"); + cmdline_options.add(general_options); + + po::options_description config_options("Config file options"); + config_options.add_options() + ("target", po::value()->required(), + "Path to target data file in binary format") + ("source", po::value()->required(), + "Path to source suffix array file in binary format") + ("alignment", po::value()->required(), + "Path to alignment file in binary format") + ("precomputation", po::value()->required(), + "Path to precomputation file in binary format") + ("vocabulary", po::value()->required(), + "Path to vocabulary file in binary format") + ("ttable", po::value()->required(), + "Path to translation table in binary format"); + config_options.add(general_options); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, cmdline_options), vm); + if (vm.count("help")) { + po::options_description all_options; + all_options.add(cmdline_options).add(config_options); + cout << all_options << endl; + return 0; + } + + po::notify(vm); + + ifstream config_stream(vm["config"].as()); + po::store(po::parse_config_file(config_stream, config_options), vm); + po::notify(vm); + + int num_threads = vm["threads"].as(); + cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + + Clock::time_point read_start_time = Clock::now(); + + Clock::time_point start_time = Clock::now(); + cerr << "Reading target data in binary format..." << endl; + shared_ptr target_data_array = make_shared(); + ifstream target_fstream(vm["target"].as()); + ar::binary_iarchive target_stream(target_fstream); + target_stream >> *target_data_array; + Clock::time_point end_time = Clock::now(); + cerr << "Reading target data took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading source suffix array in binary format..." << endl; + shared_ptr source_suffix_array = make_shared(); + ifstream source_fstream(vm["source"].as()); + ar::binary_iarchive source_stream(source_fstream); + source_stream >> *source_suffix_array; + end_time = Clock::now(); + cerr << "Reading source suffix array took " + << GetDuration(start_time, end_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading alignment in binary format..." << endl; + shared_ptr alignment = make_shared(); + ifstream alignment_fstream(vm["alignment"].as()); + ar::binary_iarchive alignment_stream(alignment_fstream); + alignment_stream >> *alignment; + end_time = Clock::now(); + cerr << "Reading alignment took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading precomputation in binary format..." << endl; + shared_ptr precomputation = make_shared(); + ifstream precomputation_fstream(vm["precomputation"].as()); + ar::binary_iarchive precomputation_stream(precomputation_fstream); + precomputation_stream >> *precomputation; + end_time = Clock::now(); + cerr << "Reading precomputation took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading vocabulary in binary format..." << endl; + shared_ptr vocabulary = make_shared(); + ifstream vocabulary_fstream(vm["vocabulary"].as()); + ar::binary_iarchive vocabulary_stream(vocabulary_fstream); + vocabulary_stream >> *vocabulary; + end_time = Clock::now(); + cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading translation table in binary format..." << endl; + shared_ptr table = make_shared(); + ifstream ttable_fstream(vm["ttable"].as()); + ar::binary_iarchive ttable_stream(ttable_fstream); + ttable_stream >> *table; + end_time = Clock::now(); + cerr << "Reading translation table took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + Clock::time_point read_end_time = Clock::now(); + cerr << "Total time spent loading data structures into memory: " + << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + + Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. + vector> features = { + make_shared(), + make_shared(), + make_shared(), + make_shared(table), + make_shared(table), + make_shared(), + make_shared() + }; + shared_ptr scorer = make_shared(features); + + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + scorer, + vocabulary, + vm["min_gap_size"].as(), + vm["max_rule_span"].as(), + vm["max_nonterminals"].as(), + vm["max_rule_symbols"].as(), + vm["max_samples"].as(), + vm["tight_phrases"].as()); + + // Creates the grammars directory if it doesn't exist. + fs::path grammar_path = vm["grammars"].as(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). + string sentence; + vector sentences; + while (getline(cin, sentence)) { + sentences.push_back(sentence); + } + + // Extracts the grammar for each sentence and saves it to a file. + vector suffixes(sentences.size()); + bool leave_one_out = vm.count("leave_one_out"); + #pragma omp parallel for schedule(dynamic) num_threads(num_threads) + for (size_t i = 0; i < sentences.size(); ++i) { + string suffix; + int position = sentences[i].find("|||"); + if (position != sentences[i].npos) { + suffix = sentences[i].substr(position); + sentences[i] = sentences[i].substr(0, position); + } + suffixes[i] = suffix; + + unordered_set blacklisted_sentence_ids; + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); + ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); + // output << grammar; + } + + for (size_t i = 0; i < sentences.size(); ++i) { + cout << " " << sentences[i] << " " << suffixes[i] << endl; + } + + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; + + return 0; +} diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index eb79f53c..0f3069b0 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -15,7 +15,6 @@ class DataArray; class Grammar; class HieroCachingRuleFactory; class Precomputation; -class Rule; class Scorer; class SuffixArray; class Vocabulary; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6b22a302..f1aa5e35 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,10 +5,10 @@ #include #include -#include #include #include #include +#include #include "alignment.h" #include "data_array.h" @@ -78,7 +78,8 @@ int main(int argc, char** argv) { ("tight_phrases", po::value()->default_value(true), "False if phrases may be loose (better, but slower)") ("leave_one_out", po::value()->zero_tokens(), - "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -99,11 +100,6 @@ int main(int argc, char** argv) { return 1; } - bool leave_one_out = false; - if (vm.count("leave_one_out")) { - leave_one_out = true; - } - int num_threads = vm["threads"].as(); cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -178,8 +174,8 @@ int main(int argc, char** argv) { << GetDuration(preprocess_start_time, preprocess_stop_time) << " seconds" << endl; - // Features used to score each grammar rule. Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. vector> features = { make_shared(), make_shared(), @@ -206,9 +202,6 @@ int main(int argc, char** argv) { vm["max_samples"].as(), vm["tight_phrases"].as()); - // Releases extra memory used by the initial precomputation. - precomputation.reset(); - // Creates the grammars directory if it doesn't exist. fs::path grammar_path = vm["grammars"].as(); if (!fs::is_directory(grammar_path)) { @@ -224,6 +217,7 @@ int main(int argc, char** argv) { } // Extracts the grammar for each sentence and saves it to a file. + bool leave_one_out = vm.count("leave_one_out"); vector suffixes(sentences.size()); #pragma omp parallel for schedule(dynamic) num_threads(num_threads) for (size_t i = 0; i < sentences.size(); ++i) { @@ -236,7 +230,9 @@ int main(int argc, char** argv) { suffixes[i] = suffix; unordered_set blacklisted_sentence_ids; - if (leave_one_out) blacklisted_sentence_ids.insert(i); + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } Grammar grammar = extractor.GetGrammar( sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index fc386ed1..887aaec1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -15,6 +15,7 @@ Sampler::~Sampler() {} PhraseLocation Sampler::Sample( const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids) const { + shared_ptr source_data_array = suffix_array->GetData(); vector sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,32 +23,30 @@ PhraseLocation Sampler::Sample( num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); - double i = low, last = i; - bool found; - shared_ptr source_data_array = suffix_array->GetData(); + double i = low, last = i - 1; while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); + bool found = false; if (blacklisted_sentence_ids.count(id)) { - found = false; - double backoff_step = 1; - while (true) { - if ((double)backoff_step >= step) break; + for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { - found = true; last = i; break; + found = true; + last = i; + break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && + if (k < min(i+step, (double) high) && !blacklisted_sentence_ids.count(id)) { - found = true; last = k; break; + found = true; + last = k; + break; } - if (j <= last && k >= high) break; - backoff_step++; } } else { found = true; -- cgit v1.2.3