From 4b38556c88c739de82b9c298261a262ec620280e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 16 Apr 2012 18:20:33 -0400 Subject: rst sampler --- rst_parser/dep_training.cc | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 rst_parser/dep_training.cc (limited to 'rst_parser/dep_training.cc') diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc new file mode 100644 index 00000000..de431ebc --- /dev/null +++ b/rst_parser/dep_training.cc @@ -0,0 +1,56 @@ +#include "dep_training.h" + +#include +#include + +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "picojson.h" + +using namespace std; + +void TrainingInstance::ReadTraining(const string& fname, vector* corpus, int rank, int size) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + string err; + int lc = 0; + bool flag = false; + while(getline(in, line)) { + ++lc; + if ((lc-1) % size != rank) continue; + if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + picojson::value obj; + picojson::parse(obj, line.begin() + pos, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + corpus->push_back(TrainingInstance()); + TrainingInstance& cur = corpus->back(); + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + assert(obj.is()); + const picojson::object& d = obj.get(); + const picojson::array& ta = d.find("tokens")->second.get(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get()[0].get())); + ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); + } + const picojson::array& da = d.find("deps")->second.get(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get(); + // get dep type here + short h = thm[2].get(); + short m = thm[1].get(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; + } + if (flag) cerr << "\nRead " << lc << " training instances\n"; +} + -- cgit v1.2.3 From f4570f262c10534b335568e1d69fb3a8dfbf38ed Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 18 Apr 2012 22:37:08 -0400 Subject: refactor --- rst_parser/dep_training.cc | 70 ++++++++++++++++++++++++++++------------------ rst_parser/dep_training.h | 4 ++- rst_parser/mst_train.cc | 2 +- rst_parser/rst_train.cc | 2 +- 4 files changed, 48 insertions(+), 30 deletions(-) (limited to 'rst_parser/dep_training.cc') diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc index de431ebc..e26505ec 100644 --- a/rst_parser/dep_training.cc +++ b/rst_parser/dep_training.cc @@ -10,11 +10,51 @@ using namespace std; -void TrainingInstance::ReadTraining(const string& fname, vector* corpus, int rank, int size) { +static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) { + picojson::value obj; + string err; + picojson::parse(obj, line.begin() + start, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + TrainingInstance& cur = *out; + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + assert(obj.is()); + const picojson::object& d = obj.get(); + const picojson::array& ta = d.find("tokens")->second.get(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get()[0].get())); + ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); + } + if (d.find("deps") != d.end()) { + const picojson::array& da = d.find("deps")->second.get(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get(); + // get dep type here + short h = thm[2].get(); + short m = thm[1].get(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; +} + +bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) { + string line; + if (!getline(*in, line)) return false; + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + static int lc = 0; ++lc; + ParseInstance(line, pos + 1, instance, lc); + return true; +} + +void TrainingInstance::ReadTrainingCorpus(const string& fname, vector* corpus, int rank, int size) { ReadFile rf(fname); istream& in = *rf.stream(); string line; - string err; int lc = 0; bool flag = false; while(getline(in, line)) { @@ -24,32 +64,8 @@ void TrainingInstance::ReadTraining(const string& fname, vector 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } corpus->push_back(TrainingInstance()); - TrainingInstance& cur = corpus->back(); - TaggedSentence& ts = cur.ts; - EdgeSubset& tree = cur.tree; - assert(obj.is()); - const picojson::object& d = obj.get(); - const picojson::array& ta = d.find("tokens")->second.get(); - for (unsigned i = 0; i < ta.size(); ++i) { - ts.words.push_back(TD::Convert(ta[i].get()[0].get())); - ts.pos.push_back(TD::Convert(ta[i].get()[1].get())); - } - const picojson::array& da = d.find("deps")->second.get(); - for (unsigned i = 0; i < da.size(); ++i) { - const picojson::array& thm = da[i].get(); - // get dep type here - short h = thm[2].get(); - short m = thm[1].get(); - if (h < 0) - tree.roots.push_back(m); - else - tree.h_m_pairs.push_back(make_pair(h,m)); - } - //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; + ParseInstance(line, pos + 1, &corpus->back(), lc); } if (flag) cerr << "\nRead " << lc << " training instances\n"; } diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h index 73ffd298..3eeee22e 100644 --- a/rst_parser/dep_training.h +++ b/rst_parser/dep_training.h @@ -1,6 +1,7 @@ #ifndef _DEP_TRAINING_H_ #define _DEP_TRAINING_H_ +#include #include #include #include "arc_factored.h" @@ -11,7 +12,8 @@ struct TrainingInstance { EdgeSubset tree; SparseVector features; // reads a "Jsent" formatted dependency file - static void ReadTraining(const std::string& fname, std::vector* corpus, int rank = 0, int size = 1); + static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF + static void ReadTrainingCorpus(const std::string& fname, std::vector* corpus, int rank = 0, int size = 1); }; #endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 0709e7c9..e414f450 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -74,7 +74,7 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); ArcFeatureFunctions ffs; vector corpus; - TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus, rank, size); + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as(), &corpus, rank, size); vector forests(corpus.size()); SparseVector empirical; bool flag = false; diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc index 16673cdc..9b730f3d 100644 --- a/rst_parser/rst_train.cc +++ b/rst_parser/rst_train.cc @@ -52,7 +52,7 @@ int main(int argc, char** argv) { vector corpus; ArcFeatureFunctions ffs; GlobalFeatureFunctions gff; - TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus); + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as(), &corpus); vector forests(corpus.size()); vector zs(corpus.size()); SparseVector empirical; -- cgit v1.2.3 From d016f7f28510f822b89c921da38006eae3877872 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 19 Apr 2012 02:45:27 -0400 Subject: compute f --- rst_parser/Makefile.am | 5 +- rst_parser/dep_training.cc | 4 ++ rst_parser/rst_parse.cc | 111 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 rst_parser/rst_parse.cc (limited to 'rst_parser/dep_training.cc') diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index 876c2237..4977f584 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,5 +1,5 @@ bin_PROGRAMS = \ - mst_train rst_train + mst_train rst_train rst_parse noinst_LIBRARIES = librst.a @@ -11,4 +11,7 @@ mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/ rst_train_SOURCES = rst_train.cc rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_parse_SOURCES = rst_parse.cc +rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc index e26505ec..ef97798b 100644 --- a/rst_parser/dep_training.cc +++ b/rst_parser/dep_training.cc @@ -18,6 +18,10 @@ static void ParseInstance(const string& line, int start, TrainingInstance* out, TrainingInstance& cur = *out; TaggedSentence& ts = cur.ts; EdgeSubset& tree = cur.tree; + ts.pos.clear(); + ts.words.clear(); + tree.roots.clear(); + tree.h_m_pairs.clear(); assert(obj.is()); const picojson::object& d = obj.get(); const picojson::array& ta = d.find("tokens")->second.get(); diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc new file mode 100644 index 00000000..9c42a8f4 --- /dev/null +++ b/rst_parser/rst_parse.cc @@ -0,0 +1,111 @@ +#include "arc_factored.h" + +#include +#include +#include +#include + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("input,i",po::value()->default_value("-"), "File containing test data (jsent format)") + ("q_weights,q",po::value(), "Arc-factored weights for proposal distribution (mandatory)") + ("p_weights,p",po::value(), "Weights for target distribution (optional)") + ("samples,n",po::value()->default_value(1000), "Number of samples"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help") || conf->count("q_weights") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + vector qweights, pweights; + Weights::InitFromFile(conf["q_weights"].as(), &qweights); + if (conf.count("p_weights")) + Weights::InitFromFile(conf["p_weights"].as(), &pweights); + const bool global = pweights.size() > 0; + ArcFeatureFunctions ffs; + GlobalFeatureFunctions gff; + ReadFile rf(conf["input"].as()); + istream* in = rf.stream(); + TrainingInstance sent; + MT19937 rng; + int samples = conf["samples"].as(); + int totroot = 0, root_right = 0, tot = 0, cor = 0; + while(TrainingInstance::ReadInstance(in, &sent)) { + ffs.PrepareForInput(sent.ts); + if (global) gff.PrepareForInput(sent.ts); + ArcFactoredForest forest(sent.ts.pos.size()); + forest.ExtractFeatures(sent.ts, ffs); + forest.Reweight(qweights); + TreeSampler ts(forest); + double best_score = -numeric_limits::infinity(); + EdgeSubset best_tree; + for (int n = 0; n < samples; ++n) { + EdgeSubset tree; + ts.SampleRandomSpanningTree(&tree, &rng); + SparseVector qfeats, gfeats; + tree.ExtractFeatures(sent.ts, ffs, &qfeats); + double score = 0; + if (global) { + gff.Features(sent.ts, tree, &gfeats); + score = (qfeats + gfeats).dot(pweights); + } else { + score = qfeats.dot(qweights); + } + if (score > best_score) { + best_tree = tree; + best_score = score; + } + } + cerr << "BEST SCORE: " << best_score << endl; + cout << best_tree << endl; + const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0; + if (sent_has_ref) { + map, bool> ref; + for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i) + ref[sent.tree.h_m_pairs[i]] = true; + int ref_root = sent.tree.roots.front(); + if (ref_root == best_tree.roots.front()) { ++root_right; } + ++totroot; + for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) { + if (ref[best_tree.h_m_pairs[i]]) { + ++cor; + } + ++tot; + } + } + } + cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl; + return 0; +} + -- cgit v1.2.3