diff options
Diffstat (limited to 'rst_parser')
-rw-r--r-- | rst_parser/Makefile.am | 20 | ||||
-rw-r--r-- | rst_parser/arc_factored.cc | 128 | ||||
-rw-r--r-- | rst_parser/arc_factored.h | 82 | ||||
-rw-r--r-- | rst_parser/arc_factored_marginals.cc | 58 | ||||
-rw-r--r-- | rst_parser/arc_ff.cc | 183 | ||||
-rw-r--r-- | rst_parser/arc_ff.h | 28 | ||||
-rw-r--r-- | rst_parser/dep_training.cc | 76 | ||||
-rw-r--r-- | rst_parser/dep_training.h | 19 | ||||
-rw-r--r-- | rst_parser/global_ff.cc | 44 | ||||
-rw-r--r-- | rst_parser/global_ff.h | 18 | ||||
-rw-r--r-- | rst_parser/mst_train.cc | 220 | ||||
-rw-r--r-- | rst_parser/picojson.h | 979 | ||||
-rw-r--r-- | rst_parser/rst.cc | 77 | ||||
-rw-r--r-- | rst_parser/rst.h | 15 | ||||
-rw-r--r-- | rst_parser/rst_parse.cc | 111 | ||||
-rw-r--r-- | rst_parser/rst_test.cc | 33 | ||||
-rw-r--r-- | rst_parser/rst_train.cc | 144 |
17 files changed, 2159 insertions, 76 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index e97ab5c5..4977f584 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,19 +1,17 @@ bin_PROGRAMS = \ - mst_train - -noinst_PROGRAMS = \ - rst_test - -TESTS = rst_test + mst_train rst_train rst_parse noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc rst.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc global_ff.cc mst_train_SOURCES = mst_train.cc -mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz + +rst_train_SOURCES = rst_train.cc +rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -rst_test_SOURCES = rst_test.cc -rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_parse_SOURCES = rst_parse.cc +rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index 1e75600b..74bf7516 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -1,31 +1,151 @@ #include "arc_factored.h" #include <set> +#include <tr1/unordered_set> #include <boost/pending/disjoint_sets.hpp> +#include <boost/functional/hash.hpp> + +#include "arc_ff.h" using namespace std; +using namespace std::tr1; using namespace boost; +void EdgeSubset::ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs, + SparseVector<double>* features) const { + SparseVector<weight_t> efmap; + for (int j = 0; j < h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(sentence, h_m_pairs[j].first, + h_m_pairs[j].second, + &efmap); + (*features) += efmap; + } + for (int j = 0; j < roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(sentence, -1, roots[j], &efmap); + (*features) += efmap; + } +} + +void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs) { + for (int m = 0; m < num_words_; ++m) { + for (int h = 0; h < num_words_; ++h) { + ffs.EdgeFeatures(sentence, h, m, &edges_(h,m).features); + } + ffs.EdgeFeatures(sentence, -1, m, &root_edges_[m].features); + } +} + +void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const { + for (int m = 0; m < num_words_; ++m) { + int best_head = -2; + prob_t best_score; + for (int h = -1; h < num_words_; ++h) { + const Edge& edge = (*this)(h,m); + if (best_head < -1 || edge.edge_prob > best_score) { + best_score = edge.edge_prob; + best_head = h; + } + } + assert(best_head >= -1); + if (best_head >= 0) + st->h_m_pairs.push_back(make_pair<short,short>(best_head, m)); + else + st->roots.push_back(m); + } +} + +struct WeightedEdge { + WeightedEdge() : h(), m(), weight() {} + WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {} + short h, m; + float weight; + inline bool operator==(const WeightedEdge& o) const { + return h == o.h && m == o.m && weight == o.weight; + } + inline bool operator!=(const WeightedEdge& o) const { + return h != o.h || m != o.m || weight != o.weight; + } +}; +inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; } +inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); } + + +struct PriorityQueue { + void push(const WeightedEdge& e) {} + const WeightedEdge& top() const { + static WeightedEdge w(1,2,3); + return w; + } + void pop() {} + void increment_all(float p) {} +}; + // based on Trajan 1977 -void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { +void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const { typedef disjoint_sets_with_storage<identity_property_map, identity_property_map, find_with_full_path_compression> DisjointSet; DisjointSet strongly(num_words_ + 1); DisjointSet weakly(num_words_ + 1); - set<unsigned> roots, h, rset; - vector<pair<short, short> > enter(num_words_ + 1); + set<unsigned> roots, rset; + unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h; + vector<PriorityQueue> qs(num_words_ + 1); + vector<WeightedEdge> enter(num_words_ + 1); + vector<unsigned> mins(num_words_ + 1); + const WeightedEdge kDUMMY(0,0,0.0f); for (unsigned i = 0; i <= num_words_; ++i) { + if (i > 0) { + // I(i) incidence on i -- all incoming edges + for (unsigned j = 0; j <= num_words_; ++j) { + qs[i].push(WeightedEdge(j, i, Weight(j,i))); + } + } strongly.make_set(i); weakly.make_set(i); roots.insert(i); + enter[i] = kDUMMY; + mins[i] = i; } while(!roots.empty()) { set<unsigned>::iterator it = roots.begin(); const unsigned k = *it; roots.erase(it); cerr << "k=" << k << endl; - pair<short,short> ij; // TODO = Max(k); + WeightedEdge ij = qs[k].top(); // MAX(k) + qs[k].pop(); + if (ij.weight <= 0) { + rset.insert(k); + } else { + if (strongly.find_set(ij.h) == k) { + roots.insert(k); + } else { + h.insert(ij); + if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) { + weakly.union_set(ij.h, ij.m); + enter[k] = ij; + } else { + unsigned vertex = 0; + float val = 99999999999; + WeightedEdge xy = ij; + while(xy != kDUMMY) { + if (xy.weight < val) { + val = xy.weight; + vertex = strongly.find_set(xy.m); + } + xy = enter[strongly.find_set(xy.h)]; + } + qs[k].increment_all(val - ij.weight); + mins[k] = mins[vertex]; + xy = enter[strongly.find_set(ij.h)]; + while (xy != kDUMMY) { + } + } + } + } } } diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..c5481d80 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -5,37 +5,65 @@ #include <cassert> #include <vector> #include <utility> +#include <boost/shared_ptr.hpp> #include "array2d.h" #include "sparse_vector.h" #include "prob.h" #include "weights.h" +#include "wordid.h" -struct SpanningTree { - SpanningTree() : roots(1, -1) {} +struct TaggedSentence { + std::vector<WordID> words; + std::vector<WordID> pos; +}; + +struct ArcFeatureFunctions; +struct EdgeSubset { + EdgeSubset() {} std::vector<short> roots; // unless multiroot trees are supported, this // will have a single member - std::vector<std::pair<short, short> > h_m_pairs; + std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0 + // assumes ArcFeatureFunction::PrepareForInput has already been called + void ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs, + SparseVector<double>* features) const; }; class ArcFactoredForest { public: - explicit ArcFactoredForest(short num_words) : - num_words_(num_words), - root_edges_(num_words), - edges_(num_words, num_words) { + ArcFactoredForest() : num_words_() {} + explicit ArcFactoredForest(short num_words) : num_words_(num_words) { + resize(num_words); + } + + unsigned size() const { return num_words_; } + + void resize(unsigned num_words) { + num_words_ = num_words; + root_edges_.clear(); + edges_.clear(); + root_edges_.resize(num_words); + edges_.resize(num_words, num_words); for (int h = 0; h < num_words; ++h) { for (int m = 0; m < num_words; ++m) { - edges_(h, m).h = h + 1; - edges_(h, m).m = m + 1; + edges_(h, m).h = h; + edges_(h, m).m = m; } - root_edges_[h].h = 0; - root_edges_[h].m = h + 1; + root_edges_[h].h = -1; + root_edges_[h].m = h; } } // compute the maximum spanning tree based on the current weighting // using the O(n^2) CLE algorithm - void MaximumSpanningTree(SpanningTree* st) const; + void MaximumSpanningTree(EdgeSubset* st) const; + + // Reweight edges so that edge_prob is the edge's marginals + // optionally returns log partition + void EdgeMarginals(prob_t* p_log_z = NULL); + + // This may not return a tree + void PickBestParentForEachWord(EdgeSubset* st) const; struct Edge { Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -45,20 +73,20 @@ class ArcFactoredForest { prob_t edge_prob; }; + // set eges_[*].features + void ExtractFeatures(const TaggedSentence& sentence, + const ArcFeatureFunctions& ffs); + const Edge& operator()(short h, short m) const { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; } Edge& operator()(short h, short m) { - assert(m > 0); - assert(m <= num_words_); - assert(h >= 0); - assert(h <= num_words_); - return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; + return h >= 0 ? edges_(h, m) : root_edges_[m]; + } + + float Weight(short h, short m) const { + return log((*this)(h,m).edge_prob); } template <class V> @@ -76,7 +104,7 @@ class ArcFactoredForest { } private: - unsigned num_words_; + int num_words_; std::vector<Edge> root_edges_; Array2D<Edge> edges_; }; @@ -85,4 +113,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge& return os << "(" << edge.h << " < " << edge.m << ")"; } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { + for (unsigned i = 0; i < ss.roots.size(); ++i) + os << "ROOT < " << ss.roots[i] << std::endl; + for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) + os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; + return os; +} + #endif diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc new file mode 100644 index 00000000..3e8c9f86 --- /dev/null +++ b/rst_parser/arc_factored_marginals.cc @@ -0,0 +1,58 @@ +#include "arc_factored.h" + +#include <iostream> + +#include "config.h" + +using namespace std; + +#if HAVE_EIGEN + +#include <Eigen/Dense> +typedef Eigen::Matrix<prob_t, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix; +typedef Eigen::Matrix<prob_t, Eigen::Dynamic, 1> RootVector; + +void ArcFactoredForest::EdgeMarginals(prob_t *plog_z) { + ArcMatrix A(num_words_,num_words_); + RootVector r(num_words_); + for (int h = 0; h < num_words_; ++h) { + for (int m = 0; m < num_words_; ++m) { + if (h != m) + A(h,m) = edges_(h,m).edge_prob; + else + A(h,m) = prob_t::Zero(); + } + r(h) = root_edges_[h].edge_prob; + } + + ArcMatrix L = -A; + L.diagonal() = A.colwise().sum(); + L.row(0) = r; + ArcMatrix Linv = L.inverse(); + if (plog_z) *plog_z = Linv.determinant(); + RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); + static const prob_t ZERO(0); + static const prob_t ONE(1); +// ArcMatrix T = Linv; + for (int h = 0; h < num_words_; ++h) { + for (int m = 0; m < num_words_; ++m) { + const prob_t marginal = (m == 0 ? ZERO : ONE) * A(h,m) * Linv(m,m) - + (h == 0 ? ZERO : ONE) * A(h,m) * Linv(m,h); + edges_(h,m).edge_prob = marginal; +// T(h,m) = marginal; + } + root_edges_[h].edge_prob = rootMarginals(h); + } +// cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +// cerr << "M:\n" << T << endl; +} + +#else + +void ArcFactoredForest::EdgeMarginals(prob_t *) { + cerr << "EdgeMarginals() requires --with-eigen!\n"; + abort(); +} + +#endif + diff --git a/rst_parser/arc_ff.cc b/rst_parser/arc_ff.cc new file mode 100644 index 00000000..c4e5aa17 --- /dev/null +++ b/rst_parser/arc_ff.cc @@ -0,0 +1,183 @@ +#include "arc_ff.h" + +#include <iostream> +#include <sstream> + +#include "stringlib.h" +#include "tdict.h" +#include "fdict.h" +#include "sentence_metadata.h" + +using namespace std; + +struct ArcFFImpl { + ArcFFImpl() : kROOT("ROOT"), kLEFT_POS("LEFT"), kRIGHT_POS("RIGHT") {} + const string kROOT; + const string kLEFT_POS; + const string kRIGHT_POS; + map<WordID, vector<int> > pcs; + + void PrepareForInput(const TaggedSentence& sent) { + pcs.clear(); + for (int i = 0; i < sent.pos.size(); ++i) + pcs[sent.pos[i]].resize(1, 0); + pcs[sent.pos[0]][0] = 1; + for (int i = 1; i < sent.pos.size(); ++i) { + const WordID posi = sent.pos[i]; + for (map<WordID, vector<int> >::iterator j = pcs.begin(); j != pcs.end(); ++j) { + const WordID posj = j->first; + vector<int>& cs = j->second; + cs.push_back(cs.back() + (posj == posi ? 1 : 0)); + } + } + } + + template <typename A> + static void Fire(SparseVector<weight_t>* v, const A& a) { + ostringstream os; + os << a; + v->set_value(FD::Convert(os.str()), 1); + } + + template <typename A, typename B> + static void Fire(SparseVector<weight_t>* v, const A& a, const B& b) { + ostringstream os; + os << a << ':' << b; + v->set_value(FD::Convert(os.str()), 1); + } + + template <typename A, typename B, typename C> + static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c) { + ostringstream os; + os << a << ':' << b << '_' << c; + v->set_value(FD::Convert(os.str()), 1); + } + + template <typename A, typename B, typename C, typename D> + static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d) { + ostringstream os; + os << a << ':' << b << '_' << c << '_' << d; + v->set_value(FD::Convert(os.str()), 1); + } + + template <typename A, typename B, typename C, typename D, typename E> + static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d, const E& e) { + ostringstream os; + os << a << ':' << b << '_' << c << '_' << d << '_' << e; + v->set_value(FD::Convert(os.str()), 1); + } + + static void AddConjoin(const SparseVector<double>& v, const string& feat, SparseVector<double>* pf) { + for (SparseVector<double>::const_iterator it = v.begin(); it != v.end(); ++it) + pf->set_value(FD::Convert(FD::Convert(it->first) + "_" + feat), it->second); + } + + static inline string Fixup(const string& str) { + string res = LowercaseString(str); + if (res.size() < 6) return res; + return res.substr(0, 5) + "*"; + } + + static inline string Suffix(const string& str) { + if (str.size() < 4) return ""; else return str.substr(str.size() - 3); + } + + void EdgeFeatures(const TaggedSentence& sent, + short h, + short m, + SparseVector<weight_t>* features) const { + const bool is_root = (h == -1); + const string head_word = (is_root ? kROOT : Fixup(TD::Convert(sent.words[h]))); + int num_words = sent.words.size(); + const string& head_pos = (is_root ? kROOT : TD::Convert(sent.pos[h])); + const string mod_word = Fixup(TD::Convert(sent.words[m])); + const string& mod_pos = TD::Convert(sent.pos[m]); + const string& mod_pos_L = (m > 0 ? TD::Convert(sent.pos[m-1]) : kLEFT_POS); + const string& mod_pos_R = (m < sent.pos.size() - 1 ? TD::Convert(sent.pos[m]) : kRIGHT_POS); + const bool bdir = m < h; + const string dir = (bdir ? "MLeft" : "MRight"); + int v = m - h; + if (v < 0) { + v= -1 - int(log(-v) / log(1.6)); + } else { + v= int(log(v) / log(1.6)) + 1; + } + ostringstream os; + if (v < 0) os << "LenL" << -v; else os << "LenR" << v; + const string lenstr = os.str(); + Fire(features, dir); + Fire(features, lenstr); + // dir, lenstr + if (is_root) { + Fire(features, "wROOT", mod_word); + Fire(features, "pROOT", mod_pos); + Fire(features, "wpROOT", mod_word, mod_pos); + Fire(features, "DROOT", mod_pos, lenstr); + Fire(features, "LROOT", mod_pos_L); + Fire(features, "RROOT", mod_pos_R); + Fire(features, "LROOT", mod_pos_L, mod_pos); + Fire(features, "RROOT", mod_pos_R, mod_pos); + Fire(features, "LDist", m); + Fire(features, "RDist", num_words - m); + } else { // not root + const string& head_pos_L = (h > 0 ? TD::Convert(sent.pos[h-1]) : kLEFT_POS); + const string& head_pos_R = (h < sent.pos.size() - 1 ? TD::Convert(sent.pos[h]) : kRIGHT_POS); + SparseVector<double> fv; + SparseVector<double>* f = &fv; + Fire(f, "H", head_pos); + Fire(f, "M", mod_pos); + Fire(f, "HM", head_pos, mod_pos); + + // surrounders + Fire(f, "posLL", head_pos, mod_pos, head_pos_L, mod_pos_L); + Fire(f, "posRR", head_pos, mod_pos, head_pos_R, mod_pos_R); + Fire(f, "posLR", head_pos, mod_pos, head_pos_L, mod_pos_R); + Fire(f, "posRL", head_pos, mod_pos, head_pos_R, mod_pos_L); + + // between features + int left = min(h,m); + int right = max(h,m); + if (right - left >= 2) { + if (bdir) --right; else ++left; + for (map<WordID, vector<int> >::const_iterator it = pcs.begin(); it != pcs.end(); ++it) { + if (it->second[left] != it->second[right]) { + Fire(f, "BT", head_pos, TD::Convert(it->first), mod_pos); + } + } + } + + Fire(f, "wH", head_word); + Fire(f, "wM", mod_word); + Fire(f, "wpH", head_word, head_pos); + Fire(f, "wpM", mod_word, mod_pos); + Fire(f, "pHwM", head_pos, mod_word); + Fire(f, "wHpM", head_word, mod_pos); + + Fire(f, "wHM", head_word, mod_word); + Fire(f, "pHMwH", head_pos, mod_pos, head_word); + Fire(f, "pHMwM", head_pos, mod_pos, mod_word); + Fire(f, "wHMpH", head_word, mod_word, head_pos); + Fire(f, "wHMpM", head_word, mod_word, mod_pos); + Fire(f, "wHMpHM", head_word, mod_word, head_pos, mod_pos); + + AddConjoin(fv, dir, features); + AddConjoin(fv, lenstr, features); + (*features) += fv; + } + } +}; + +ArcFeatureFunctions::ArcFeatureFunctions() : pimpl(new ArcFFImpl) {} +ArcFeatureFunctions::~ArcFeatureFunctions() { delete pimpl; } + +void ArcFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) { + pimpl->PrepareForInput(sentence); +} + +void ArcFeatureFunctions::EdgeFeatures(const TaggedSentence& sentence, + short h, + short m, + SparseVector<weight_t>* features) const { + pimpl->EdgeFeatures(sentence, h, m, features); +} + diff --git a/rst_parser/arc_ff.h b/rst_parser/arc_ff.h new file mode 100644 index 00000000..52f311d2 --- /dev/null +++ b/rst_parser/arc_ff.h @@ -0,0 +1,28 @@ +#ifndef _ARC_FF_H_ +#define _ARC_FF_H_ + +#include <string> +#include "sparse_vector.h" +#include "weights.h" +#include "arc_factored.h" + +struct TaggedSentence; +struct ArcFFImpl; +class ArcFeatureFunctions { + public: + ArcFeatureFunctions(); + ~ArcFeatureFunctions(); + + // called once, per input, before any calls to EdgeFeatures + // used to initialize sentence-specific data structures + void PrepareForInput(const TaggedSentence& sentence); + + void EdgeFeatures(const TaggedSentence& sentence, + short h, + short m, + SparseVector<weight_t>* features) const; + private: + ArcFFImpl* pimpl; +}; + +#endif diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc new file mode 100644 index 00000000..ef97798b --- /dev/null +++ b/rst_parser/dep_training.cc @@ -0,0 +1,76 @@ +#include "dep_training.h" + +#include <vector> +#include <iostream> + +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "picojson.h" + +using namespace std; + +static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) { + picojson::value obj; + string err; + picojson::parse(obj, line.begin() + start, line.end(), &err); + if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } + TrainingInstance& cur = *out; + TaggedSentence& ts = cur.ts; + EdgeSubset& tree = cur.tree; + ts.pos.clear(); + ts.words.clear(); + tree.roots.clear(); + tree.h_m_pairs.clear(); + assert(obj.is<picojson::object>()); + const picojson::object& d = obj.get<picojson::object>(); + const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); + for (unsigned i = 0; i < ta.size(); ++i) { + ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>())); + ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>())); + } + if (d.find("deps") != d.end()) { + const picojson::array& da = d.find("deps")->second.get<picojson::array>(); + for (unsigned i = 0; i < da.size(); ++i) { + const picojson::array& thm = da[i].get<picojson::array>(); + // get dep type here + short h = thm[2].get<double>(); + short m = thm[1].get<double>(); + if (h < 0) + tree.roots.push_back(m); + else + tree.h_m_pairs.push_back(make_pair(h,m)); + } + } + //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; +} + +bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) { + string line; + if (!getline(*in, line)) return false; + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + static int lc = 0; ++lc; + ParseInstance(line, pos + 1, instance, lc); + return true; +} + +void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + bool flag = false; + while(getline(in, line)) { + ++lc; + if ((lc-1) % size != rank) continue; + if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + size_t pos = line.rfind('\t'); + assert(pos != string::npos); + corpus->push_back(TrainingInstance()); + ParseInstance(line, pos + 1, &corpus->back(), lc); + } + if (flag) cerr << "\nRead " << lc << " training instances\n"; +} + diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h new file mode 100644 index 00000000..3eeee22e --- /dev/null +++ b/rst_parser/dep_training.h @@ -0,0 +1,19 @@ +#ifndef _DEP_TRAINING_H_ +#define _DEP_TRAINING_H_ + +#include <iostream> +#include <string> +#include <vector> +#include "arc_factored.h" +#include "weights.h" + +struct TrainingInstance { + TaggedSentence ts; + EdgeSubset tree; + SparseVector<weight_t> features; + // reads a "Jsent" formatted dependency file + static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF + static void ReadTrainingCorpus(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1); +}; + +#endif diff --git a/rst_parser/global_ff.cc b/rst_parser/global_ff.cc new file mode 100644 index 00000000..ae410875 --- /dev/null +++ b/rst_parser/global_ff.cc @@ -0,0 +1,44 @@ +#include "global_ff.h" + +#include <iostream> +#include <sstream> + +#include "tdict.h" + +using namespace std; + +struct GFFImpl { + void PrepareForInput(const TaggedSentence& sentence) { + } + void Features(const TaggedSentence& sentence, + const EdgeSubset& tree, + SparseVector<double>* feats) const { + const vector<WordID>& words = sentence.words; + const vector<WordID>& tags = sentence.pos; + const vector<pair<short,short> >& hms = tree.h_m_pairs; + assert(words.size() == tags.size()); + vector<int> mods(words.size()); + for (int i = 0; i < hms.size(); ++i) { + mods[hms[i].first]++; // first = head, second = modifier + } + for (int i = 0; i < mods.size(); ++i) { + ostringstream os; + os << "NM:" << TD::Convert(tags[i]) << "_" << mods[i]; + feats->add_value(FD::Convert(os.str()), 1.0); + } + } +}; + +GlobalFeatureFunctions::GlobalFeatureFunctions() : pimpl(new GFFImpl) {} +GlobalFeatureFunctions::~GlobalFeatureFunctions() { delete pimpl; } + +void GlobalFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) { + pimpl->PrepareForInput(sentence); +} + +void GlobalFeatureFunctions::Features(const TaggedSentence& sentence, + const EdgeSubset& tree, + SparseVector<double>* feats) const { + pimpl->Features(sentence, tree, feats); +} + diff --git a/rst_parser/global_ff.h b/rst_parser/global_ff.h new file mode 100644 index 00000000..d71d0fa1 --- /dev/null +++ b/rst_parser/global_ff.h @@ -0,0 +1,18 @@ +#ifndef _GLOBAL_FF_H_ +#define _GLOBAL_FF_H_ + +#include "arc_factored.h" + +struct GFFImpl; +struct GlobalFeatureFunctions { + GlobalFeatureFunctions(); + ~GlobalFeatureFunctions(); + void PrepareForInput(const TaggedSentence& sentence); + void Features(const TaggedSentence& sentence, + const EdgeSubset& tree, + SparseVector<double>* feats) const; + private: + GFFImpl* pimpl; +}; + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 7b5af4c1..6332693e 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -1,12 +1,228 @@ #include "arc_factored.h" +#include <vector> #include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +// #define HAVE_THREAD 1 +#if HAVE_THREAD +#include <boost/thread.hpp> +#endif + +#include "arc_ff.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "dep_training.h" +#include "optimize.h" +#include "weights.h" using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") + ("weights,w",po::value<string>(), "Optional starting weights") + ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations") + ("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength") +#ifdef HAVE_CMPH + ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features") +#endif +#if HAVE_THREAD + ("threads,T",po::value<unsigned>()->default_value(1), "Number of threads") +#endif + ("correction_buffers,m", po::value<int>()->default_value(10), "LBFGS correction buffers"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value<string>(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void AddFeatures(double prob, const SparseVector<double>& fmap, vector<double>* g) { + for (SparseVector<double>::const_iterator it = fmap.begin(); it != fmap.end(); ++it) + (*g)[it->first] += it->second * prob; +} + +double ApplyRegularizationTerms(const double C, + const vector<double>& weights, + vector<double>* g) { + assert(weights.size() == g->size()); + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { +// const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + double& g_i = (*g)[i]; + reg += C * w_i * w_i; + g_i += 2 * C * w_i; + +// reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); +// g_i += 2 * T * (w_i - prev_w_i); + } + return reg; +} + +struct GradientWorker { + GradientWorker(int f, + int t, + vector<double>* w, + vector<TrainingInstance>* c, + vector<ArcFactoredForest>* fs) : obj(), weights(*w), from(f), to(t), corpus(*c), forests(*fs), g(w->size()) {} + void operator()() { + int every = (to - from) / 20; + if (!every) every++; + for (int i = from; i < to; ++i) { + if ((from == 0) && (i + 1) % every == 0) cerr << '.' << flush; + const int num_words = corpus[i].ts.words.size(); + forests[i].Reweight(weights); + prob_t z; + forests[i].EdgeMarginals(&z); + obj -= log(z); + //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl; + //cerr << " ZZ = " << zz << endl; + for (int h = -1; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + if (h == m) continue; + const ArcFactoredForest::Edge& edge = forests[i](h,m); + const SparseVector<weight_t>& fmap = edge.features; + double prob = edge.edge_prob.as_float(); + if (prob < -0.000001) { cerr << "Prob < 0: " << prob << endl; prob = 0; } + if (prob > 1.000001) { cerr << "Prob > 1: " << prob << endl; prob = 1; } + AddFeatures(prob, fmap, &g); + //mfm += fmap * prob; // DE + } + } + } + } + double obj; + vector<double>& weights; + const int from, to; + vector<TrainingInstance>& corpus; + vector<ArcFactoredForest>& forests; + vector<double> g; // local gradient +}; int main(int argc, char** argv) { - ArcFactoredForest af(5); - cerr << af(0,3) << endl; + int rank = 0; + int size = 1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>()); + cerr << " " << FD::NumFeats() << " features in map\n"; + } + ArcFeatureFunctions ffs; + vector<TrainingInstance> corpus; + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size); + vector<weight_t> weights; + Weights::InitFromFile(conf["weights"].as<string>(), &weights); + vector<ArcFactoredForest> forests(corpus.size()); + SparseVector<double> empirical; + cerr << "Extracting features...\n"; + bool flag = false; + for (int i = 0; i < corpus.size(); ++i) { + TrainingInstance& cur = corpus[i]; + if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } + if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } + ffs.PrepareForInput(cur.ts); + SparseVector<weight_t> efmap; + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; + } + empirical += cur.features; + forests[i].resize(cur.ts.words.size()); + forests[i].ExtractFeatures(cur.ts, ffs); + } + if (flag) cerr << endl; + //cerr << "EMP: " << empirical << endl; //DE + weights.resize(FD::NumFeats(), 0.0); + vector<weight_t> g(FD::NumFeats(), 0.0); + cerr << "features initialized\noptimizing...\n"; + boost::shared_ptr<BatchOptimizer> o; +#if HAVE_THREAD + unsigned threads = conf["threads"].as<unsigned>(); + if (threads > corpus.size()) threads = corpus.size(); +#else + const unsigned threads = 1; +#endif + int chunk = corpus.size() / threads; + o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as<int>())); + int iterations = 1000; + for (int iter = 0; iter < iterations; ++iter) { + cerr << "ITERATION " << iter << " " << flush; + fill(g.begin(), g.end(), 0.0); + for (SparseVector<double>::const_iterator it = empirical.begin(); it != empirical.end(); ++it) + g[it->first] = -it->second; + double obj = -empirical.dot(weights); + vector<boost::shared_ptr<GradientWorker> > jobs; + for (int from = 0; from < corpus.size(); from += chunk) { + int to = from + chunk; + if (to > corpus.size()) to = corpus.size(); + jobs.push_back(boost::shared_ptr<GradientWorker>(new GradientWorker(from, to, &weights, &corpus, &forests))); + } +#if HAVE_THREAD + boost::thread_group tg; + for (int i = 0; i < threads; ++i) + tg.create_thread(boost::ref(*jobs[i])); + tg.join_all(); +#else + (*jobs[0])(); +#endif + for (int i = 0; i < threads; ++i) { + obj += jobs[i]->obj; + vector<double>& tg = jobs[i]->g; + for (unsigned j = 0; j < g.size(); ++j) + g[j] += tg[j]; + } + // SparseVector<double> mfm; //DE + //cerr << endl << "E: " << empirical << endl; // DE + //cerr << "M: " << mfm << endl; // DE + double r = ApplyRegularizationTerms(conf["regularization_strength"].as<double>(), weights, &g); + double gnorm = 0; + for (int i = 0; i < g.size(); ++i) + gnorm += g[i]*g[i]; + ostringstream ll; + ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); + cerr << ' ' << ll.str().substr(ll.str().find('\t')+1) << endl; + obj += r; + assert(obj >= 0); + o->Optimize(obj, g, &weights); + Weights::ShowLargestFeatures(weights); + const bool converged = o->HasConverged(); + const char* ofname = converged ? "weights.final.gz" : "weights.cur.gz"; + if (converged || ((iter+1) % conf["output_every_i_iterations"].as<unsigned>()) == 0) { + cerr << "writing..." << flush; + const string sl = ll.str(); + Weights::WriteToFile(ofname, weights, true, &sl); + cerr << "done" << endl; + } + if (converged) { cerr << "CONVERGED\n"; break; } + } return 0; } diff --git a/rst_parser/picojson.h b/rst_parser/picojson.h new file mode 100644 index 00000000..bdb26057 --- /dev/null +++ b/rst_parser/picojson.h @@ -0,0 +1,979 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011 Kazuho Oku + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY CYBOZU LABS, INC. ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO + * EVENT SHALL CYBOZU LABS, INC. OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are + * those of the authors and should not be interpreted as representing official + * policies, either expressed or implied, of Cybozu Labs, Inc. + * + */ +#ifndef picojson_h +#define picojson_h + +#include <cassert> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <iterator> +#include <map> +#include <string> +#include <vector> + +#ifdef _MSC_VER + #define SNPRINTF _snprintf_s + #pragma warning(push) + #pragma warning(disable : 4244) // conversion from int to char +#else + #define SNPRINTF snprintf +#endif + +namespace picojson { + + enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type + }; + + struct null {}; + + class value { + public: + typedef std::vector<value> array; + typedef std::map<std::string, value> object; + protected: + int type_; + union { + bool boolean_; + double number_; + std::string* string_; + array* array_; + object* object_; + }; + public: + value(); + value(int type, bool); + explicit value(bool b); + explicit value(double n); + explicit value(const std::string& s); + explicit value(const array& a); + explicit value(const object& o); + explicit value(const char* s); + value(const char* s, size_t len); + ~value(); + value(const value& x); + value& operator=(const value& x); + template <typename T> bool is() const; + template <typename T> const T& get() const; + template <typename T> T& get(); + bool evaluate_as_boolean() const; + const value& get(size_t idx) const; + const value& get(const std::string& key) const; + bool contains(size_t idx) const; + bool contains(const std::string& key) const; + std::string to_str() const; + template <typename Iter> void serialize(Iter os) const; + std::string serialize() const; + private: + template <typename T> value(const T*); // intentionally defined to block implicit conversion of pointer to bool + }; + + typedef value::array array; + typedef value::object object; + + inline value::value() : type_(null_type) {} + + inline value::value(int type, bool) : type_(type) { + switch (type) { +#define INIT(p, v) case p##type: p = v; break + INIT(boolean_, false); + INIT(number_, 0.0); + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: break; + } + } + + inline value::value(bool b) : type_(boolean_type) { + boolean_ = b; + } + + inline value::value(double n) : type_(number_type) { + number_ = n; + } + + inline value::value(const std::string& s) : type_(string_type) { + string_ = new std::string(s); + } + + inline value::value(const array& a) : type_(array_type) { + array_ = new array(a); + } + + inline value::value(const object& o) : type_(object_type) { + object_ = new object(o); + } + + inline value::value(const char* s) : type_(string_type) { + string_ = new std::string(s); + } + + inline value::value(const char* s, size_t len) : type_(string_type) { + string_ = new std::string(s, len); + } + + inline value::~value() { + switch (type_) { +#define DEINIT(p) case p##type: delete p; break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: break; + } + } + + inline value::value(const value& x) : type_(x.type_) { + switch (type_) { +#define INIT(p, v) case p##type: p = v; break + INIT(boolean_, x.boolean_); + INIT(number_, x.number_); + INIT(string_, new std::string(*x.string_)); + INIT(array_, new array(*x.array_)); + INIT(object_, new object(*x.object_)); +#undef INIT + default: break; + } + } + + inline value& value::operator=(const value& x) { + if (this != &x) { + this->~value(); + new (this) value(x); + } + return *this; + } + +#define IS(ctype, jtype) \ + template <> inline bool value::is<ctype>() const { \ + return type_ == jtype##_type; \ + } + IS(null, null) + IS(bool, boolean) + IS(int, number) + IS(double, number) + IS(std::string, string) + IS(array, array) + IS(object, object) +#undef IS + +#define GET(ctype, var) \ + template <> inline const ctype& value::get<ctype>() const { \ + assert("type mismatch! call vis<type>() before get<type>()" \ + && is<ctype>()); \ + return var; \ + } \ + template <> inline ctype& value::get<ctype>() { \ + assert("type mismatch! call is<type>() before get<type>()" \ + && is<ctype>()); \ + return var; \ + } + GET(bool, boolean_) + GET(double, number_) + GET(std::string, *string_) + GET(array, *array_) + GET(object, *object_) +#undef GET + + inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return boolean_; + case number_type: + return number_ != 0; + case string_type: + return ! string_->empty(); + default: + return true; + } + } + + inline const value& value::get(size_t idx) const { + static value s_null; + assert(is<array>()); + return idx < array_->size() ? (*array_)[idx] : s_null; + } + + inline const value& value::get(const std::string& key) const { + static value s_null; + assert(is<object>()); + object::const_iterator i = object_->find(key); + return i != object_->end() ? i->second : s_null; + } + + inline bool value::contains(size_t idx) const { + assert(is<array>()); + return idx < array_->size(); + } + + inline bool value::contains(const std::string& key) const { + assert(is<object>()); + object::const_iterator i = object_->find(key); + return i != object_->end(); + } + + inline std::string value::to_str() const { + switch (type_) { + case null_type: return "null"; + case boolean_type: return boolean_ ? "true" : "false"; + case number_type: { + char buf[256]; + double tmp; + SNPRINTF(buf, sizeof(buf), modf(number_, &tmp) == 0 ? "%.f" : "%f", number_); + return buf; + } + case string_type: return *string_; + case array_type: return "array"; + case object_type: return "object"; + default: assert(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + } + + template <typename Iter> void copy(const std::string& s, Iter oi) { + std::copy(s.begin(), s.end(), oi); + } + + template <typename Iter> void serialize_str(const std::string& s, Iter oi) { + *oi++ = '"'; + for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) { + switch (*i) { +#define MAP(val, sym) case val: copy(sym, oi); break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if ((unsigned char)*i < 0x20 || *i == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", *i & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = *i; + } + break; + } + } + *oi++ = '"'; + } + + template <typename Iter> void value::serialize(Iter oi) const { + switch (type_) { + case string_type: + serialize_str(*string_, oi); + break; + case array_type: { + *oi++ = '['; + for (array::const_iterator i = array_->begin(); i != array_->end(); ++i) { + if (i != array_->begin()) { + *oi++ = ','; + } + i->serialize(oi); + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + for (object::const_iterator i = object_->begin(); + i != object_->end(); + ++i) { + if (i != object_->begin()) { + *oi++ = ','; + } + serialize_str(i->first, oi); + *oi++ = ':'; + i->second.serialize(oi); + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + } + + inline std::string value::serialize() const { + std::string s; + serialize(std::back_inserter(s)); + return s; + } + + template <typename Iter> class input { + protected: + Iter cur_, end_; + int last_ch_; + bool ungot_; + int line_; + public: + input(const Iter& first, const Iter& last) : cur_(first), end_(last), last_ch_(-1), ungot_(false), line_(1) {} + int getc() { + if (ungot_) { + ungot_ = false; + return last_ch_; + } + if (cur_ == end_) { + last_ch_ = -1; + return -1; + } + if (last_ch_ == '\n') { + line_++; + } + last_ch_ = *cur_++ & 0xff; + return last_ch_; + } + void ungetc() { + if (last_ch_ != -1) { + assert(! ungot_); + ungot_ = true; + } + } + Iter cur() const { return cur_; } + int line() const { return line_; } + void skip_ws() { + while (1) { + int ch = getc(); + if (! (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + int expect(int expect) { + skip_ws(); + if (getc() != expect) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string& pattern) { + for (std::string::const_iterator pi(pattern.begin()); + pi != pattern.end(); + ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } + }; + + template<typename Iter> inline int _parse_quadhex(input<Iter> &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; + } + + template<typename String, typename Iter> inline bool _parse_codepoint(String& out, input<Iter>& in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (! (0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(uni_ch); + } else { + if (uni_ch < 0x800) { + out.push_back(0xc0 | (uni_ch >> 6)); + } else { + if (uni_ch < 0x10000) { + out.push_back(0xe0 | (uni_ch >> 12)); + } else { + out.push_back(0xf0 | (uni_ch >> 18)); + out.push_back(0x80 | ((uni_ch >> 12) & 0x3f)); + } + out.push_back(0x80 | ((uni_ch >> 6) & 0x3f)); + } + out.push_back(0x80 | (uni_ch & 0x3f)); + } + return true; + } + + template<typename String, typename Iter> inline bool _parse_string(String& out, input<Iter>& in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) case sym: out.push_back(val); break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (! _parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(ch); + } + } + return false; + } + + template <typename Context, typename Iter> inline bool _parse_array(Context& ctx, input<Iter>& in) { + if (! ctx.parse_array_start()) { + return false; + } + if (in.expect(']')) { + return true; + } + size_t idx = 0; + do { + if (! ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']'); + } + + template <typename Context, typename Iter> inline bool _parse_object(Context& ctx, input<Iter>& in) { + if (! ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return true; + } + do { + std::string key; + if (! in.expect('"') + || ! _parse_string(key, in) + || ! in.expect(':')) { + return false; + } + if (! ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}'); + } + + template <typename Iter> inline bool _parse_number(double& out, input<Iter>& in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == '.' + || ch == 'e' || ch == 'E') { + num_str.push_back(ch); + } else { + in.ungetc(); + break; + } + } + char* endp; + out = strtod(num_str.c_str(), &endp); + return endp == num_str.c_str() + num_str.size(); + } + + template <typename Context, typename Iter> inline bool _parse(Context& ctx, input<Iter>& in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + in.ungetc(); + double f; + if (_parse_number(f, in)) { + ctx.set_number(f); + return true; + } else { + return false; + } + } + break; + } + in.ungetc(); + return false; + } + + class deny_parse_context { + public: + bool set_null() { return false; } + bool set_bool(bool) { return false; } + bool set_number(double) { return false; } + template <typename Iter> bool parse_string(input<Iter>&) { return false; } + bool parse_array_start() { return false; } + template <typename Iter> bool parse_array_item(input<Iter>&, size_t) { + return false; + } + bool parse_object_start() { return false; } + template <typename Iter> bool parse_object_item(input<Iter>&, const std::string&) { + return false; + } + }; + + class default_parse_context { + protected: + value* out_; + public: + default_parse_context(value* out) : out_(out) {} + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } + bool set_number(double f) { + *out_ = value(f); + return true; + } + template<typename Iter> bool parse_string(input<Iter>& in) { + *out_ = value(string_type, false); + return _parse_string(out_->get<std::string>(), in); + } + bool parse_array_start() { + *out_ = value(array_type, false); + return true; + } + template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) { + array& a = out_->get<array>(); + a.push_back(value()); + default_parse_context ctx(&a.back()); + return _parse(ctx, in); + } + bool parse_object_start() { + *out_ = value(object_type, false); + return true; + } + template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string& key) { + object& o = out_->get<object>(); + default_parse_context ctx(&o[key]); + return _parse(ctx, in); + } + private: + default_parse_context(const default_parse_context&); + default_parse_context& operator=(const default_parse_context&); + }; + + class null_parse_context { + public: + struct dummy_str { + void push_back(int) {} + }; + public: + null_parse_context() {} + bool set_null() { return true; } + bool set_bool(bool) { return true; } + bool set_number(double) { return true; } + template <typename Iter> bool parse_string(input<Iter>& in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { return true; } + template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) { + return _parse(*this, in); + } + bool parse_object_start() { return true; } + template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string&) { + return _parse(*this, in); + } + private: + null_parse_context(const null_parse_context&); + null_parse_context& operator=(const null_parse_context&); + }; + + // obsolete, use the version below + template <typename Iter> inline std::string parse(value& out, Iter& pos, const Iter& last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; + } + + template <typename Context, typename Iter> inline Iter _parse(Context& ctx, const Iter& first, const Iter& last, std::string* err) { + input<Iter> in(first, last); + if (! _parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(ch); + } + } + } + return in.cur(); + } + + template <typename Iter> inline Iter parse(value& out, const Iter& first, const Iter& last, std::string* err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); + } + + inline std::string parse(value& out, std::istream& is) { + std::string err; + parse(out, std::istreambuf_iterator<char>(is.rdbuf()), + std::istreambuf_iterator<char>(), &err); + return err; + } + + template <typename T> struct last_error_t { + static std::string s; + }; + template <typename T> std::string last_error_t<T>::s; + + inline void set_last_error(const std::string& s) { + last_error_t<bool>::s = s; + } + + inline const std::string& get_last_error() { + return last_error_t<bool>::s; + } + + inline bool operator==(const value& x, const value& y) { + if (x.is<null>()) + return y.is<null>(); +#define PICOJSON_CMP(type) \ + if (x.is<type>()) \ + return y.is<type>() && x.get<type>() == y.get<type>() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + assert(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; + } + + inline bool operator!=(const value& x, const value& y) { + return ! (x == y); + } +} + +inline std::istream& operator>>(std::istream& is, picojson::value& x) +{ + picojson::set_last_error(std::string()); + std::string err = picojson::parse(x, is); + if (! err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) +{ + x.serialize(std::ostream_iterator<char>(os)); + return os; +} +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +#endif +#ifdef TEST_PICOJSON +#ifdef _MSC_VER + #pragma warning(disable : 4127) // conditional expression is constant +#endif + +using namespace std; + +static void plan(int num) +{ + printf("1..%d\n", num); +} + +static bool success = true; + +static void ok(bool b, const char* name = "") +{ + static int n = 1; + if (! b) + success = false; + printf("%s %d - %s\n", b ? "ok" : "ng", n++, name); +} + +template <typename T> void is(const T& x, const T& y, const char* name = "") +{ + if (x == y) { + ok(true, name); + } else { + ok(false, name); + } +} + +#include <algorithm> + +int main(void) +{ + plan(75); + + // constructors +#define TEST(expr, expected) \ + is(picojson::value expr .serialize(), string(expected), "picojson::value" #expr) + + TEST( (true), "true"); + TEST( (false), "false"); + TEST( (42.0), "42"); + TEST( (string("hello")), "\"hello\""); + TEST( ("hello"), "\"hello\""); + TEST( ("hello", 4), "\"hell\""); + +#undef TEST + +#define TEST(in, type, cmp, serialize_test) { \ + picojson::value v; \ + const char* s = in; \ + string err = picojson::parse(v, s, s + strlen(s)); \ + ok(err.empty(), in " no error"); \ + ok(v.is<type>(), in " check type"); \ + is<type>(v.get<type>(), cmp, in " correct output"); \ + is(*s, '\0', in " read to eof"); \ + if (serialize_test) { \ + is(v.serialize(), string(in), in " serialize"); \ + } \ + } + TEST("false", bool, false, true); + TEST("true", bool, true, true); + TEST("90.5", double, 90.5, false); + TEST("\"hello\"", string, string("hello"), true); + TEST("\"\\\"\\\\\\/\\b\\f\\n\\r\\t\"", string, string("\"\\/\b\f\n\r\t"), + true); + TEST("\"\\u0061\\u30af\\u30ea\\u30b9\"", string, + string("a\xe3\x82\xaf\xe3\x83\xaa\xe3\x82\xb9"), false); + TEST("\"\\ud840\\udc0b\"", string, string("\xf0\xa0\x80\x8b"), false); +#undef TEST + +#define TEST(type, expr) { \ + picojson::value v; \ + const char *s = expr; \ + string err = picojson::parse(v, s, s + strlen(s)); \ + ok(err.empty(), "empty " #type " no error"); \ + ok(v.is<picojson::type>(), "empty " #type " check type"); \ + ok(v.get<picojson::type>().empty(), "check " #type " array size"); \ + } + TEST(array, "[]"); + TEST(object, "{}"); +#undef TEST + + { + picojson::value v; + const char *s = "[1,true,\"hello\"]"; + string err = picojson::parse(v, s, s + strlen(s)); + ok(err.empty(), "array no error"); + ok(v.is<picojson::array>(), "array check type"); + is(v.get<picojson::array>().size(), size_t(3), "check array size"); + ok(v.contains(0), "check contains array[0]"); + ok(v.get(0).is<double>(), "check array[0] type"); + is(v.get(0).get<double>(), 1.0, "check array[0] value"); + ok(v.contains(1), "check contains array[1]"); + ok(v.get(1).is<bool>(), "check array[1] type"); + ok(v.get(1).get<bool>(), "check array[1] value"); + ok(v.contains(2), "check contains array[2]"); + ok(v.get(2).is<string>(), "check array[2] type"); + is(v.get(2).get<string>(), string("hello"), "check array[2] value"); + ok(!v.contains(3), "check not contains array[3]"); + } + + { + picojson::value v; + const char *s = "{ \"a\": true }"; + string err = picojson::parse(v, s, s + strlen(s)); + ok(err.empty(), "object no error"); + ok(v.is<picojson::object>(), "object check type"); + is(v.get<picojson::object>().size(), size_t(1), "check object size"); + ok(v.contains("a"), "check contains property"); + ok(v.get("a").is<bool>(), "check bool property exists"); + is(v.get("a").get<bool>(), true, "check bool property value"); + is(v.serialize(), string("{\"a\":true}"), "serialize object"); + ok(!v.contains("z"), "check not contains property"); + } + +#define TEST(json, msg) do { \ + picojson::value v; \ + const char *s = json; \ + string err = picojson::parse(v, s, s + strlen(s)); \ + is(err, string("syntax error at line " msg), msg); \ + } while (0) + TEST("falsoa", "1 near: oa"); + TEST("{]", "1 near: ]"); + TEST("\n\bbell", "2 near: bell"); + TEST("\"abc\nd\"", "1 near: "); +#undef TEST + + { + picojson::value v1, v2; + const char *s; + string err; + s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; + err = picojson::parse(v1, s, s + strlen(s)); + s = "{ \"d\": 2.0, \"b\": true, \"a\": [1,2,\"three\"] }"; + err = picojson::parse(v2, s, s + strlen(s)); + ok((v1 == v2), "check == operator in deep comparison"); + } + + { + picojson::value v1, v2; + const char *s; + string err; + s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; + err = picojson::parse(v1, s, s + strlen(s)); + s = "{ \"d\": 2.0, \"a\": [1,\"three\"], \"b\": true }"; + err = picojson::parse(v2, s, s + strlen(s)); + ok((v1 != v2), "check != operator for array in deep comparison"); + } + + { + picojson::value v1, v2; + const char *s; + string err; + s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; + err = picojson::parse(v1, s, s + strlen(s)); + s = "{ \"d\": 2.0, \"a\": [1,2,\"three\"], \"b\": false }"; + err = picojson::parse(v2, s, s + strlen(s)); + ok((v1 != v2), "check != operator for object in deep comparison"); + } + + { + picojson::value v1, v2; + const char *s; + string err; + s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; + err = picojson::parse(v1, s, s + strlen(s)); + picojson::object& o = v1.get<picojson::object>(); + o.erase("b"); + picojson::array& a = o["a"].get<picojson::array>(); + picojson::array::iterator i; + i = std::remove(a.begin(), a.end(), picojson::value(std::string("three"))); + a.erase(i, a.end()); + s = "{ \"a\": [1,2], \"d\": 2 }"; + err = picojson::parse(v2, s, s + strlen(s)); + ok((v1 == v2), "check erase()"); + } + + ok(picojson::value(3.0).serialize() == "3", + "integral number should be serialized as a integer"); + + { + const char* s = "{ \"a\": [1,2], \"d\": 2 }"; + picojson::null_parse_context ctx; + string err; + picojson::_parse(ctx, s, s + strlen(s), &err); + ok(err.empty(), "null_parse_context"); + } + + return success ? 0 : 1; +} + +#endif diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index f6b295b3..bc91330b 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -2,6 +2,81 @@ using namespace std; -StochasticForest::StochasticForest(const ArcFactoredForest& af) { +// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time. +// this is an awesome algorithm +TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) { + // edges are directed from modifiers to heads, and finally to the root + vector<double> p; + for (int m = 1; m <= forest.size(); ++m) { +#if USE_ALIAS_SAMPLER + p.clear(); +#else + SampleSet<double>& ss = usucc[m]; +#endif + double z = 0; + for (int h = 0; h <= forest.size(); ++h) { + double u = forest(h-1,m-1).edge_prob.as_float(); + z += u; +#if USE_ALIAS_SAMPLER + p.push_back(u); +#else + ss.add(u); +#endif + } +#if USE_ALIAS_SAMPLER + for (int i = 0; i < p.size(); ++i) { p[i] /= z; } + usucc[m].Init(p); +#endif + } } +void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree, MT19937* prng) { + MT19937& rng = *prng; + const int r = 0; + bool success = false; + while (!success) { + int roots = 0; + tree->h_m_pairs.clear(); + tree->roots.clear(); + vector<int> next(forest.size() + 1, -1); + vector<char> in_tree(forest.size() + 1, 0); + in_tree[r] = 1; + //cerr << "Forest size: " << forest.size() << endl; + for (int i = 0; i <= forest.size(); ++i) { + //cerr << "Sampling starting at u=" << i << endl; + int u = i; + if (in_tree[u]) continue; + while(!in_tree[u]) { +#if USE_ALIAS_SAMPLER + next[u] = usucc[u].Draw(rng); +#else + next[u] = rng.SelectSample(usucc[u]); +#endif + u = next[u]; + } + u = i; + //cerr << (u-1); + int prev = u-1; + while(!in_tree[u]) { + in_tree[u] = true; + u = next[u]; + //cerr << " > " << (u-1); + if (u == r) { + ++roots; + tree->roots.push_back(prev); + } else { + tree->h_m_pairs.push_back(make_pair<short,short>(u-1,prev)); + } + prev = u-1; + } + //cerr << endl; + } + assert(roots > 0); + if (roots > 1) { + //cerr << "FAILURE\n"; + } else { + success = true; + } + } +}; + diff --git a/rst_parser/rst.h b/rst_parser/rst.h index 865871eb..8bf389f7 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -1,10 +1,21 @@ #ifndef _RST_H_ #define _RST_H_ +#include <vector> +#include "sampler.h" #include "arc_factored.h" +#include "alias_sampler.h" -struct StochasticForest { - explicit StochasticForest(const ArcFactoredForest& af); +struct TreeSampler { + explicit TreeSampler(const ArcFactoredForest& af); + void SampleRandomSpanningTree(EdgeSubset* tree, MT19937* rng); + const ArcFactoredForest& forest; +#define USE_ALIAS_SAMPLER 1 +#if USE_ALIAS_SAMPLER + std::vector<AliasSampler> usucc; +#else + std::vector<SampleSet<double> > usucc; +#endif }; #endif diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc new file mode 100644 index 00000000..9c42a8f4 --- /dev/null +++ b/rst_parser/rst_parse.cc @@ -0,0 +1,111 @@ +#include "arc_factored.h" + +#include <vector> +#include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)") + ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)") + ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)") + ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value<string>(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help") || conf->count("q_weights") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + vector<weight_t> qweights, pweights; + Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights); + if (conf.count("p_weights")) + Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights); + const bool global = pweights.size() > 0; + ArcFeatureFunctions ffs; + GlobalFeatureFunctions gff; + ReadFile rf(conf["input"].as<string>()); + istream* in = rf.stream(); + TrainingInstance sent; + MT19937 rng; + int samples = conf["samples"].as<unsigned>(); + int totroot = 0, root_right = 0, tot = 0, cor = 0; + while(TrainingInstance::ReadInstance(in, &sent)) { + ffs.PrepareForInput(sent.ts); + if (global) gff.PrepareForInput(sent.ts); + ArcFactoredForest forest(sent.ts.pos.size()); + forest.ExtractFeatures(sent.ts, ffs); + forest.Reweight(qweights); + TreeSampler ts(forest); + double best_score = -numeric_limits<double>::infinity(); + EdgeSubset best_tree; + for (int n = 0; n < samples; ++n) { + EdgeSubset tree; + ts.SampleRandomSpanningTree(&tree, &rng); + SparseVector<double> qfeats, gfeats; + tree.ExtractFeatures(sent.ts, ffs, &qfeats); + double score = 0; + if (global) { + gff.Features(sent.ts, tree, &gfeats); + score = (qfeats + gfeats).dot(pweights); + } else { + score = qfeats.dot(qweights); + } + if (score > best_score) { + best_tree = tree; + best_score = score; + } + } + cerr << "BEST SCORE: " << best_score << endl; + cout << best_tree << endl; + const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0; + if (sent_has_ref) { + map<pair<short,short>, bool> ref; + for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i) + ref[sent.tree.h_m_pairs[i]] = true; + int ref_root = sent.tree.roots.front(); + if (ref_root == best_tree.roots.front()) { ++root_right; } + ++totroot; + for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) { + if (ref[best_tree.h_m_pairs[i]]) { + ++cor; + } + ++tot; + } + } + } + cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl; + return 0; +} + diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc deleted file mode 100644 index e8fe706e..00000000 --- a/rst_parser/rst_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "arc_factored.h" - -#include <iostream> - -using namespace std; - -int main(int argc, char** argv) { - // John saw Mary - // (H -> M) - // (1 -> 2) 20 - // (1 -> 3) 3 - // (2 -> 1) 20 - // (2 -> 3) 30 - // (3 -> 2) 0 - // (3 -> 1) 11 - // (0, 2) 10 - // (0, 1) 9 - // (0, 3) 9 - ArcFactoredForest af(3); - af(1,2).edge_prob.logeq(20); - af(1,3).edge_prob.logeq(3); - af(2,1).edge_prob.logeq(20); - af(2,3).edge_prob.logeq(30); - af(3,2).edge_prob.logeq(0); - af(3,1).edge_prob.logeq(11); - af(0,2).edge_prob.logeq(10); - af(0,1).edge_prob.logeq(9); - af(0,3).edge_prob.logeq(9); - SpanningTree tree; - af.MaximumSpanningTree(&tree); - return 0; -} - diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc new file mode 100644 index 00000000..9b730f3d --- /dev/null +++ b/rst_parser/rst_train.cc @@ -0,0 +1,144 @@ +#include "arc_factored.h" + +#include <vector> +#include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + string cfg_file; + opts.add_options() + ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") + ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution") + ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value<string>(&cfg_file), "Configuration file") + ("help,?", "Print this help message and exit"); + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(dconfig_options).add(clo); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (cfg_file.size() > 0) { + ReadFile rf(cfg_file); + po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); + } + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + vector<weight_t> qweights(FD::NumFeats(), 0.0); + Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights); + vector<TrainingInstance> corpus; + ArcFeatureFunctions ffs; + GlobalFeatureFunctions gff; + TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); + vector<ArcFactoredForest> forests(corpus.size()); + vector<prob_t> zs(corpus.size()); + SparseVector<double> empirical; + bool flag = false; + for (int i = 0; i < corpus.size(); ++i) { + TrainingInstance& cur = corpus[i]; + if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } + if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } + SparseVector<weight_t> efmap; + ffs.PrepareForInput(cur.ts); + gff.PrepareForInput(cur.ts); + for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, + cur.tree.h_m_pairs[j].second, + &efmap); + cur.features += efmap; + } + for (int j = 0; j < cur.tree.roots.size(); ++j) { + efmap.clear(); + ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); + cur.features += efmap; + } + efmap.clear(); + gff.Features(cur.ts, cur.tree, &efmap); + cur.features += efmap; + empirical += cur.features; + forests[i].resize(cur.ts.words.size()); + forests[i].ExtractFeatures(cur.ts, ffs); + forests[i].Reweight(qweights); + forests[i].EdgeMarginals(&zs[i]); + zs[i] = prob_t::One() / zs[i]; + // cerr << zs[i] << endl; + forests[i].Reweight(qweights); // EdgeMarginals overwrites edge_prob + } + if (flag) cerr << endl; + MT19937 rng; + SparseVector<double> model_exp; + SparseVector<double> weights; + Weights::InitSparseVector(qweights, &weights); + int samples = conf["samples"].as<unsigned>(); + for (int i = 0; i < corpus.size(); ++i) { +#if 0 + forests[i].EdgeMarginals(); + model_exp.clear(); + for (int h = -1; h < num_words; ++h) { + for (int m = 0; m < num_words; ++m) { + if (h == m) continue; + const ArcFactoredForest::Edge& edge = forests[i](h,m); + const SparseVector<weight_t>& fmap = edge.features; + double prob = edge.edge_prob.as_float(); + model_exp += fmap * prob; + } + } + cerr << "TRUE EXP: " << model_exp << endl; + forests[i].Reweight(weights); +#endif + + TreeSampler ts(forests[i]); + prob_t zhat = prob_t::Zero(); + SparseVector<prob_t> sampled_exp; + for (int n = 0; n < samples; ++n) { + EdgeSubset tree; + ts.SampleRandomSpanningTree(&tree, &rng); + SparseVector<double> qfeats, gfeats; + tree.ExtractFeatures(corpus[i].ts, ffs, &qfeats); + prob_t u; u.logeq(qfeats.dot(qweights)); + const prob_t q = u / zs[i]; // proposal mass + gff.Features(corpus[i].ts, tree, &gfeats); + SparseVector<double> tot_feats = qfeats + gfeats; + u.logeq(tot_feats.dot(weights)); + prob_t w = u / q; + zhat += w; + for (SparseVector<double>::const_iterator it = tot_feats.begin(); it != tot_feats.end(); ++it) + sampled_exp.add_value(it->first, w * prob_t(it->second)); + } + sampled_exp /= zhat; + SparseVector<double> tot_m; + for (SparseVector<prob_t>::const_iterator it = sampled_exp.begin(); it != sampled_exp.end(); ++it) + tot_m.add_value(it->first, it->second.as_float()); + //cerr << "DIFF: " << (tot_m - corpus[i].features) << endl; + const double eta = 0.03; + weights -= (tot_m - corpus[i].features) * eta; + } + cerr << "WEIGHTS.\n"; + cerr << weights << endl; + return 0; +} + |