diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-11-26 15:01:14 +0000 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-11-26 15:01:14 +0000 | 
| commit | bed3e4b867e4132917fa0640956e8ce713f0e451 (patch) | |
| tree | a442aa2233c85313aebcf364ec0a0804922d2db7 /extractor | |
| parent | e633526bc2ba1f73e88989f495d70c0d2ec84a97 (diff) | |
Script for grammar extraction only.
Diffstat (limited to 'extractor')
| -rw-r--r-- | extractor/Makefile.am | 40 | ||||
| -rw-r--r-- | extractor/extract.cc | 253 | ||||
| -rw-r--r-- | extractor/grammar_extractor.h | 1 | ||||
| -rw-r--r-- | extractor/run_extractor.cc | 20 | ||||
| -rw-r--r-- | extractor/sampler.cc | 23 | 
5 files changed, 277 insertions, 60 deletions
| diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 64a5a2b5..7825012c 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,5 +1,5 @@ -bin_PROGRAMS = compile run_extractor +bin_PROGRAMS = compile run_extractor extract  if HAVE_CXX11 @@ -105,44 +105,14 @@ translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $  vocabulary_test_SOURCES = vocabulary_test.cc  vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a -noinst_LIBRARIES = libextractor.a libcompile.a +noinst_LIBRARIES = libextractor.a  compile_SOURCES = compile.cc -compile_LDADD = libcompile.a +compile_LDADD = libextractor.a  run_extractor_SOURCES = run_extractor.cc  run_extractor_LDADD = libextractor.a - -libcompile_a_SOURCES = \ -  alignment.cc \ -  data_array.cc \ -  phrase_location.cc \ -  precomputation.cc \ -  suffix_array.cc \ -  time_util.cc \ -  translation_table.cc \ -  vocabulary.cc \ -  alignment.h \ -  data_array.h \ -  fast_intersector.h \ -  grammar.h \ -  grammar_extractor.h \ -  matchings_finder.h \ -  matchings_trie.h \ -  phrase.h \ -  phrase_builder.h \ -  phrase_location.h \ -  precomputation.h \ -  rule.h \ -  rule_extractor.h \ -  rule_extractor_helper.h \ -  rule_factory.h \ -  sampler.h \ -  scorer.h \ -  suffix_array.h \ -  target_phrase_extractor.h \ -  time_util.h \ -  translation_table.h \ -  vocabulary.h +extract_SOURCES = extract.cc +extract_LDADD = libextractor.a  libextractor_a_SOURCES = \    alignment.cc \ diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..2d5831fa --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include <fstream> +#include <iostream> +#include <memory> +#include <string> +#include <vector> + +#include <boost/archive/binary_iarchive.hpp> +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include <omp.h> + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { +  string file_name = "grammar." + to_string(file_number); +  return grammar_path / file_name; +} + +int main(int argc, char** argv) { +  po::options_description general_options("General options"); +  int max_threads = 1; +  #pragma omp parallel +  max_threads = omp_get_num_threads(); +  string threads_option = "Number of threads used for grammar extraction " +                          "max(" + to_string(max_threads) + ")"; +  general_options.add_options() +    ("threads,t", po::value<int>()->required()->default_value(1), +     threads_option.c_str()) +    ("grammars,g", po::value<string>()->required(), "Grammars output path") +    ("max_rule_span", po::value<int>()->default_value(15), +        "Maximum rule span") +    ("max_rule_symbols", po::value<int>()->default_value(5), +        "Maximum number of symbols (terminals + nontermals) in a rule") +    ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size") +    ("max_nonterminals", po::value<int>()->default_value(2), +        "Maximum number of nonterminals in a rule") +    ("max_samples", po::value<int>()->default_value(300), +        "Maximum number of samples") +    ("tight_phrases", po::value<bool>()->default_value(true), +        "False if phrases may be loose (better, but slower)") +    ("leave_one_out", po::value<bool>()->zero_tokens(), +        "do leave-one-out estimation of grammars " +        "(e.g. for extracting grammars for the training set"); + +  po::options_description cmdline_options("Command line options"); +  cmdline_options.add_options() +    ("help", "Show available options") +    ("config", po::value<string>()->required(), "Path to config file"); +  cmdline_options.add(general_options); + +  po::options_description config_options("Config file options"); +  config_options.add_options() +    ("target", po::value<string>()->required(), +        "Path to target data file in binary format") +    ("source", po::value<string>()->required(), +        "Path to source suffix array file in binary format") +    ("alignment", po::value<string>()->required(), +        "Path to alignment file in binary format") +    ("precomputation", po::value<string>()->required(), +        "Path to precomputation file in binary format") +    ("vocabulary", po::value<string>()->required(), +        "Path to vocabulary file in binary format") +    ("ttable", po::value<string>()->required(), +        "Path to translation table in binary format"); +  config_options.add(general_options); + +  po::variables_map vm; +  po::store(po::parse_command_line(argc, argv, cmdline_options), vm); +  if (vm.count("help")) { +    po::options_description all_options; +    all_options.add(cmdline_options).add(config_options); +    cout << all_options << endl; +    return 0; +  } + +  po::notify(vm); + +  ifstream config_stream(vm["config"].as<string>()); +  po::store(po::parse_config_file(config_stream, config_options), vm); +  po::notify(vm); + +  int num_threads = vm["threads"].as<int>(); +  cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + +  Clock::time_point read_start_time = Clock::now(); + +  Clock::time_point start_time = Clock::now(); +  cerr << "Reading target data in binary format..." << endl; +  shared_ptr<DataArray> target_data_array = make_shared<DataArray>(); +  ifstream target_fstream(vm["target"].as<string>()); +  ar::binary_iarchive target_stream(target_fstream); +  target_stream >> *target_data_array; +  Clock::time_point end_time = Clock::now(); +  cerr << "Reading target data took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading source suffix array in binary format..." << endl; +  shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>(); +  ifstream source_fstream(vm["source"].as<string>()); +  ar::binary_iarchive source_stream(source_fstream); +  source_stream >> *source_suffix_array; +  end_time = Clock::now(); +  cerr << "Reading source suffix array took " +       << GetDuration(start_time, end_time) << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading alignment in binary format..." << endl; +  shared_ptr<Alignment> alignment = make_shared<Alignment>(); +  ifstream alignment_fstream(vm["alignment"].as<string>()); +  ar::binary_iarchive alignment_stream(alignment_fstream); +  alignment_stream >> *alignment; +  end_time = Clock::now(); +  cerr << "Reading alignment took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading precomputation in binary format..." << endl; +  shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(); +  ifstream precomputation_fstream(vm["precomputation"].as<string>()); +  ar::binary_iarchive precomputation_stream(precomputation_fstream); +  precomputation_stream >> *precomputation; +  end_time = Clock::now(); +  cerr << "Reading precomputation took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading vocabulary in binary format..." << endl; +  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); +  ifstream vocabulary_fstream(vm["vocabulary"].as<string>()); +  ar::binary_iarchive vocabulary_stream(vocabulary_fstream); +  vocabulary_stream >> *vocabulary; +  end_time = Clock::now(); +  cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading translation table in binary format..." << endl; +  shared_ptr<TranslationTable> table = make_shared<TranslationTable>(); +  ifstream ttable_fstream(vm["ttable"].as<string>()); +  ar::binary_iarchive ttable_stream(ttable_fstream); +  ttable_stream >> *table; +  end_time = Clock::now(); +  cerr << "Reading translation table took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  Clock::time_point read_end_time = Clock::now(); +  cerr << "Total time spent loading data structures into memory: " +       << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + +  Clock::time_point extraction_start_time = Clock::now(); +  // Features used to score each grammar rule. +  vector<shared_ptr<Feature>> features = { +      make_shared<TargetGivenSourceCoherent>(), +      make_shared<SampleSourceCount>(), +      make_shared<CountSourceTarget>(), +      make_shared<MaxLexSourceGivenTarget>(table), +      make_shared<MaxLexTargetGivenSource>(table), +      make_shared<IsSourceSingleton>(), +      make_shared<IsSourceTargetSingleton>() +  }; +  shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + +  GrammarExtractor extractor( +      source_suffix_array, +      target_data_array, +      alignment, +      precomputation, +      scorer, +      vocabulary, +      vm["min_gap_size"].as<int>(), +      vm["max_rule_span"].as<int>(), +      vm["max_nonterminals"].as<int>(), +      vm["max_rule_symbols"].as<int>(), +      vm["max_samples"].as<int>(), +      vm["tight_phrases"].as<bool>()); + +  // Creates the grammars directory if it doesn't exist. +  fs::path grammar_path = vm["grammars"].as<string>(); +  if (!fs::is_directory(grammar_path)) { +    fs::create_directory(grammar_path); +  } + +  // Reads all sentences for which we extract grammar rules (the paralellization +  // is simplified if we read all sentences upfront). +  string sentence; +  vector<string> sentences; +  while (getline(cin, sentence)) { +    sentences.push_back(sentence); +  } + +  // Extracts the grammar for each sentence and saves it to a file. +  vector<string> suffixes(sentences.size()); +  bool leave_one_out = vm.count("leave_one_out"); +  #pragma omp parallel for schedule(dynamic) num_threads(num_threads) +  for (size_t i = 0; i < sentences.size(); ++i) { +    string suffix; +    int position = sentences[i].find("|||"); +    if (position != sentences[i].npos) { +      suffix = sentences[i].substr(position); +      sentences[i] = sentences[i].substr(0, position); +    } +    suffixes[i] = suffix; + +    unordered_set<int> blacklisted_sentence_ids; +    if (leave_one_out) { +      blacklisted_sentence_ids.insert(i); +    } +    Grammar grammar = extractor.GetGrammar( +        sentences[i], blacklisted_sentence_ids); +    ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); +    // output << grammar; +  } + +  for (size_t i = 0; i < sentences.size(); ++i) { +    cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\"" +         << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl; +  } + +  Clock::time_point extraction_stop_time = Clock::now(); +  cerr << "Overall extraction step took " +       << GetDuration(extraction_start_time, extraction_stop_time) +       << " seconds" << endl; + +  return 0; +} diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index eb79f53c..0f3069b0 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -15,7 +15,6 @@ class DataArray;  class Grammar;  class HieroCachingRuleFactory;  class Precomputation; -class Rule;  class Scorer;  class SuffixArray;  class Vocabulary; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6b22a302..f1aa5e35 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,10 +5,10 @@  #include <string>  #include <vector> -#include <omp.h>  #include <boost/filesystem.hpp>  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> +#include <omp.h>  #include "alignment.h"  #include "data_array.h" @@ -78,7 +78,8 @@ int main(int argc, char** argv) {      ("tight_phrases", po::value<bool>()->default_value(true),          "False if phrases may be loose (better, but slower)")      ("leave_one_out", po::value<bool>()->zero_tokens(), -        "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); +        "do leave-one-out estimation of grammars " +        "(e.g. for extracting grammars for the training set");    po::variables_map vm;    po::store(po::parse_command_line(argc, argv, desc), vm); @@ -99,11 +100,6 @@ int main(int argc, char** argv) {      return 1;    } -  bool leave_one_out = false; -  if (vm.count("leave_one_out")) { -    leave_one_out = true; -  } -    int num_threads = vm["threads"].as<int>();    cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -178,8 +174,8 @@ int main(int argc, char** argv) {         << GetDuration(preprocess_start_time, preprocess_stop_time)         << " seconds" << endl; -  // Features used to score each grammar rule.    Clock::time_point extraction_start_time = Clock::now(); +  // Features used to score each grammar rule.    vector<shared_ptr<Feature>> features = {        make_shared<TargetGivenSourceCoherent>(),        make_shared<SampleSourceCount>(), @@ -206,9 +202,6 @@ int main(int argc, char** argv) {        vm["max_samples"].as<int>(),        vm["tight_phrases"].as<bool>()); -  // Releases extra memory used by the initial precomputation. -  precomputation.reset(); -    // Creates the grammars directory if it doesn't exist.    fs::path grammar_path = vm["grammars"].as<string>();    if (!fs::is_directory(grammar_path)) { @@ -224,6 +217,7 @@ int main(int argc, char** argv) {    }    // Extracts the grammar for each sentence and saves it to a file. +  bool leave_one_out = vm.count("leave_one_out");    vector<string> suffixes(sentences.size());    #pragma omp parallel for schedule(dynamic) num_threads(num_threads)    for (size_t i = 0; i < sentences.size(); ++i) { @@ -236,7 +230,9 @@ int main(int argc, char** argv) {      suffixes[i] = suffix;      unordered_set<int> blacklisted_sentence_ids; -    if (leave_one_out) blacklisted_sentence_ids.insert(i); +    if (leave_one_out) { +      blacklisted_sentence_ids.insert(i); +    }      Grammar grammar = extractor.GetGrammar(          sentences[i], blacklisted_sentence_ids);      ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index fc386ed1..887aaec1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -15,6 +15,7 @@ Sampler::~Sampler() {}  PhraseLocation Sampler::Sample(      const PhraseLocation& location,      const unordered_set<int>& blacklisted_sentence_ids) const { +  shared_ptr<DataArray> source_data_array = suffix_array->GetData();    vector<int> sample;    int num_subpatterns;    if (location.matchings == NULL) { @@ -22,32 +23,30 @@ PhraseLocation Sampler::Sample(      num_subpatterns = 1;      int low = location.sa_low, high = location.sa_high;      double step = max(1.0, (double) (high - low) / max_samples); -    double i = low, last = i; -    bool found; -    shared_ptr<DataArray> source_data_array = suffix_array->GetData(); +    double i = low, last = i - 1;      while (sample.size() < max_samples && i < high) {        int x = suffix_array->GetSuffix(Round(i));        int id = source_data_array->GetSentenceId(x); +      bool found = false;        if (blacklisted_sentence_ids.count(id)) { -        found = false; -        double backoff_step = 1; -        while (true) { -          if ((double)backoff_step >= step) break; +        for (int backoff_step = 1; backoff_step <= step; ++backoff_step) {            double j = i - backoff_step;            x = suffix_array->GetSuffix(Round(j));            id = source_data_array->GetSentenceId(x);            if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { -            found = true; last = i; break; +            found = true; +            last = i; +            break;            }            double k = i + backoff_step;            x = suffix_array->GetSuffix(Round(k));            id = source_data_array->GetSentenceId(x); -          if (k < min(i+step, (double)high) && +          if (k < min(i+step, (double) high) &&                !blacklisted_sentence_ids.count(id)) { -            found = true; last = k; break; +            found = true; +            last = k; +            break;            } -          if (j <= last && k >= high) break; -          backoff_step++;          }        } else {          found = true; | 
