summaryrefslogtreecommitdiff
path: root/extractor/run_extractor.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-04-23 19:35:18 -0400
commit6d347f1ce078dede3da0e1498f75e357351c6543 (patch)
tree8e872b8747c530e741e55e25e9917c1bd8b32c5b /extractor/run_extractor.cc
parentd11b76def6899790161c47a73018146311356d8b (diff)
parent5e9605b65202f4e5fc59843b197d88c4774f0ac8 (diff)
merge paul's extractor code
Diffstat (limited to 'extractor/run_extractor.cc')
-rw-r--r--extractor/run_extractor.cc242
1 files changed, 242 insertions, 0 deletions
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
new file mode 100644
index 00000000..aec83e3b
--- /dev/null
+++ b/extractor/run_extractor.cc
@@ -0,0 +1,242 @@
+#include <chrono>
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <omp.h>
+#include <boost/filesystem.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "alignment.h"
+#include "data_array.h"
+#include "features/count_source_target.h"
+#include "features/feature.h"
+#include "features/is_source_singleton.h"
+#include "features/is_source_target_singleton.h"
+#include "features/max_lex_source_given_target.h"
+#include "features/max_lex_target_given_source.h"
+#include "features/sample_source_count.h"
+#include "features/target_given_source_coherent.h"
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "precomputation.h"
+#include "rule.h"
+#include "scorer.h"
+#include "suffix_array.h"
+#include "time_util.h"
+#include "translation_table.h"
+
+namespace fs = boost::filesystem;
+namespace po = boost::program_options;
+using namespace std;
+using namespace extractor;
+using namespace features;
+
+// Returns the file path in which a given grammar should be written.
+fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
+ string file_name = "grammar." + to_string(file_number);
+ return grammar_path / file_name;
+}
+
+int main(int argc, char** argv) {
+ int num_threads_default = 1;
+ #pragma omp parallel
+ num_threads_default = omp_get_num_threads();
+
+ // Sets up the command line arguments map.
+ po::options_description desc("Command line options");
+ desc.add_options()
+ ("help,h", "Show available options")
+ ("source,f", po::value<string>(), "Source language corpus")
+ ("target,e", po::value<string>(), "Target language corpus")
+ ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
+ ("alignment,a", po::value<string>()->required(), "Bitext word alignment")
+ ("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("threads,t", po::value<int>()->default_value(num_threads_default),
+ "Number of parallel extractors")
+ ("frequent", po::value<int>()->default_value(100),
+ "Number of precomputed frequent patterns")
+ ("super_frequent", po::value<int>()->default_value(10),
+ "Number of precomputed super frequent patterns")
+ ("max_rule_span", po::value<int>()->default_value(15),
+ "Maximum rule span")
+ ("max_rule_symbols", po::value<int>()->default_value(5),
+ "Maximum number of symbols (terminals + nontermals) in a rule")
+ ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
+ ("max_phrase_len", po::value<int>()->default_value(4),
+ "Maximum frequent phrase length")
+ ("max_nonterminals", po::value<int>()->default_value(2),
+ "Maximum number of nonterminals in a rule")
+ ("min_frequency", po::value<int>()->default_value(1000),
+ "Minimum number of occurrences for a pharse to be considered frequent")
+ ("max_samples", po::value<int>()->default_value(300),
+ "Maximum number of samples")
+ ("tight_phrases", po::value<bool>()->default_value(true),
+ "False if phrases may be loose (better, but slower)");
+
+ po::variables_map vm;
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+
+ // Checks for the help option before calling notify, so the we don't get an
+ // exception for missing required arguments.
+ if (vm.count("help")) {
+ cout << desc << endl;
+ return 0;
+ }
+
+ po::notify(vm);
+
+ if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) {
+ cerr << "A paralel corpus is required. "
+ << "Use -f (source) with -e (target) or -b (bitext)."
+ << endl;
+ return 1;
+ }
+
+ int num_threads = vm["threads"].as<int>();
+ cout << "Grammar extraction will use " << num_threads << " threads." << endl;
+
+ // Reads the parallel corpus.
+ Clock::time_point preprocess_start_time = Clock::now();
+ cerr << "Reading source and target data..." << endl;
+ Clock::time_point start_time = Clock::now();
+ shared_ptr<DataArray> source_data_array, target_data_array;
+ if (vm.count("bitext")) {
+ source_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), SOURCE);
+ target_data_array = make_shared<DataArray>(
+ vm["bitext"].as<string>(), TARGET);
+ } else {
+ source_data_array = make_shared<DataArray>(vm["source"].as<string>());
+ target_data_array = make_shared<DataArray>(vm["target"].as<string>());
+ }
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Reading data took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
+
+ // Constructs the suffix array for the source data.
+ cerr << "Creating source suffix array..." << endl;
+ start_time = Clock::now();
+ shared_ptr<SuffixArray> source_suffix_array =
+ make_shared<SuffixArray>(source_data_array);
+ stop_time = Clock::now();
+ cerr << "Creating suffix array took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Reads the alignment.
+ cerr << "Reading alignment..." << endl;
+ start_time = Clock::now();
+ shared_ptr<Alignment> alignment =
+ make_shared<Alignment>(vm["alignment"].as<string>());
+ stop_time = Clock::now();
+ cerr << "Reading alignment took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Constructs an index storing the occurrences in the source data for each
+ // frequent collocation.
+ cerr << "Precomputing collocations..." << endl;
+ start_time = Clock::now();
+ shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
+ source_suffix_array,
+ vm["frequent"].as<int>(),
+ vm["super_frequent"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["min_gap_size"].as<int>(),
+ vm["max_phrase_len"].as<int>(),
+ vm["min_frequency"].as<int>());
+ stop_time = Clock::now();
+ cerr << "Precomputing collocations took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Constructs a table storing p(e | f) and p(f | e) for every pair of source
+ // and target words.
+ cerr << "Precomputing conditional probabilities..." << endl;
+ start_time = Clock::now();
+ shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
+ source_data_array, target_data_array, alignment);
+ stop_time = Clock::now();
+ cerr << "Precomputing conditional probabilities took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ Clock::time_point preprocess_stop_time = Clock::now();
+ cerr << "Overall preprocessing step took "
+ << GetDuration(preprocess_start_time, preprocess_stop_time)
+ << " seconds" << endl;
+
+ // Features used to score each grammar rule.
+ Clock::time_point extraction_start_time = Clock::now();
+ vector<shared_ptr<Feature> > features = {
+ make_shared<TargetGivenSourceCoherent>(),
+ make_shared<SampleSourceCount>(),
+ make_shared<CountSourceTarget>(),
+ make_shared<MaxLexSourceGivenTarget>(table),
+ make_shared<MaxLexTargetGivenSource>(table),
+ make_shared<IsSourceSingleton>(),
+ make_shared<IsSourceTargetSingleton>()
+ };
+ shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
+
+ // Sets up the grammar extractor.
+ GrammarExtractor extractor(
+ source_suffix_array,
+ target_data_array,
+ alignment,
+ precomputation,
+ scorer,
+ vm["min_gap_size"].as<int>(),
+ vm["max_rule_span"].as<int>(),
+ vm["max_nonterminals"].as<int>(),
+ vm["max_rule_symbols"].as<int>(),
+ vm["max_samples"].as<int>(),
+ vm["tight_phrases"].as<bool>());
+
+ // Releases extra memory used by the initial precomputation.
+ precomputation.reset();
+
+ // Creates the grammars directory if it doesn't exist.
+ fs::path grammar_path = vm["grammars"].as<string>();
+ if (!fs::is_directory(grammar_path)) {
+ fs::create_directory(grammar_path);
+ }
+
+ // Reads all sentences for which we extract grammar rules (the paralellization
+ // is simplified if we read all sentences upfront).
+ string sentence;
+ vector<string> sentences;
+ while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ // Extracts the grammar for each sentence and saves it to a file.
+ vector<string> suffixes(sentences.size());
+ #pragma omp parallel for schedule(dynamic) num_threads(num_threads)
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string suffix;
+ int position = sentences[i].find("|||");
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
+ }
+ suffixes[i] = suffix;
+
+ Grammar grammar = extractor.GetGrammar(sentences[i]);
+ ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
+ output << grammar;
+ }
+
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ cout << "<seg grammar=\"" << GetGrammarFilePath(grammar_path, i) << "\" id=\""
+ << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
+ }
+
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
+
+ return 0;
+}