summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-26 15:01:14 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-26 15:01:14 +0000
commita3826db61847a55f59bb9666f61fd1bb88888085 (patch)
tree022475bafbf71ba6aaeb98efdbafcde24f7e60a5 /extractor
parent1cd86c44e1799c441cdcda2a022be0ee6e52d38c (diff)
Script for grammar extraction only.
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am40
-rw-r--r--extractor/extract.cc253
-rw-r--r--extractor/grammar_extractor.h1
-rw-r--r--extractor/run_extractor.cc20
-rw-r--r--extractor/sampler.cc23
5 files changed, 277 insertions, 60 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 64a5a2b5..7825012c 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -1,5 +1,5 @@
-bin_PROGRAMS = compile run_extractor
+bin_PROGRAMS = compile run_extractor extract
if HAVE_CXX11
@@ -105,44 +105,14 @@ translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $
vocabulary_test_SOURCES = vocabulary_test.cc
vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
-noinst_LIBRARIES = libextractor.a libcompile.a
+noinst_LIBRARIES = libextractor.a
compile_SOURCES = compile.cc
-compile_LDADD = libcompile.a
+compile_LDADD = libextractor.a
run_extractor_SOURCES = run_extractor.cc
run_extractor_LDADD = libextractor.a
-
-libcompile_a_SOURCES = \
- alignment.cc \
- data_array.cc \
- phrase_location.cc \
- precomputation.cc \
- suffix_array.cc \
- time_util.cc \
- translation_table.cc \
- vocabulary.cc \
- alignment.h \
- data_array.h \
- fast_intersector.h \
- grammar.h \
- grammar_extractor.h \
- matchings_finder.h \
- matchings_trie.h \
- phrase.h \
- phrase_builder.h \
- phrase_location.h \
- precomputation.h \
- rule.h \
- rule_extractor.h \
- rule_extractor_helper.h \
- rule_factory.h \
- sampler.h \
- scorer.h \
- suffix_array.h \
- target_phrase_extractor.h \
- time_util.h \
- translation_table.h \
- vocabulary.h
+extract_SOURCES = extract.cc
+extract_LDADD = libextractor.a
libextractor_a_SOURCES = \
alignment.cc \
diff --git a/extractor/extract.cc b/extractor/extract.cc
new file mode 100644
index 00000000..2d5831fa
--- /dev/null
+++ b/extractor/extract.cc
@@ -0,0 +1,253 @@
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "translation_table.h"
+#include "vocabulary.h"
+
+namespace ar = boost::archive;
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace extractor;
+using namespace features;
+using namespace std;
+
+// Returns the file path in which a given grammar should be written.
+fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
+ string file_name = "grammar." + to_string(file_number);
+ return grammar_path / file_name;
+}
+
+int main(int argc, char** argv) {
+ po::options_description general_options("General options");
+ int max_threads = 1;
+ #pragma omp parallel
+ max_threads = omp_get_num_threads();
+ string threads_option = "Number of threads used for grammar extraction "
+ "max(" + to_string(max_threads) + ")";
+ general_options.add_options()
+ ("threads,t", po::value<int>()->required()->default_value(1),
+ threads_option.c_str())
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("max_rule_span", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_nonterminals", po::value<int>()->default_value(2),
+ "Maximum number of nonterminals in a rule")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
+
+ po::options_description cmdline_options("Command line options");
+ cmdline_options.add_options()
+ ("help", "Show available options")
+ ("config", po::value<string>()->required(), "Path to config file");
+ cmdline_options.add(general_options);
+
+ po::options_description config_options("Config file options");
+ config_options.add_options()
+ ("target", po::value<string>()->required(),
+ "Path to target data file in binary format")
+ ("source", po::value<string>()->required(),
+ "Path to source suffix array file in binary format")
+ ("alignment", po::value<string>()->required(),
+ "Path to alignment file in binary format")
+ ("precomputation", po::value<string>()->required(),
+ "Path to precomputation file in binary format")
+ ("vocabulary", po::value<string>()->required(),
+ "Path to vocabulary file in binary format")
+ ("ttable", po::value<string>()->required(),
+ "Path to translation table in binary format");
+ config_options.add(general_options);
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, cmdline_options), vm);
+ if (vm.count("help")) {
+ po::options_description all_options;
+ all_options.add(cmdline_options).add(config_options);
+ cout << all_options << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ ifstream config_stream(vm["config"].as<string>());
+ po::store(po::parse_config_file(config_stream, config_options), vm);
+ po::notify(vm);
+
+ int num_threads = vm["threads"].as<int>();
+ cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
+
+ Clock::time_point read_start_time = Clock::now();
+
+ Clock::time_point start_time = Clock::now();
+ cerr << "Reading target data in binary format..." << endl;
+ shared_ptr<DataArray> target_data_array = make_shared<DataArray>();
+ ifstream target_fstream(vm["target"].as<string>());
+ ar::binary_iarchive target_stream(target_fstream);
+ target_stream >> *target_data_array;
+ Clock::time_point end_time = Clock::now();
+ cerr << "Reading target data took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading source suffix array in binary format..." << endl;
+ shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>();
+ ifstream source_fstream(vm["source"].as<string>());
+ ar::binary_iarchive source_stream(source_fstream);
+ source_stream >> *source_suffix_array;
+ end_time = Clock::now();
+ cerr << "Reading source suffix array took "
+ << GetDuration(start_time, end_time) << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading alignment in binary format..." << endl;
+ shared_ptr<Alignment> alignment = make_shared<Alignment>();
+ ifstream alignment_fstream(vm["alignment"].as<string>());
+ ar::binary_iarchive alignment_stream(alignment_fstream);
+ alignment_stream >> *alignment;
+ end_time = Clock::now();
+ cerr << "Reading alignment took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading precomputation in binary format..." << endl;
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>();
+ ifstream precomputation_fstream(vm["precomputation"].as<string>());
+ ar::binary_iarchive precomputation_stream(precomputation_fstream);
+ precomputation_stream >> *precomputation;
+ end_time = Clock::now();
+ cerr << "Reading precomputation took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading vocabulary in binary format..." << endl;
+ shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>();
+ ifstream vocabulary_fstream(vm["vocabulary"].as<string>());
+ ar::binary_iarchive vocabulary_stream(vocabulary_fstream);
+ vocabulary_stream >> *vocabulary;
+ end_time = Clock::now();
+ cerr << "Reading vocabulary took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Reading translation table in binary format..." << endl;
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>();
+ ifstream ttable_fstream(vm["ttable"].as<string>());
+ ar::binary_iarchive ttable_stream(ttable_fstream);
+ ttable_stream >> *table;
+ end_time = Clock::now();
+ cerr << "Reading translation table took " << GetDuration(start_time, end_time)
+ << " seconds" << endl;
+
+ Clock::time_point read_end_time = Clock::now();
+ cerr << "Total time spent loading data structures into memory: "
+ << GetDuration(read_start_time, read_end_time) << " seconds" << endl;
+
+ Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
+ vector<shared_ptr<Feature>> features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+
+ GrammarExtractor extractor(
+ source_suffix_array,
+ target_data_array,
+ alignment,
+ precomputation,
+ scorer,
+ vocabulary,
+ vm["min_gap_size"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_nonterminals"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["max_samples"].as<int>(),
+ vm["tight_phrases"].as<bool>());
+
+ // Creates the grammars directory if it doesn't exist.
+ fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
+ string sentence;
+ vector<string> sentences;
+ while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ // Extracts the grammar for each sentence and saves it to a file.
+ vector<string> suffixes(sentences.size());
+ bool leave_one_out = vm.count("leave_one_out");
+ #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string suffix;
+ int position = sentences[i].find("|||");
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
+ }
+ suffixes[i] = suffix;
+
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
+ Grammar grammar = extractor.GetGrammar(
+ sentences[i], blacklisted_sentence_ids);
+ ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
+ // output << grammar;
+ }
+
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""
+ << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
+ }
+
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
+
+ return 0;
+}
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index eb79f53c..0f3069b0 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -15,7 +15,6 @@ class DataArray;
class Grammar;
class HieroCachingRuleFactory;
class Precomputation;
-class Rule;
class Scorer;
class SuffixArray;
class Vocabulary;
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 6b22a302..f1aa5e35 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -5,10 +5,10 @@
#include <string>
#include <vector>
-#include <omp.h>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include <omp.h>
#include "alignment.h"
#include "data_array.h"
@@ -78,7 +78,8 @@ int main(int argc, char** argv) {
("tight_phrases", po::value<bool>()->default_value(true),
"False if phrases may be loose (better, but slower)")
("leave_one_out", po::value<bool>()->zero_tokens(),
- "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");
+ "do leave-one-out estimation of grammars "
+ "(e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -99,11 +100,6 @@ int main(int argc, char** argv) {
return 1;
}
- bool leave_one_out = false;
- if (vm.count("leave_one_out")) {
- leave_one_out = true;
- }
-
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -178,8 +174,8 @@ int main(int argc, char** argv) {
<< GetDuration(preprocess_start_time, preprocess_stop_time)
<< " seconds" << endl;
- // Features used to score each grammar rule.
Clock::time_point extraction_start_time = Clock::now();
+ // Features used to score each grammar rule.
vector<shared_ptr<Feature>> features = {
make_shared<TargetGivenSourceCoherent>(),
make_shared<SampleSourceCount>(),
@@ -206,9 +202,6 @@ int main(int argc, char** argv) {
vm["max_samples"].as<int>(),
vm["tight_phrases"].as<bool>());
- // Releases extra memory used by the initial precomputation.
- precomputation.reset();
-
// Creates the grammars directory if it doesn't exist.
fs::path grammar_path = vm["grammars"].as<string>();
if (!fs::is_directory(grammar_path)) {
@@ -224,6 +217,7 @@ int main(int argc, char** argv) {
}
// Extracts the grammar for each sentence and saves it to a file.
+ bool leave_one_out = vm.count("leave_one_out");
vector<string> suffixes(sentences.size());
#pragma omp parallel for schedule(dynamic) num_threads(num_threads)
for (size_t i = 0; i < sentences.size(); ++i) {
@@ -236,7 +230,9 @@ int main(int argc, char** argv) {
suffixes[i] = suffix;
unordered_set<int> blacklisted_sentence_ids;
- if (leave_one_out) blacklisted_sentence_ids.insert(i);
+ if (leave_one_out) {
+ blacklisted_sentence_ids.insert(i);
+ }
Grammar grammar = extractor.GetGrammar(
sentences[i], blacklisted_sentence_ids);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index fc386ed1..887aaec1 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -15,6 +15,7 @@ Sampler::~Sampler() {}
PhraseLocation Sampler::Sample(
const PhraseLocation& location,
const unordered_set<int>& blacklisted_sentence_ids) const {
+ shared_ptr<DataArray> source_data_array = suffix_array->GetData();
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
@@ -22,32 +23,30 @@ PhraseLocation Sampler::Sample(
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
double step = max(1.0, (double) (high - low) / max_samples);
- double i = low, last = i;
- bool found;
- shared_ptr<DataArray> source_data_array = suffix_array->GetData();
+ double i = low, last = i - 1;
while (sample.size() < max_samples && i < high) {
int x = suffix_array->GetSuffix(Round(i));
int id = source_data_array->GetSentenceId(x);
+ bool found = false;
if (blacklisted_sentence_ids.count(id)) {
- found = false;
- double backoff_step = 1;
- while (true) {
- if ((double)backoff_step >= step) break;
+ for (int backoff_step = 1; backoff_step <= step; ++backoff_step) {
double j = i - backoff_step;
x = suffix_array->GetSuffix(Round(j));
id = source_data_array->GetSentenceId(x);
if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) {
- found = true; last = i; break;
+ found = true;
+ last = i;
+ break;
}
double k = i + backoff_step;
x = suffix_array->GetSuffix(Round(k));
id = source_data_array->GetSentenceId(x);
- if (k < min(i+step, (double)high) &&
+ if (k < min(i+step, (double) high) &&
!blacklisted_sentence_ids.count(id)) {
- found = true; last = k; break;
+ found = true;
+ last = k;
+ break;
}
- if (j <= last && k >= high) break;
- backoff_step++;
}
} else {
found = true;