From bed3e4b867e4132917fa0640956e8ce713f0e451 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Tue, 26 Nov 2013 15:01:14 +0000 Subject: Script for grammar extraction only. --- extractor/extract.cc | 253 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 extractor/extract.cc (limited to 'extractor/extract.cc') diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..2d5831fa --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { + string file_name = "grammar." + to_string(file_number); + return grammar_path / file_name; +} + +int main(int argc, char** argv) { + po::options_description general_options("General options"); + int max_threads = 1; + #pragma omp parallel + max_threads = omp_get_num_threads(); + string threads_option = "Number of threads used for grammar extraction " + "max(" + to_string(max_threads) + ")"; + general_options.add_options() + ("threads,t", po::value()->required()->default_value(1), + threads_option.c_str()) + ("grammars,g", po::value()->required(), "Grammars output path") + ("max_rule_span", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size", po::value()->default_value(1), "Minimum gap size") + ("max_nonterminals", po::value()->default_value(2), + "Maximum number of nonterminals in a rule") + ("max_samples", po::value()->default_value(300), + "Maximum number of samples") + ("tight_phrases", po::value()->default_value(true), + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value()->zero_tokens(), + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); + + po::options_description cmdline_options("Command line options"); + cmdline_options.add_options() + ("help", "Show available options") + ("config", po::value()->required(), "Path to config file"); + cmdline_options.add(general_options); + + po::options_description config_options("Config file options"); + config_options.add_options() + ("target", po::value()->required(), + "Path to target data file in binary format") + ("source", po::value()->required(), + "Path to source suffix array file in binary format") + ("alignment", po::value()->required(), + "Path to alignment file in binary format") + ("precomputation", po::value()->required(), + "Path to precomputation file in binary format") + ("vocabulary", po::value()->required(), + "Path to vocabulary file in binary format") + ("ttable", po::value()->required(), + "Path to translation table in binary format"); + config_options.add(general_options); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, cmdline_options), vm); + if (vm.count("help")) { + po::options_description all_options; + all_options.add(cmdline_options).add(config_options); + cout << all_options << endl; + return 0; + } + + po::notify(vm); + + ifstream config_stream(vm["config"].as()); + po::store(po::parse_config_file(config_stream, config_options), vm); + po::notify(vm); + + int num_threads = vm["threads"].as(); + cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + + Clock::time_point read_start_time = Clock::now(); + + Clock::time_point start_time = Clock::now(); + cerr << "Reading target data in binary format..." << endl; + shared_ptr target_data_array = make_shared(); + ifstream target_fstream(vm["target"].as()); + ar::binary_iarchive target_stream(target_fstream); + target_stream >> *target_data_array; + Clock::time_point end_time = Clock::now(); + cerr << "Reading target data took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading source suffix array in binary format..." << endl; + shared_ptr source_suffix_array = make_shared(); + ifstream source_fstream(vm["source"].as()); + ar::binary_iarchive source_stream(source_fstream); + source_stream >> *source_suffix_array; + end_time = Clock::now(); + cerr << "Reading source suffix array took " + << GetDuration(start_time, end_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading alignment in binary format..." << endl; + shared_ptr alignment = make_shared(); + ifstream alignment_fstream(vm["alignment"].as()); + ar::binary_iarchive alignment_stream(alignment_fstream); + alignment_stream >> *alignment; + end_time = Clock::now(); + cerr << "Reading alignment took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading precomputation in binary format..." << endl; + shared_ptr precomputation = make_shared(); + ifstream precomputation_fstream(vm["precomputation"].as()); + ar::binary_iarchive precomputation_stream(precomputation_fstream); + precomputation_stream >> *precomputation; + end_time = Clock::now(); + cerr << "Reading precomputation took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading vocabulary in binary format..." << endl; + shared_ptr vocabulary = make_shared(); + ifstream vocabulary_fstream(vm["vocabulary"].as()); + ar::binary_iarchive vocabulary_stream(vocabulary_fstream); + vocabulary_stream >> *vocabulary; + end_time = Clock::now(); + cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading translation table in binary format..." << endl; + shared_ptr table = make_shared(); + ifstream ttable_fstream(vm["ttable"].as()); + ar::binary_iarchive ttable_stream(ttable_fstream); + ttable_stream >> *table; + end_time = Clock::now(); + cerr << "Reading translation table took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + Clock::time_point read_end_time = Clock::now(); + cerr << "Total time spent loading data structures into memory: " + << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + + Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. + vector> features = { + make_shared(), + make_shared(), + make_shared(), + make_shared(table), + make_shared(table), + make_shared(), + make_shared() + }; + shared_ptr scorer = make_shared(features); + + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + scorer, + vocabulary, + vm["min_gap_size"].as(), + vm["max_rule_span"].as(), + vm["max_nonterminals"].as(), + vm["max_rule_symbols"].as(), + vm["max_samples"].as(), + vm["tight_phrases"].as()); + + // Creates the grammars directory if it doesn't exist. + fs::path grammar_path = vm["grammars"].as(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). + string sentence; + vector sentences; + while (getline(cin, sentence)) { + sentences.push_back(sentence); + } + + // Extracts the grammar for each sentence and saves it to a file. + vector suffixes(sentences.size()); + bool leave_one_out = vm.count("leave_one_out"); + #pragma omp parallel for schedule(dynamic) num_threads(num_threads) + for (size_t i = 0; i < sentences.size(); ++i) { + string suffix; + int position = sentences[i].find("|||"); + if (position != sentences[i].npos) { + suffix = sentences[i].substr(position); + sentences[i] = sentences[i].substr(0, position); + } + suffixes[i] = suffix; + + unordered_set blacklisted_sentence_ids; + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); + ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); + // output << grammar; + } + + for (size_t i = 0; i < sentences.size(); ++i) { + cout << " " << sentences[i] << " " << suffixes[i] << endl; + } + + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; + + return 0; +} -- cgit v1.2.3 From 304103565d3b79cc9c98c1ee0356a8824fc982c2 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Tue, 26 Nov 2013 16:03:16 +0000 Subject: Write config file after compiling data structures. --- extractor/compile.cc | 30 +++++++++++++++++++++++------- extractor/extract.cc | 4 ++-- 2 files changed, 25 insertions(+), 9 deletions(-) (limited to 'extractor/extract.cc') diff --git a/extractor/compile.cc b/extractor/compile.cc index 9e8044ad..3ee668ce 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -30,6 +30,8 @@ int main(int argc, char** argv) { ("bitext,b", po::value(), "Parallel text (source ||| target)") ("alignment,a", po::value()->required(), "Bitext word alignment") ("output,o", po::value()->required(), "Output path") + ("config,c", po::value()->required(), + "Path where the config file will be generated") ("frequent", po::value()->default_value(100), "Number of precomputed frequent patterns") ("super_frequent", po::value()->default_value(10), @@ -82,8 +84,12 @@ int main(int argc, char** argv) { target_data_array = make_shared(vm["target"].as()); } + ofstream config_stream(vm["config"].as()); + Clock::time_point start_write = Clock::now(); - ofstream target_fstream((output_dir / fs::path("target.bin")).string()); + string target_path = (output_dir / fs::path("target.bin")).string(); + config_stream << "target = " << target_path << endl; + ofstream target_fstream(target_path); ar::binary_oarchive target_stream(target_fstream); target_stream << *target_data_array; Clock::time_point stop_write = Clock::now(); @@ -100,7 +106,9 @@ int main(int argc, char** argv) { make_shared(source_data_array); start_write = Clock::now(); - ofstream source_fstream((output_dir / fs::path("source.bin")).string()); + string source_path = (output_dir / fs::path("source.bin")).string(); + config_stream << "source = " << source_path << endl; + ofstream source_fstream(source_path); ar::binary_oarchive output_stream(source_fstream); output_stream << *source_suffix_array; stop_write = Clock::now(); @@ -116,7 +124,9 @@ int main(int argc, char** argv) { make_shared(vm["alignment"].as()); start_write = Clock::now(); - ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string()); + string alignment_path = (output_dir / fs::path("alignment.bin")).string(); + config_stream << "alignment = " << alignment_path << endl; + ofstream alignment_fstream(alignment_path); ar::binary_oarchive alignment_stream(alignment_fstream); alignment_stream << *alignment; stop_write = Clock::now(); @@ -126,7 +136,7 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; - shared_ptr vocabulary; + shared_ptr vocabulary = make_shared(); start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; @@ -142,11 +152,15 @@ int main(int argc, char** argv) { vm["min_frequency"].as()); start_write = Clock::now(); - ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string()); + string precomputation_path = (output_dir / fs::path("precomp.bin")).string(); + config_stream << "precomputation = " << precomputation_path << endl; + ofstream precomp_fstream(precomputation_path); ar::binary_oarchive precomp_stream(precomp_fstream); precomp_stream << precomputation; - ofstream vocab_fstream((output_dir / fs::path("vocab.bin")).string()); + string vocabulary_path = (output_dir / fs::path("vocab.bin")).string(); + config_stream << "vocabulary = " << vocabulary_path << endl; + ofstream vocab_fstream(vocabulary_path); ar::binary_oarchive vocab_stream(vocab_fstream); vocab_stream << *vocabulary; stop_write = Clock::now(); @@ -161,7 +175,9 @@ int main(int argc, char** argv) { TranslationTable table(source_data_array, target_data_array, alignment); start_write = Clock::now(); - ofstream table_fstream((output_dir / fs::path("bilex.bin")).string()); + string table_path = (output_dir / fs::path("bilex.bin")).string(); + config_stream << "ttable = " << table_path << endl; + ofstream table_fstream(table_path); ar::binary_oarchive table_stream(table_fstream); table_stream << table; stop_write = Clock::now(); diff --git a/extractor/extract.cc b/extractor/extract.cc index 2d5831fa..387cbe9b 100644 --- a/extractor/extract.cc +++ b/extractor/extract.cc @@ -72,7 +72,7 @@ int main(int argc, char** argv) { po::options_description cmdline_options("Command line options"); cmdline_options.add_options() ("help", "Show available options") - ("config", po::value()->required(), "Path to config file"); + ("config,c", po::value()->required(), "Path to config file"); cmdline_options.add(general_options); po::options_description config_options("Config file options"); @@ -236,7 +236,7 @@ int main(int argc, char** argv) { Grammar grammar = extractor.GetGrammar( sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); - // output << grammar; + output << grammar; } for (size_t i = 0; i < sentences.size(); ++i) { -- cgit v1.2.3