From fa80a23079d642a3e984712c9dfa9ac47d2457fa Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 16 Apr 2012 22:42:24 -0400 Subject: refactor some code, simplify, fix typos --- rst_parser/Makefile.am | 16 ++--- rst_parser/arc_factored.cc | 40 ++++++------ rst_parser/arc_factored.h | 7 ++- rst_parser/arc_ff.cc | 120 +++++++++++++++++++++--------------- rst_parser/arc_ff.h | 35 +++-------- rst_parser/arc_ff_factory.h | 42 ------------- rst_parser/mst_train.cc | 37 +++++------- rst_parser/rst_parse.cc | 126 -------------------------------------- rst_parser/rst_test.cc | 48 --------------- rst_parser/rst_train.cc | 144 ++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 264 insertions(+), 351 deletions(-) delete mode 100644 rst_parser/arc_ff_factory.h delete mode 100644 rst_parser/rst_parse.cc delete mode 100644 rst_parser/rst_test.cc create mode 100644 rst_parser/rst_train.cc diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index 6e884f53..876c2237 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,22 +1,14 @@ bin_PROGRAMS = \ - mst_train rst_parse - -noinst_PROGRAMS = \ - rst_test - -TESTS = rst_test + mst_train rst_train noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc global_ff.cc mst_train_SOURCES = mst_train.cc mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz -rst_parse_SOURCES = rst_parse.cc -rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -rst_test_SOURCES = rst_test.cc -rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_train_SOURCES = rst_train.cc +rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index 34c689f4..74bf7516 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -13,36 +13,30 @@ using namespace std::tr1; using namespace boost; void EdgeSubset::ExtractFeatures(const TaggedSentence& sentence, - const std::vector >& ffs, + const ArcFeatureFunctions& ffs, SparseVector* features) const { SparseVector efmap; - for (int i = 0; i < ffs.size(); ++i) { - const ArcFeatureFunction& ff= *ffs[i]; - for (int j = 0; j < h_m_pairs.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(sentence, h_m_pairs[j].first, - h_m_pairs[j].second, - &efmap); - (*features) += efmap; - } - for (int j = 0; j < roots.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(sentence, -1, roots[j], &efmap); - (*features) += efmap; - } + for (int j = 0; j < h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(sentence, h_m_pairs[j].first, + h_m_pairs[j].second, + &efmap); + (*features) += efmap; + } + for (int j = 0; j < roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(sentence, -1, roots[j], &efmap); + (*features) += efmap; } } void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence, - const std::vector >& ffs) { - for (int i = 0; i < ffs.size(); ++i) { - const ArcFeatureFunction& ff = *ffs[i]; - for (int m = 0; m < num_words_; ++m) { - for (int h = 0; h < num_words_; ++h) { - ff.EgdeFeatures(sentence, h, m, &edges_(h,m).features); - } - ff.EgdeFeatures(sentence, -1, m, &root_edges_[m].features); + const ArcFeatureFunctions& ffs) { + for (int m = 0; m < num_words_; ++m) { + for (int h = 0; h < num_words_; ++h) { + ffs.EdgeFeatures(sentence, h, m, &edges_(h,m).features); } + ffs.EdgeFeatures(sentence, -1, m, &root_edges_[m].features); } } diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index a271c8d4..c5481d80 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -17,14 +17,15 @@ struct TaggedSentence { std::vector pos; }; -struct ArcFeatureFunction; +struct ArcFeatureFunctions; struct EdgeSubset { EdgeSubset() {} std::vector roots; // unless multiroot trees are supported, this // will have a single member std::vector > h_m_pairs; // h,m start at 0 + // assumes ArcFeatureFunction::PrepareForInput has already been called void ExtractFeatures(const TaggedSentence& sentence, - const std::vector >& ffs, + const ArcFeatureFunctions& ffs, SparseVector* features) const; }; @@ -74,7 +75,7 @@ class ArcFactoredForest { // set eges_[*].features void ExtractFeatures(const TaggedSentence& sentence, - const std::vector >& ffs); + const ArcFeatureFunctions& ffs); const Edge& operator()(short h, short m) const { return h >= 0 ? edges_(h, m) : root_edges_[m]; diff --git a/rst_parser/arc_ff.cc b/rst_parser/arc_ff.cc index f9effbda..10885716 100644 --- a/rst_parser/arc_ff.cc +++ b/rst_parser/arc_ff.cc @@ -6,59 +6,81 @@ using namespace std; -ArcFeatureFunction::~ArcFeatureFunction() {} +struct ArcFFImpl { + ArcFFImpl() : kROOT("ROOT") {} + const string kROOT; -void ArcFeatureFunction::PrepareForInput(const TaggedSentence&) {} + void PrepareForInput(const TaggedSentence& sentence) { + (void) sentence; + } + + void EdgeFeatures(const TaggedSentence& sent, + short h, + short m, + SparseVector* features) const { + const bool is_root = (h == -1); + const string& head_word = (is_root ? kROOT : TD::Convert(sent.words[h])); + const string& head_pos = (is_root ? kROOT : TD::Convert(sent.pos[h])); + const string& mod_word = TD::Convert(sent.words[m]); + const string& mod_pos = TD::Convert(sent.pos[m]); + const bool dir = m < h; + int v = m - h; + if (v < 0) { + v= -1 - int(log(-v) / log(2)); + } else { + v= int(log(v) / log(2)); + } + static map lenmap; + int& lenfid = lenmap[v]; + if (!lenfid) { + ostringstream os; + if (v < 0) os << "LenL" << -v; else os << "LenR" << v; + lenfid = FD::Convert(os.str()); + } + features->set_value(lenfid, 1.0); + const string& lenstr = FD::Convert(lenfid); + if (!is_root) { + static int modl = FD::Convert("ModLeft"); + static int modr = FD::Convert("ModRight"); + if (dir) features->set_value(modl, 1); + else features->set_value(modr, 1); + } + if (is_root) { + ostringstream os; + os << "ROOT:" << mod_pos; + features->set_value(FD::Convert(os.str()), 1.0); + os << "_" << lenstr; + features->set_value(FD::Convert(os.str()), 1.0); + } else { // not root + ostringstream os; + os << "HM:" << head_pos << '_' << mod_pos; + features->set_value(FD::Convert(os.str()), 1.0); + os << '_' << dir; + features->set_value(FD::Convert(os.str()), 1.0); + os << '_' << lenstr; + features->set_value(FD::Convert(os.str()), 1.0); + ostringstream os2; + os2 << "LexHM:" << head_word << '_' << mod_word; + features->set_value(FD::Convert(os2.str()), 1.0); + os2 << '_' << dir; + features->set_value(FD::Convert(os2.str()), 1.0); + os2 << '_' << lenstr; + features->set_value(FD::Convert(os2.str()), 1.0); + } + } +}; -DistancePenalty::DistancePenalty(const string&) : fidw_(FD::Convert("Distance")), fidr_(FD::Convert("RootDistance")) {} +ArcFeatureFunctions::ArcFeatureFunctions() : pimpl(new ArcFFImpl) {} +ArcFeatureFunctions::~ArcFeatureFunctions() { delete pimpl; } + +void ArcFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) { + pimpl->PrepareForInput(sentence); +} -void DistancePenalty::EdgeFeaturesImpl(const TaggedSentence& sent, +void ArcFeatureFunctions::EdgeFeatures(const TaggedSentence& sentence, short h, short m, SparseVector* features) const { - const bool dir = m < h; - const bool is_root = (h == -1); - int v = m - h; - if (v < 0) { - v= -1 - int(log(-v) / log(2)); - } else { - v= int(log(v) / log(2)); - } - static map lenmap; - int& lenfid = lenmap[v]; - if (!lenfid) { - ostringstream os; - if (v < 0) os << "LenL" << -v; else os << "LenR" << v; - lenfid = FD::Convert(os.str()); - } - features->set_value(lenfid, 1.0); - const string& lenstr = FD::Convert(lenfid); - if (!is_root) { - static int modl = FD::Convert("ModLeft"); - static int modr = FD::Convert("ModRight"); - if (dir) features->set_value(modl, 1); - else features->set_value(modr, 1); - } - if (is_root) { - ostringstream os; - os << "ROOT:" << TD::Convert(sent.pos[m]); - features->set_value(FD::Convert(os.str()), 1.0); - os << "_" << lenstr; - features->set_value(FD::Convert(os.str()), 1.0); - } else { // not root - ostringstream os; - os << "HM:" << TD::Convert(sent.pos[h]) << '_' << TD::Convert(sent.pos[m]); - features->set_value(FD::Convert(os.str()), 1.0); - os << '_' << dir; - features->set_value(FD::Convert(os.str()), 1.0); - os << '_' << lenstr; - features->set_value(FD::Convert(os.str()), 1.0); - ostringstream os2; - os2 << "LexHM:" << TD::Convert(sent.words[h]) << '_' << TD::Convert(sent.words[m]); - features->set_value(FD::Convert(os2.str()), 1.0); - os2 << '_' << dir; - features->set_value(FD::Convert(os2.str()), 1.0); - os2 << '_' << lenstr; - features->set_value(FD::Convert(os2.str()), 1.0); - } + pimpl->EdgeFeatures(sentence, h, m, features); } + diff --git a/rst_parser/arc_ff.h b/rst_parser/arc_ff.h index bc51fef4..52f311d2 100644 --- a/rst_parser/arc_ff.h +++ b/rst_parser/arc_ff.h @@ -7,37 +7,22 @@ #include "arc_factored.h" struct TaggedSentence; -class ArcFeatureFunction { +struct ArcFFImpl; +class ArcFeatureFunctions { public: - virtual ~ArcFeatureFunction(); + ArcFeatureFunctions(); + ~ArcFeatureFunctions(); // called once, per input, before any calls to EdgeFeatures // used to initialize sentence-specific data structures - virtual void PrepareForInput(const TaggedSentence& sentence); + void PrepareForInput(const TaggedSentence& sentence); - inline void EgdeFeatures(const TaggedSentence& sentence, - short h, - short m, - SparseVector* features) const { - EdgeFeaturesImpl(sentence, h, m, features); - } - protected: - virtual void EdgeFeaturesImpl(const TaggedSentence& sentence, - short h, - short m, - SparseVector* features) const = 0; -}; - -class DistancePenalty : public ArcFeatureFunction { - public: - DistancePenalty(const std::string& param); - protected: - virtual void EdgeFeaturesImpl(const TaggedSentence& sentence, - short h, - short m, - SparseVector* features) const; + void EdgeFeatures(const TaggedSentence& sentence, + short h, + short m, + SparseVector* features) const; private: - const int fidw_, fidr_; + ArcFFImpl* pimpl; }; #endif diff --git a/rst_parser/arc_ff_factory.h b/rst_parser/arc_ff_factory.h deleted file mode 100644 index 4237fd5d..00000000 --- a/rst_parser/arc_ff_factory.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _ARC_FF_FACTORY_H_ -#define _ARC_FF_FACTORY_H_ - -#include -#include -#include - -struct ArcFFFactoryBase { - virtual boost::shared_ptr Create(const std::string& param) const = 0; -}; - -template -struct ArcFFFactory : public ArcFFFactoryBase { - boost::shared_ptr Create(const std::string& param) const { - return boost::shared_ptr(new FF(param)); - } -}; - -struct ArcFFRegistry { - boost::shared_ptr Create(const std::string& name, const std::string& param) const { - std::map::const_iterator it = facts.find(name); - assert(it != facts.end()); - return it->second->Create(param); - } - - void Register(const std::string& name, ArcFFFactoryBase* fact) { - ArcFFFactoryBase*& f = facts[name]; - assert(f == NULL); - f = fact; - } - std::map facts; -}; - -std::ostream& operator<<(std::ostream& os, const ArcFFRegistry& reg) { - for (std::map::const_iterator it = reg.facts.begin(); - it != reg.facts.end(); ++it) { - os << " " << it->first << std::endl; - } - return os; -} - -#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index f0403d7e..0709e7c9 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -6,7 +6,6 @@ #include #include "arc_ff.h" -#include "arc_ff_factory.h" #include "stringlib.h" #include "filelib.h" #include "tdict.h" @@ -22,7 +21,6 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { string cfg_file; opts.add_options() ("training_data,t",po::value()->default_value("-"), "File containing training data (jsent format)") - ("feature_function,F",po::value >()->composing(), "feature function (multiple permitted)") ("weights,w",po::value(), "Optional starting weights") ("output_every_i_iterations,I",po::value()->default_value(1), "Write weights every I iterations") ("regularization_strength,C",po::value()->default_value(1.0), "Regularization strength") @@ -74,12 +72,8 @@ int main(int argc, char** argv) { int size = 1; po::variables_map conf; InitCommandLine(argc, argv, &conf); - ArcFactoredForest af(5); - ArcFFRegistry reg; - reg.Register("DistancePenalty", new ArcFFFactory); + ArcFeatureFunctions ffs; vector corpus; - vector > ffs; - ffs.push_back(boost::shared_ptr(new DistancePenalty(""))); TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus, rank, size); vector forests(corpus.size()); SparseVector empirical; @@ -88,22 +82,19 @@ int main(int argc, char** argv) { TrainingInstance& cur = corpus[i]; if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } - for (int fi = 0; fi < ffs.size(); ++fi) { - ArcFeatureFunction& ff = *ffs[fi]; - ff.PrepareForInput(cur.ts); - SparseVector efmap; - for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, - cur.tree.h_m_pairs[j].second, - &efmap); - cur.features += efmap; - } - for (int j = 0; j < cur.tree.roots.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); - cur.features += efmap; - } + ffs.PrepareForInput(cur.ts); + SparseVector efmap; + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; } empirical += cur.features; forests[i].resize(cur.ts.words.size()); diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc deleted file mode 100644 index 9cc1359a..00000000 --- a/rst_parser/rst_parse.cc +++ /dev/null @@ -1,126 +0,0 @@ -#include "arc_factored.h" - -#include -#include -#include -#include - -#include "timing_stats.h" -#include "arc_ff.h" -#include "arc_ff_factory.h" -#include "dep_training.h" -#include "stringlib.h" -#include "filelib.h" -#include "tdict.h" -#include "weights.h" -#include "rst.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - string cfg_file; - opts.add_options() - ("training_data,t",po::value()->default_value("-"), "File containing training data (jsent format)") - ("feature_function,F",po::value >()->composing(), "feature function (multiple permitted)") - ("q_weights,q",po::value(), "Arc-factored weights for proposal distribution") - ("samples,n",po::value()->default_value(1000), "Number of samples"); - po::options_description clo("Command line options"); - clo.add_options() - ("config,c", po::value(&cfg_file), "Configuration file") - ("help,?", "Print this help message and exit"); - - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(dconfig_options).add(clo); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (cfg_file.size() > 0) { - ReadFile rf(cfg_file); - po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); - } - if (conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - ArcFactoredForest af(5); - ArcFFRegistry reg; - reg.Register("DistancePenalty", new ArcFFFactory); - vector corpus; - vector > ffs; - ffs.push_back(boost::shared_ptr(new DistancePenalty(""))); - TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus); - vector forests(corpus.size()); - SparseVector empirical; - bool flag = false; - for (int i = 0; i < corpus.size(); ++i) { - TrainingInstance& cur = corpus[i]; - if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } - if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } - for (int fi = 0; fi < ffs.size(); ++fi) { - ArcFeatureFunction& ff = *ffs[fi]; - ff.PrepareForInput(cur.ts); - SparseVector efmap; - for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, - cur.tree.h_m_pairs[j].second, - &efmap); - cur.features += efmap; - } - for (int j = 0; j < cur.tree.roots.size(); ++j) { - efmap.clear(); - ff.EgdeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); - cur.features += efmap; - } - } - empirical += cur.features; - forests[i].resize(cur.ts.words.size()); - forests[i].ExtractFeatures(cur.ts, ffs); - } - if (flag) cerr << endl; - vector weights(FD::NumFeats(), 0.0); - Weights::InitFromFile(conf["q_weights"].as(), &weights); - MT19937 rng; - SparseVector model_exp; - SparseVector sampled_exp; - int samples = conf["samples"].as(); - for (int i = 0; i < corpus.size(); ++i) { - const int num_words = corpus[i].ts.words.size(); - forests[i].Reweight(weights); - forests[i].EdgeMarginals(); - model_exp.clear(); - for (int h = -1; h < num_words; ++h) { - for (int m = 0; m < num_words; ++m) { - if (h == m) continue; - const ArcFactoredForest::Edge& edge = forests[i](h,m); - const SparseVector& fmap = edge.features; - double prob = edge.edge_prob.as_float(); - model_exp += fmap * prob; - } - } - //cerr << "TRUE EXP: " << model_exp << endl; - - forests[i].Reweight(weights); - TreeSampler ts(forests[i]); - sampled_exp.clear(); - //ostringstream os; os << "Samples_" << samples; - //Timer t(os.str()); - for (int n = 0; n < samples; ++n) { - EdgeSubset tree; - ts.SampleRandomSpanningTree(&tree, &rng); - SparseVector feats; - tree.ExtractFeatures(corpus[i].ts, ffs, &feats); - sampled_exp += feats; - } - sampled_exp /= samples; - cerr << "L2 norm of diff @ " << samples << " samples: " << (model_exp - sampled_exp).l2norm() << endl; - } - return 0; -} - diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc deleted file mode 100644 index 3bb95759..00000000 --- a/rst_parser/rst_test.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "arc_factored.h" - -#include - -#include - -using namespace std; - -int main(int argc, char** argv) { - // John saw Mary - // (H -> M) - // (1 -> 2) 20 - // (1 -> 3) 3 - // (2 -> 1) 20 - // (2 -> 3) 30 - // (3 -> 2) 0 - // (3 -> 1) 11 - // (0, 2) 10 - // (0, 1) 9 - // (0, 3) 9 - ArcFactoredForest af(3); - af(0,1).edge_prob.logeq(20); - af(0,2).edge_prob.logeq(3); - af(1,0).edge_prob.logeq(20); - af(1,2).edge_prob.logeq(30); - af(2,1).edge_prob.logeq(0); - af(2,0).edge_prob.logeq(11); - af(-1,1).edge_prob.logeq(10); - af(-1,0).edge_prob.logeq(9); - af(-1,2).edge_prob.logeq(9); - EdgeSubset tree; -// af.MaximumEdgeSubset(&tree); - prob_t z; - af.EdgeMarginals(&z); - cerr << "Z = " << abs(z) << endl; - af.PickBestParentForEachWord(&tree); - cerr << tree << endl; - typedef Eigen::Matrix M3; - M3 A = M3::Zero(); - A(0,0) = prob_t(1); - A(1,0) = prob_t(3); - A(0,1) = prob_t(2); - A(1,1) = prob_t(4); - prob_t det = A.determinant(); - cerr << det.as_float() << endl; - return 0; -} - diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc new file mode 100644 index 00000000..16673cdc --- /dev/null +++ b/rst_parser/rst_train.cc @@ -0,0 +1,144 @@ +#include "arc_factored.h" + +#include +#include +#include +#include + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("training_data,t",po::value()->default_value("-"), "File containing training data (jsent format)") + ("q_weights,q",po::value(), "Arc-factored weights for proposal distribution") + ("samples,n",po::value()->default_value(1000), "Number of samples"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + vector qweights(FD::NumFeats(), 0.0); + Weights::InitFromFile(conf["q_weights"].as(), &qweights); + vector corpus; + ArcFeatureFunctions ffs; + GlobalFeatureFunctions gff; + TrainingInstance::ReadTraining(conf["training_data"].as(), &corpus); + vector forests(corpus.size()); + vector zs(corpus.size()); + SparseVector empirical; + bool flag = false; + for (int i = 0; i < corpus.size(); ++i) { + TrainingInstance& cur = corpus[i]; + if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } + if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } + SparseVector efmap; + ffs.PrepareForInput(cur.ts); + gff.PrepareForInput(cur.ts); + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; + } + efmap.clear(); + gff.Features(cur.ts, cur.tree, &efmap); + cur.features += efmap; + empirical += cur.features; + forests[i].resize(cur.ts.words.size()); + forests[i].ExtractFeatures(cur.ts, ffs); + forests[i].Reweight(qweights); + forests[i].EdgeMarginals(&zs[i]); + zs[i] = prob_t::One() / zs[i]; + // cerr << zs[i] << endl; + forests[i].Reweight(qweights); // EdgeMarginals overwrites edge_prob + } + if (flag) cerr << endl; + MT19937 rng; + SparseVector model_exp; + SparseVector weights; + Weights::InitSparseVector(qweights, &weights); + int samples = conf["samples"].as(); + for (int i = 0; i < corpus.size(); ++i) { +#if 0 + forests[i].EdgeMarginals(); + model_exp.clear(); + for (int h = -1; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + if (h == m) continue; + const ArcFactoredForest::Edge& edge = forests[i](h,m); + const SparseVector& fmap = edge.features; + double prob = edge.edge_prob.as_float(); + model_exp += fmap * prob; + } + } + cerr << "TRUE EXP: " << model_exp << endl; + forests[i].Reweight(weights); +#endif + + TreeSampler ts(forests[i]); + prob_t zhat = prob_t::Zero(); + SparseVector sampled_exp; + for (int n = 0; n < samples; ++n) { + EdgeSubset tree; + ts.SampleRandomSpanningTree(&tree, &rng); + SparseVector qfeats, gfeats; + tree.ExtractFeatures(corpus[i].ts, ffs, &qfeats); + prob_t u; u.logeq(qfeats.dot(qweights)); + const prob_t q = u / zs[i]; // proposal mass + gff.Features(corpus[i].ts, tree, &gfeats); + SparseVector tot_feats = qfeats + gfeats; + u.logeq(tot_feats.dot(weights)); + prob_t w = u / q; + zhat += w; + for (SparseVector::const_iterator it = tot_feats.begin(); it != tot_feats.end(); ++it) + sampled_exp.add_value(it->first, w * prob_t(it->second)); + } + sampled_exp /= zhat; + SparseVector tot_m; + for (SparseVector::const_iterator it = sampled_exp.begin(); it != sampled_exp.end(); ++it) + tot_m.add_value(it->first, it->second.as_float()); + //cerr << "DIFF: " << (tot_m - corpus[i].features) << endl; + const double eta = 0.03; + weights -= (tot_m - corpus[i].features) * eta; + } + cerr << "WEIGHTS.\n"; + cerr << weights << endl; + return 0; +} + -- cgit v1.2.3