summaryrefslogtreecommitdiff
path: root/rst_parser
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser')
-rw-r--r--rst_parser/Makefile.am20
-rw-r--r--rst_parser/arc_factored.cc128
-rw-r--r--rst_parser/arc_factored.h82
-rw-r--r--rst_parser/arc_factored_marginals.cc58
-rw-r--r--rst_parser/arc_ff.cc183
-rw-r--r--rst_parser/arc_ff.h28
-rw-r--r--rst_parser/dep_training.cc76
-rw-r--r--rst_parser/dep_training.h19
-rw-r--r--rst_parser/global_ff.cc44
-rw-r--r--rst_parser/global_ff.h18
-rw-r--r--rst_parser/mst_train.cc220
-rw-r--r--rst_parser/picojson.h979
-rw-r--r--rst_parser/rst.cc77
-rw-r--r--rst_parser/rst.h15
-rw-r--r--rst_parser/rst_parse.cc111
-rw-r--r--rst_parser/rst_test.cc33
-rw-r--r--rst_parser/rst_train.cc144
17 files changed, 2159 insertions, 76 deletions
diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am
index e97ab5c5..4977f584 100644
--- a/rst_parser/Makefile.am
+++ b/rst_parser/Makefile.am
@@ -1,19 +1,17 @@
bin_PROGRAMS = \
- mst_train
-
-noinst_PROGRAMS = \
- rst_test
-
-TESTS = rst_test
+ mst_train rst_train rst_parse
noinst_LIBRARIES = librst.a
-librst_a_SOURCES = arc_factored.cc rst.cc
+librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc global_ff.cc
mst_train_SOURCES = mst_train.cc
-mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz
+
+rst_train_SOURCES = rst_train.cc
+rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
-rst_test_SOURCES = rst_test.cc
-rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+rst_parse_SOURCES = rst_parse.cc
+rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm
diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc
index 1e75600b..74bf7516 100644
--- a/rst_parser/arc_factored.cc
+++ b/rst_parser/arc_factored.cc
@@ -1,31 +1,151 @@
#include "arc_factored.h"
#include <set>
+#include <tr1/unordered_set>
#include <boost/pending/disjoint_sets.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "arc_ff.h"
using namespace std;
+using namespace std::tr1;
using namespace boost;
+void EdgeSubset::ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs,
+ SparseVector<double>* features) const {
+ SparseVector<weight_t> efmap;
+ for (int j = 0; j < h_m_pairs.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(sentence, h_m_pairs[j].first,
+ h_m_pairs[j].second,
+ &efmap);
+ (*features) += efmap;
+ }
+ for (int j = 0; j < roots.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(sentence, -1, roots[j], &efmap);
+ (*features) += efmap;
+ }
+}
+
+void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs) {
+ for (int m = 0; m < num_words_; ++m) {
+ for (int h = 0; h < num_words_; ++h) {
+ ffs.EdgeFeatures(sentence, h, m, &edges_(h,m).features);
+ }
+ ffs.EdgeFeatures(sentence, -1, m, &root_edges_[m].features);
+ }
+}
+
+void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const {
+ for (int m = 0; m < num_words_; ++m) {
+ int best_head = -2;
+ prob_t best_score;
+ for (int h = -1; h < num_words_; ++h) {
+ const Edge& edge = (*this)(h,m);
+ if (best_head < -1 || edge.edge_prob > best_score) {
+ best_score = edge.edge_prob;
+ best_head = h;
+ }
+ }
+ assert(best_head >= -1);
+ if (best_head >= 0)
+ st->h_m_pairs.push_back(make_pair<short,short>(best_head, m));
+ else
+ st->roots.push_back(m);
+ }
+}
+
+struct WeightedEdge {
+ WeightedEdge() : h(), m(), weight() {}
+ WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {}
+ short h, m;
+ float weight;
+ inline bool operator==(const WeightedEdge& o) const {
+ return h == o.h && m == o.m && weight == o.weight;
+ }
+ inline bool operator!=(const WeightedEdge& o) const {
+ return h != o.h || m != o.m || weight != o.weight;
+ }
+};
+inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; }
+inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); }
+
+
+struct PriorityQueue {
+ void push(const WeightedEdge& e) {}
+ const WeightedEdge& top() const {
+ static WeightedEdge w(1,2,3);
+ return w;
+ }
+ void pop() {}
+ void increment_all(float p) {}
+};
+
// based on Trajan 1977
-void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const {
+void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const {
typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
find_with_full_path_compression> DisjointSet;
DisjointSet strongly(num_words_ + 1);
DisjointSet weakly(num_words_ + 1);
- set<unsigned> roots, h, rset;
- vector<pair<short, short> > enter(num_words_ + 1);
+ set<unsigned> roots, rset;
+ unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h;
+ vector<PriorityQueue> qs(num_words_ + 1);
+ vector<WeightedEdge> enter(num_words_ + 1);
+ vector<unsigned> mins(num_words_ + 1);
+ const WeightedEdge kDUMMY(0,0,0.0f);
for (unsigned i = 0; i <= num_words_; ++i) {
+ if (i > 0) {
+ // I(i) incidence on i -- all incoming edges
+ for (unsigned j = 0; j <= num_words_; ++j) {
+ qs[i].push(WeightedEdge(j, i, Weight(j,i)));
+ }
+ }
strongly.make_set(i);
weakly.make_set(i);
roots.insert(i);
+ enter[i] = kDUMMY;
+ mins[i] = i;
}
while(!roots.empty()) {
set<unsigned>::iterator it = roots.begin();
const unsigned k = *it;
roots.erase(it);
cerr << "k=" << k << endl;
- pair<short,short> ij; // TODO = Max(k);
+ WeightedEdge ij = qs[k].top(); // MAX(k)
+ qs[k].pop();
+ if (ij.weight <= 0) {
+ rset.insert(k);
+ } else {
+ if (strongly.find_set(ij.h) == k) {
+ roots.insert(k);
+ } else {
+ h.insert(ij);
+ if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) {
+ weakly.union_set(ij.h, ij.m);
+ enter[k] = ij;
+ } else {
+ unsigned vertex = 0;
+ float val = 99999999999;
+ WeightedEdge xy = ij;
+ while(xy != kDUMMY) {
+ if (xy.weight < val) {
+ val = xy.weight;
+ vertex = strongly.find_set(xy.m);
+ }
+ xy = enter[strongly.find_set(xy.h)];
+ }
+ qs[k].increment_all(val - ij.weight);
+ mins[k] = mins[vertex];
+ xy = enter[strongly.find_set(ij.h)];
+ while (xy != kDUMMY) {
+ }
+ }
+ }
+ }
}
}
diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h
index e99be482..c5481d80 100644
--- a/rst_parser/arc_factored.h
+++ b/rst_parser/arc_factored.h
@@ -5,37 +5,65 @@
#include <cassert>
#include <vector>
#include <utility>
+#include <boost/shared_ptr.hpp>
#include "array2d.h"
#include "sparse_vector.h"
#include "prob.h"
#include "weights.h"
+#include "wordid.h"
-struct SpanningTree {
- SpanningTree() : roots(1, -1) {}
+struct TaggedSentence {
+ std::vector<WordID> words;
+ std::vector<WordID> pos;
+};
+
+struct ArcFeatureFunctions;
+struct EdgeSubset {
+ EdgeSubset() {}
std::vector<short> roots; // unless multiroot trees are supported, this
// will have a single member
- std::vector<std::pair<short, short> > h_m_pairs;
+ std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0
+ // assumes ArcFeatureFunction::PrepareForInput has already been called
+ void ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs,
+ SparseVector<double>* features) const;
};
class ArcFactoredForest {
public:
- explicit ArcFactoredForest(short num_words) :
- num_words_(num_words),
- root_edges_(num_words),
- edges_(num_words, num_words) {
+ ArcFactoredForest() : num_words_() {}
+ explicit ArcFactoredForest(short num_words) : num_words_(num_words) {
+ resize(num_words);
+ }
+
+ unsigned size() const { return num_words_; }
+
+ void resize(unsigned num_words) {
+ num_words_ = num_words;
+ root_edges_.clear();
+ edges_.clear();
+ root_edges_.resize(num_words);
+ edges_.resize(num_words, num_words);
for (int h = 0; h < num_words; ++h) {
for (int m = 0; m < num_words; ++m) {
- edges_(h, m).h = h + 1;
- edges_(h, m).m = m + 1;
+ edges_(h, m).h = h;
+ edges_(h, m).m = m;
}
- root_edges_[h].h = 0;
- root_edges_[h].m = h + 1;
+ root_edges_[h].h = -1;
+ root_edges_[h].m = h;
}
}
// compute the maximum spanning tree based on the current weighting
// using the O(n^2) CLE algorithm
- void MaximumSpanningTree(SpanningTree* st) const;
+ void MaximumSpanningTree(EdgeSubset* st) const;
+
+ // Reweight edges so that edge_prob is the edge's marginals
+ // optionally returns log partition
+ void EdgeMarginals(prob_t* p_log_z = NULL);
+
+ // This may not return a tree
+ void PickBestParentForEachWord(EdgeSubset* st) const;
struct Edge {
Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {}
@@ -45,20 +73,20 @@ class ArcFactoredForest {
prob_t edge_prob;
};
+ // set eges_[*].features
+ void ExtractFeatures(const TaggedSentence& sentence,
+ const ArcFeatureFunctions& ffs);
+
const Edge& operator()(short h, short m) const {
- assert(m > 0);
- assert(m <= num_words_);
- assert(h >= 0);
- assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
+ return h >= 0 ? edges_(h, m) : root_edges_[m];
}
Edge& operator()(short h, short m) {
- assert(m > 0);
- assert(m <= num_words_);
- assert(h >= 0);
- assert(h <= num_words_);
- return h ? edges_(h - 1, m - 1) : root_edges_[m - 1];
+ return h >= 0 ? edges_(h, m) : root_edges_[m];
+ }
+
+ float Weight(short h, short m) const {
+ return log((*this)(h,m).edge_prob);
}
template <class V>
@@ -76,7 +104,7 @@ class ArcFactoredForest {
}
private:
- unsigned num_words_;
+ int num_words_;
std::vector<Edge> root_edges_;
Array2D<Edge> edges_;
};
@@ -85,4 +113,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&
return os << "(" << edge.h << " < " << edge.m << ")";
}
+inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) {
+ for (unsigned i = 0; i < ss.roots.size(); ++i)
+ os << "ROOT < " << ss.roots[i] << std::endl;
+ for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i)
+ os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl;
+ return os;
+}
+
#endif
diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc
new file mode 100644
index 00000000..3e8c9f86
--- /dev/null
+++ b/rst_parser/arc_factored_marginals.cc
@@ -0,0 +1,58 @@
+#include "arc_factored.h"
+
+#include <iostream>
+
+#include "config.h"
+
+using namespace std;
+
+#if HAVE_EIGEN
+
+#include <Eigen/Dense>
+typedef Eigen::Matrix<prob_t, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix;
+typedef Eigen::Matrix<prob_t, Eigen::Dynamic, 1> RootVector;
+
+void ArcFactoredForest::EdgeMarginals(prob_t *plog_z) {
+ ArcMatrix A(num_words_,num_words_);
+ RootVector r(num_words_);
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ if (h != m)
+ A(h,m) = edges_(h,m).edge_prob;
+ else
+ A(h,m) = prob_t::Zero();
+ }
+ r(h) = root_edges_[h].edge_prob;
+ }
+
+ ArcMatrix L = -A;
+ L.diagonal() = A.colwise().sum();
+ L.row(0) = r;
+ ArcMatrix Linv = L.inverse();
+ if (plog_z) *plog_z = Linv.determinant();
+ RootVector rootMarginals = r.cwiseProduct(Linv.col(0));
+ static const prob_t ZERO(0);
+ static const prob_t ONE(1);
+// ArcMatrix T = Linv;
+ for (int h = 0; h < num_words_; ++h) {
+ for (int m = 0; m < num_words_; ++m) {
+ const prob_t marginal = (m == 0 ? ZERO : ONE) * A(h,m) * Linv(m,m) -
+ (h == 0 ? ZERO : ONE) * A(h,m) * Linv(m,h);
+ edges_(h,m).edge_prob = marginal;
+// T(h,m) = marginal;
+ }
+ root_edges_[h].edge_prob = rootMarginals(h);
+ }
+// cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl;
+// cerr << "M:\n" << T << endl;
+}
+
+#else
+
+void ArcFactoredForest::EdgeMarginals(prob_t *) {
+ cerr << "EdgeMarginals() requires --with-eigen!\n";
+ abort();
+}
+
+#endif
+
diff --git a/rst_parser/arc_ff.cc b/rst_parser/arc_ff.cc
new file mode 100644
index 00000000..c4e5aa17
--- /dev/null
+++ b/rst_parser/arc_ff.cc
@@ -0,0 +1,183 @@
+#include "arc_ff.h"
+
+#include <iostream>
+#include <sstream>
+
+#include "stringlib.h"
+#include "tdict.h"
+#include "fdict.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+struct ArcFFImpl {
+ ArcFFImpl() : kROOT("ROOT"), kLEFT_POS("LEFT"), kRIGHT_POS("RIGHT") {}
+ const string kROOT;
+ const string kLEFT_POS;
+ const string kRIGHT_POS;
+ map<WordID, vector<int> > pcs;
+
+ void PrepareForInput(const TaggedSentence& sent) {
+ pcs.clear();
+ for (int i = 0; i < sent.pos.size(); ++i)
+ pcs[sent.pos[i]].resize(1, 0);
+ pcs[sent.pos[0]][0] = 1;
+ for (int i = 1; i < sent.pos.size(); ++i) {
+ const WordID posi = sent.pos[i];
+ for (map<WordID, vector<int> >::iterator j = pcs.begin(); j != pcs.end(); ++j) {
+ const WordID posj = j->first;
+ vector<int>& cs = j->second;
+ cs.push_back(cs.back() + (posj == posi ? 1 : 0));
+ }
+ }
+ }
+
+ template <typename A>
+ static void Fire(SparseVector<weight_t>* v, const A& a) {
+ ostringstream os;
+ os << a;
+ v->set_value(FD::Convert(os.str()), 1);
+ }
+
+ template <typename A, typename B>
+ static void Fire(SparseVector<weight_t>* v, const A& a, const B& b) {
+ ostringstream os;
+ os << a << ':' << b;
+ v->set_value(FD::Convert(os.str()), 1);
+ }
+
+ template <typename A, typename B, typename C>
+ static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c) {
+ ostringstream os;
+ os << a << ':' << b << '_' << c;
+ v->set_value(FD::Convert(os.str()), 1);
+ }
+
+ template <typename A, typename B, typename C, typename D>
+ static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d) {
+ ostringstream os;
+ os << a << ':' << b << '_' << c << '_' << d;
+ v->set_value(FD::Convert(os.str()), 1);
+ }
+
+ template <typename A, typename B, typename C, typename D, typename E>
+ static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d, const E& e) {
+ ostringstream os;
+ os << a << ':' << b << '_' << c << '_' << d << '_' << e;
+ v->set_value(FD::Convert(os.str()), 1);
+ }
+
+ static void AddConjoin(const SparseVector<double>& v, const string& feat, SparseVector<double>* pf) {
+ for (SparseVector<double>::const_iterator it = v.begin(); it != v.end(); ++it)
+ pf->set_value(FD::Convert(FD::Convert(it->first) + "_" + feat), it->second);
+ }
+
+ static inline string Fixup(const string& str) {
+ string res = LowercaseString(str);
+ if (res.size() < 6) return res;
+ return res.substr(0, 5) + "*";
+ }
+
+ static inline string Suffix(const string& str) {
+ if (str.size() < 4) return ""; else return str.substr(str.size() - 3);
+ }
+
+ void EdgeFeatures(const TaggedSentence& sent,
+ short h,
+ short m,
+ SparseVector<weight_t>* features) const {
+ const bool is_root = (h == -1);
+ const string head_word = (is_root ? kROOT : Fixup(TD::Convert(sent.words[h])));
+ int num_words = sent.words.size();
+ const string& head_pos = (is_root ? kROOT : TD::Convert(sent.pos[h]));
+ const string mod_word = Fixup(TD::Convert(sent.words[m]));
+ const string& mod_pos = TD::Convert(sent.pos[m]);
+ const string& mod_pos_L = (m > 0 ? TD::Convert(sent.pos[m-1]) : kLEFT_POS);
+ const string& mod_pos_R = (m < sent.pos.size() - 1 ? TD::Convert(sent.pos[m]) : kRIGHT_POS);
+ const bool bdir = m < h;
+ const string dir = (bdir ? "MLeft" : "MRight");
+ int v = m - h;
+ if (v < 0) {
+ v= -1 - int(log(-v) / log(1.6));
+ } else {
+ v= int(log(v) / log(1.6)) + 1;
+ }
+ ostringstream os;
+ if (v < 0) os << "LenL" << -v; else os << "LenR" << v;
+ const string lenstr = os.str();
+ Fire(features, dir);
+ Fire(features, lenstr);
+ // dir, lenstr
+ if (is_root) {
+ Fire(features, "wROOT", mod_word);
+ Fire(features, "pROOT", mod_pos);
+ Fire(features, "wpROOT", mod_word, mod_pos);
+ Fire(features, "DROOT", mod_pos, lenstr);
+ Fire(features, "LROOT", mod_pos_L);
+ Fire(features, "RROOT", mod_pos_R);
+ Fire(features, "LROOT", mod_pos_L, mod_pos);
+ Fire(features, "RROOT", mod_pos_R, mod_pos);
+ Fire(features, "LDist", m);
+ Fire(features, "RDist", num_words - m);
+ } else { // not root
+ const string& head_pos_L = (h > 0 ? TD::Convert(sent.pos[h-1]) : kLEFT_POS);
+ const string& head_pos_R = (h < sent.pos.size() - 1 ? TD::Convert(sent.pos[h]) : kRIGHT_POS);
+ SparseVector<double> fv;
+ SparseVector<double>* f = &fv;
+ Fire(f, "H", head_pos);
+ Fire(f, "M", mod_pos);
+ Fire(f, "HM", head_pos, mod_pos);
+
+ // surrounders
+ Fire(f, "posLL", head_pos, mod_pos, head_pos_L, mod_pos_L);
+ Fire(f, "posRR", head_pos, mod_pos, head_pos_R, mod_pos_R);
+ Fire(f, "posLR", head_pos, mod_pos, head_pos_L, mod_pos_R);
+ Fire(f, "posRL", head_pos, mod_pos, head_pos_R, mod_pos_L);
+
+ // between features
+ int left = min(h,m);
+ int right = max(h,m);
+ if (right - left >= 2) {
+ if (bdir) --right; else ++left;
+ for (map<WordID, vector<int> >::const_iterator it = pcs.begin(); it != pcs.end(); ++it) {
+ if (it->second[left] != it->second[right]) {
+ Fire(f, "BT", head_pos, TD::Convert(it->first), mod_pos);
+ }
+ }
+ }
+
+ Fire(f, "wH", head_word);
+ Fire(f, "wM", mod_word);
+ Fire(f, "wpH", head_word, head_pos);
+ Fire(f, "wpM", mod_word, mod_pos);
+ Fire(f, "pHwM", head_pos, mod_word);
+ Fire(f, "wHpM", head_word, mod_pos);
+
+ Fire(f, "wHM", head_word, mod_word);
+ Fire(f, "pHMwH", head_pos, mod_pos, head_word);
+ Fire(f, "pHMwM", head_pos, mod_pos, mod_word);
+ Fire(f, "wHMpH", head_word, mod_word, head_pos);
+ Fire(f, "wHMpM", head_word, mod_word, mod_pos);
+ Fire(f, "wHMpHM", head_word, mod_word, head_pos, mod_pos);
+
+ AddConjoin(fv, dir, features);
+ AddConjoin(fv, lenstr, features);
+ (*features) += fv;
+ }
+ }
+};
+
+ArcFeatureFunctions::ArcFeatureFunctions() : pimpl(new ArcFFImpl) {}
+ArcFeatureFunctions::~ArcFeatureFunctions() { delete pimpl; }
+
+void ArcFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) {
+ pimpl->PrepareForInput(sentence);
+}
+
+void ArcFeatureFunctions::EdgeFeatures(const TaggedSentence& sentence,
+ short h,
+ short m,
+ SparseVector<weight_t>* features) const {
+ pimpl->EdgeFeatures(sentence, h, m, features);
+}
+
diff --git a/rst_parser/arc_ff.h b/rst_parser/arc_ff.h
new file mode 100644
index 00000000..52f311d2
--- /dev/null
+++ b/rst_parser/arc_ff.h
@@ -0,0 +1,28 @@
+#ifndef _ARC_FF_H_
+#define _ARC_FF_H_
+
+#include <string>
+#include "sparse_vector.h"
+#include "weights.h"
+#include "arc_factored.h"
+
+struct TaggedSentence;
+struct ArcFFImpl;
+class ArcFeatureFunctions {
+ public:
+ ArcFeatureFunctions();
+ ~ArcFeatureFunctions();
+
+ // called once, per input, before any calls to EdgeFeatures
+ // used to initialize sentence-specific data structures
+ void PrepareForInput(const TaggedSentence& sentence);
+
+ void EdgeFeatures(const TaggedSentence& sentence,
+ short h,
+ short m,
+ SparseVector<weight_t>* features) const;
+ private:
+ ArcFFImpl* pimpl;
+};
+
+#endif
diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc
new file mode 100644
index 00000000..ef97798b
--- /dev/null
+++ b/rst_parser/dep_training.cc
@@ -0,0 +1,76 @@
+#include "dep_training.h"
+
+#include <vector>
+#include <iostream>
+
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "picojson.h"
+
+using namespace std;
+
+static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) {
+ picojson::value obj;
+ string err;
+ picojson::parse(obj, line.begin() + start, line.end(), &err);
+ if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); }
+ TrainingInstance& cur = *out;
+ TaggedSentence& ts = cur.ts;
+ EdgeSubset& tree = cur.tree;
+ ts.pos.clear();
+ ts.words.clear();
+ tree.roots.clear();
+ tree.h_m_pairs.clear();
+ assert(obj.is<picojson::object>());
+ const picojson::object& d = obj.get<picojson::object>();
+ const picojson::array& ta = d.find("tokens")->second.get<picojson::array>();
+ for (unsigned i = 0; i < ta.size(); ++i) {
+ ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>()));
+ ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>()));
+ }
+ if (d.find("deps") != d.end()) {
+ const picojson::array& da = d.find("deps")->second.get<picojson::array>();
+ for (unsigned i = 0; i < da.size(); ++i) {
+ const picojson::array& thm = da[i].get<picojson::array>();
+ // get dep type here
+ short h = thm[2].get<double>();
+ short m = thm[1].get<double>();
+ if (h < 0)
+ tree.roots.push_back(m);
+ else
+ tree.h_m_pairs.push_back(make_pair(h,m));
+ }
+ }
+ //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl;
+}
+
+bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) {
+ string line;
+ if (!getline(*in, line)) return false;
+ size_t pos = line.rfind('\t');
+ assert(pos != string::npos);
+ static int lc = 0; ++lc;
+ ParseInstance(line, pos + 1, instance, lc);
+ return true;
+}
+
+void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) {
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ string line;
+ int lc = 0;
+ bool flag = false;
+ while(getline(in, line)) {
+ ++lc;
+ if ((lc-1) % size != rank) continue;
+ if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; }
+ if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
+ size_t pos = line.rfind('\t');
+ assert(pos != string::npos);
+ corpus->push_back(TrainingInstance());
+ ParseInstance(line, pos + 1, &corpus->back(), lc);
+ }
+ if (flag) cerr << "\nRead " << lc << " training instances\n";
+}
+
diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h
new file mode 100644
index 00000000..3eeee22e
--- /dev/null
+++ b/rst_parser/dep_training.h
@@ -0,0 +1,19 @@
+#ifndef _DEP_TRAINING_H_
+#define _DEP_TRAINING_H_
+
+#include <iostream>
+#include <string>
+#include <vector>
+#include "arc_factored.h"
+#include "weights.h"
+
+struct TrainingInstance {
+ TaggedSentence ts;
+ EdgeSubset tree;
+ SparseVector<weight_t> features;
+ // reads a "Jsent" formatted dependency file
+ static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF
+ static void ReadTrainingCorpus(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1);
+};
+
+#endif
diff --git a/rst_parser/global_ff.cc b/rst_parser/global_ff.cc
new file mode 100644
index 00000000..ae410875
--- /dev/null
+++ b/rst_parser/global_ff.cc
@@ -0,0 +1,44 @@
+#include "global_ff.h"
+
+#include <iostream>
+#include <sstream>
+
+#include "tdict.h"
+
+using namespace std;
+
+struct GFFImpl {
+ void PrepareForInput(const TaggedSentence& sentence) {
+ }
+ void Features(const TaggedSentence& sentence,
+ const EdgeSubset& tree,
+ SparseVector<double>* feats) const {
+ const vector<WordID>& words = sentence.words;
+ const vector<WordID>& tags = sentence.pos;
+ const vector<pair<short,short> >& hms = tree.h_m_pairs;
+ assert(words.size() == tags.size());
+ vector<int> mods(words.size());
+ for (int i = 0; i < hms.size(); ++i) {
+ mods[hms[i].first]++; // first = head, second = modifier
+ }
+ for (int i = 0; i < mods.size(); ++i) {
+ ostringstream os;
+ os << "NM:" << TD::Convert(tags[i]) << "_" << mods[i];
+ feats->add_value(FD::Convert(os.str()), 1.0);
+ }
+ }
+};
+
+GlobalFeatureFunctions::GlobalFeatureFunctions() : pimpl(new GFFImpl) {}
+GlobalFeatureFunctions::~GlobalFeatureFunctions() { delete pimpl; }
+
+void GlobalFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) {
+ pimpl->PrepareForInput(sentence);
+}
+
+void GlobalFeatureFunctions::Features(const TaggedSentence& sentence,
+ const EdgeSubset& tree,
+ SparseVector<double>* feats) const {
+ pimpl->Features(sentence, tree, feats);
+}
+
diff --git a/rst_parser/global_ff.h b/rst_parser/global_ff.h
new file mode 100644
index 00000000..d71d0fa1
--- /dev/null
+++ b/rst_parser/global_ff.h
@@ -0,0 +1,18 @@
+#ifndef _GLOBAL_FF_H_
+#define _GLOBAL_FF_H_
+
+#include "arc_factored.h"
+
+struct GFFImpl;
+struct GlobalFeatureFunctions {
+ GlobalFeatureFunctions();
+ ~GlobalFeatureFunctions();
+ void PrepareForInput(const TaggedSentence& sentence);
+ void Features(const TaggedSentence& sentence,
+ const EdgeSubset& tree,
+ SparseVector<double>* feats) const;
+ private:
+ GFFImpl* pimpl;
+};
+
+#endif
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index 7b5af4c1..6332693e 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -1,12 +1,228 @@
#include "arc_factored.h"
+#include <vector>
#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+// #define HAVE_THREAD 1
+#if HAVE_THREAD
+#include <boost/thread.hpp>
+#endif
+
+#include "arc_ff.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "dep_training.h"
+#include "optimize.h"
+#include "weights.h"
using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ string cfg_file;
+ opts.add_options()
+ ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)")
+ ("weights,w",po::value<string>(), "Optional starting weights")
+ ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations")
+ ("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength")
+#ifdef HAVE_CMPH
+ ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
+#endif
+#if HAVE_THREAD
+ ("threads,T",po::value<unsigned>()->default_value(1), "Number of threads")
+#endif
+ ("correction_buffers,m", po::value<int>()->default_value(10), "LBFGS correction buffers");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(&cfg_file), "Configuration file")
+ ("help,?", "Print this help message and exit");
+
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(dconfig_options).add(clo);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (cfg_file.size() > 0) {
+ ReadFile rf(cfg_file);
+ po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+ }
+ if (conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void AddFeatures(double prob, const SparseVector<double>& fmap, vector<double>* g) {
+ for (SparseVector<double>::const_iterator it = fmap.begin(); it != fmap.end(); ++it)
+ (*g)[it->first] += it->second * prob;
+}
+
+double ApplyRegularizationTerms(const double C,
+ const vector<double>& weights,
+ vector<double>* g) {
+ assert(weights.size() == g->size());
+ double reg = 0;
+ for (size_t i = 0; i < weights.size(); ++i) {
+// const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0);
+ const double& w_i = weights[i];
+ double& g_i = (*g)[i];
+ reg += C * w_i * w_i;
+ g_i += 2 * C * w_i;
+
+// reg += T * (w_i - prev_w_i) * (w_i - prev_w_i);
+// g_i += 2 * T * (w_i - prev_w_i);
+ }
+ return reg;
+}
+
+struct GradientWorker {
+ GradientWorker(int f,
+ int t,
+ vector<double>* w,
+ vector<TrainingInstance>* c,
+ vector<ArcFactoredForest>* fs) : obj(), weights(*w), from(f), to(t), corpus(*c), forests(*fs), g(w->size()) {}
+ void operator()() {
+ int every = (to - from) / 20;
+ if (!every) every++;
+ for (int i = from; i < to; ++i) {
+ if ((from == 0) && (i + 1) % every == 0) cerr << '.' << flush;
+ const int num_words = corpus[i].ts.words.size();
+ forests[i].Reweight(weights);
+ prob_t z;
+ forests[i].EdgeMarginals(&z);
+ obj -= log(z);
+ //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl;
+ //cerr << " ZZ = " << zz << endl;
+ for (int h = -1; h < num_words; ++h) {
+ for (int m = 0; m < num_words; ++m) {
+ if (h == m) continue;
+ const ArcFactoredForest::Edge& edge = forests[i](h,m);
+ const SparseVector<weight_t>& fmap = edge.features;
+ double prob = edge.edge_prob.as_float();
+ if (prob < -0.000001) { cerr << "Prob < 0: " << prob << endl; prob = 0; }
+ if (prob > 1.000001) { cerr << "Prob > 1: " << prob << endl; prob = 1; }
+ AddFeatures(prob, fmap, &g);
+ //mfm += fmap * prob; // DE
+ }
+ }
+ }
+ }
+ double obj;
+ vector<double>& weights;
+ const int from, to;
+ vector<TrainingInstance>& corpus;
+ vector<ArcFactoredForest>& forests;
+ vector<double> g; // local gradient
+};
int main(int argc, char** argv) {
- ArcFactoredForest af(5);
- cerr << af(0,3) << endl;
+ int rank = 0;
+ int size = 1;
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
+ }
+ ArcFeatureFunctions ffs;
+ vector<TrainingInstance> corpus;
+ TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size);
+ vector<weight_t> weights;
+ Weights::InitFromFile(conf["weights"].as<string>(), &weights);
+ vector<ArcFactoredForest> forests(corpus.size());
+ SparseVector<double> empirical;
+ cerr << "Extracting features...\n";
+ bool flag = false;
+ for (int i = 0; i < corpus.size(); ++i) {
+ TrainingInstance& cur = corpus[i];
+ if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; }
+ if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; }
+ ffs.PrepareForInput(cur.ts);
+ SparseVector<weight_t> efmap;
+ for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first,
+ cur.tree.h_m_pairs[j].second,
+ &efmap);
+ cur.features += efmap;
+ }
+ for (int j = 0; j < cur.tree.roots.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap);
+ cur.features += efmap;
+ }
+ empirical += cur.features;
+ forests[i].resize(cur.ts.words.size());
+ forests[i].ExtractFeatures(cur.ts, ffs);
+ }
+ if (flag) cerr << endl;
+ //cerr << "EMP: " << empirical << endl; //DE
+ weights.resize(FD::NumFeats(), 0.0);
+ vector<weight_t> g(FD::NumFeats(), 0.0);
+ cerr << "features initialized\noptimizing...\n";
+ boost::shared_ptr<BatchOptimizer> o;
+#if HAVE_THREAD
+ unsigned threads = conf["threads"].as<unsigned>();
+ if (threads > corpus.size()) threads = corpus.size();
+#else
+ const unsigned threads = 1;
+#endif
+ int chunk = corpus.size() / threads;
+ o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as<int>()));
+ int iterations = 1000;
+ for (int iter = 0; iter < iterations; ++iter) {
+ cerr << "ITERATION " << iter << " " << flush;
+ fill(g.begin(), g.end(), 0.0);
+ for (SparseVector<double>::const_iterator it = empirical.begin(); it != empirical.end(); ++it)
+ g[it->first] = -it->second;
+ double obj = -empirical.dot(weights);
+ vector<boost::shared_ptr<GradientWorker> > jobs;
+ for (int from = 0; from < corpus.size(); from += chunk) {
+ int to = from + chunk;
+ if (to > corpus.size()) to = corpus.size();
+ jobs.push_back(boost::shared_ptr<GradientWorker>(new GradientWorker(from, to, &weights, &corpus, &forests)));
+ }
+#if HAVE_THREAD
+ boost::thread_group tg;
+ for (int i = 0; i < threads; ++i)
+ tg.create_thread(boost::ref(*jobs[i]));
+ tg.join_all();
+#else
+ (*jobs[0])();
+#endif
+ for (int i = 0; i < threads; ++i) {
+ obj += jobs[i]->obj;
+ vector<double>& tg = jobs[i]->g;
+ for (unsigned j = 0; j < g.size(); ++j)
+ g[j] += tg[j];
+ }
+ // SparseVector<double> mfm; //DE
+ //cerr << endl << "E: " << empirical << endl; // DE
+ //cerr << "M: " << mfm << endl; // DE
+ double r = ApplyRegularizationTerms(conf["regularization_strength"].as<double>(), weights, &g);
+ double gnorm = 0;
+ for (int i = 0; i < g.size(); ++i)
+ gnorm += g[i]*g[i];
+ ostringstream ll;
+ ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm);
+ cerr << ' ' << ll.str().substr(ll.str().find('\t')+1) << endl;
+ obj += r;
+ assert(obj >= 0);
+ o->Optimize(obj, g, &weights);
+ Weights::ShowLargestFeatures(weights);
+ const bool converged = o->HasConverged();
+ const char* ofname = converged ? "weights.final.gz" : "weights.cur.gz";
+ if (converged || ((iter+1) % conf["output_every_i_iterations"].as<unsigned>()) == 0) {
+ cerr << "writing..." << flush;
+ const string sl = ll.str();
+ Weights::WriteToFile(ofname, weights, true, &sl);
+ cerr << "done" << endl;
+ }
+ if (converged) { cerr << "CONVERGED\n"; break; }
+ }
return 0;
}
diff --git a/rst_parser/picojson.h b/rst_parser/picojson.h
new file mode 100644
index 00000000..bdb26057
--- /dev/null
+++ b/rst_parser/picojson.h
@@ -0,0 +1,979 @@
+/*
+ * Copyright 2009-2010 Cybozu Labs, Inc.
+ * Copyright 2011 Kazuho Oku
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY CYBOZU LABS, INC. ``AS IS'' AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
+ * EVENT SHALL CYBOZU LABS, INC. OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
+ * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
+ * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * The views and conclusions contained in the software and documentation are
+ * those of the authors and should not be interpreted as representing official
+ * policies, either expressed or implied, of Cybozu Labs, Inc.
+ *
+ */
+#ifndef picojson_h
+#define picojson_h
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <iostream>
+#include <iterator>
+#include <map>
+#include <string>
+#include <vector>
+
+#ifdef _MSC_VER
+ #define SNPRINTF _snprintf_s
+ #pragma warning(push)
+ #pragma warning(disable : 4244) // conversion from int to char
+#else
+ #define SNPRINTF snprintf
+#endif
+
+namespace picojson {
+
+ enum {
+ null_type,
+ boolean_type,
+ number_type,
+ string_type,
+ array_type,
+ object_type
+ };
+
+ struct null {};
+
+ class value {
+ public:
+ typedef std::vector<value> array;
+ typedef std::map<std::string, value> object;
+ protected:
+ int type_;
+ union {
+ bool boolean_;
+ double number_;
+ std::string* string_;
+ array* array_;
+ object* object_;
+ };
+ public:
+ value();
+ value(int type, bool);
+ explicit value(bool b);
+ explicit value(double n);
+ explicit value(const std::string& s);
+ explicit value(const array& a);
+ explicit value(const object& o);
+ explicit value(const char* s);
+ value(const char* s, size_t len);
+ ~value();
+ value(const value& x);
+ value& operator=(const value& x);
+ template <typename T> bool is() const;
+ template <typename T> const T& get() const;
+ template <typename T> T& get();
+ bool evaluate_as_boolean() const;
+ const value& get(size_t idx) const;
+ const value& get(const std::string& key) const;
+ bool contains(size_t idx) const;
+ bool contains(const std::string& key) const;
+ std::string to_str() const;
+ template <typename Iter> void serialize(Iter os) const;
+ std::string serialize() const;
+ private:
+ template <typename T> value(const T*); // intentionally defined to block implicit conversion of pointer to bool
+ };
+
+ typedef value::array array;
+ typedef value::object object;
+
+ inline value::value() : type_(null_type) {}
+
+ inline value::value(int type, bool) : type_(type) {
+ switch (type) {
+#define INIT(p, v) case p##type: p = v; break
+ INIT(boolean_, false);
+ INIT(number_, 0.0);
+ INIT(string_, new std::string());
+ INIT(array_, new array());
+ INIT(object_, new object());
+#undef INIT
+ default: break;
+ }
+ }
+
+ inline value::value(bool b) : type_(boolean_type) {
+ boolean_ = b;
+ }
+
+ inline value::value(double n) : type_(number_type) {
+ number_ = n;
+ }
+
+ inline value::value(const std::string& s) : type_(string_type) {
+ string_ = new std::string(s);
+ }
+
+ inline value::value(const array& a) : type_(array_type) {
+ array_ = new array(a);
+ }
+
+ inline value::value(const object& o) : type_(object_type) {
+ object_ = new object(o);
+ }
+
+ inline value::value(const char* s) : type_(string_type) {
+ string_ = new std::string(s);
+ }
+
+ inline value::value(const char* s, size_t len) : type_(string_type) {
+ string_ = new std::string(s, len);
+ }
+
+ inline value::~value() {
+ switch (type_) {
+#define DEINIT(p) case p##type: delete p; break
+ DEINIT(string_);
+ DEINIT(array_);
+ DEINIT(object_);
+#undef DEINIT
+ default: break;
+ }
+ }
+
+ inline value::value(const value& x) : type_(x.type_) {
+ switch (type_) {
+#define INIT(p, v) case p##type: p = v; break
+ INIT(boolean_, x.boolean_);
+ INIT(number_, x.number_);
+ INIT(string_, new std::string(*x.string_));
+ INIT(array_, new array(*x.array_));
+ INIT(object_, new object(*x.object_));
+#undef INIT
+ default: break;
+ }
+ }
+
+ inline value& value::operator=(const value& x) {
+ if (this != &x) {
+ this->~value();
+ new (this) value(x);
+ }
+ return *this;
+ }
+
+#define IS(ctype, jtype) \
+ template <> inline bool value::is<ctype>() const { \
+ return type_ == jtype##_type; \
+ }
+ IS(null, null)
+ IS(bool, boolean)
+ IS(int, number)
+ IS(double, number)
+ IS(std::string, string)
+ IS(array, array)
+ IS(object, object)
+#undef IS
+
+#define GET(ctype, var) \
+ template <> inline const ctype& value::get<ctype>() const { \
+ assert("type mismatch! call vis<type>() before get<type>()" \
+ && is<ctype>()); \
+ return var; \
+ } \
+ template <> inline ctype& value::get<ctype>() { \
+ assert("type mismatch! call is<type>() before get<type>()" \
+ && is<ctype>()); \
+ return var; \
+ }
+ GET(bool, boolean_)
+ GET(double, number_)
+ GET(std::string, *string_)
+ GET(array, *array_)
+ GET(object, *object_)
+#undef GET
+
+ inline bool value::evaluate_as_boolean() const {
+ switch (type_) {
+ case null_type:
+ return false;
+ case boolean_type:
+ return boolean_;
+ case number_type:
+ return number_ != 0;
+ case string_type:
+ return ! string_->empty();
+ default:
+ return true;
+ }
+ }
+
+ inline const value& value::get(size_t idx) const {
+ static value s_null;
+ assert(is<array>());
+ return idx < array_->size() ? (*array_)[idx] : s_null;
+ }
+
+ inline const value& value::get(const std::string& key) const {
+ static value s_null;
+ assert(is<object>());
+ object::const_iterator i = object_->find(key);
+ return i != object_->end() ? i->second : s_null;
+ }
+
+ inline bool value::contains(size_t idx) const {
+ assert(is<array>());
+ return idx < array_->size();
+ }
+
+ inline bool value::contains(const std::string& key) const {
+ assert(is<object>());
+ object::const_iterator i = object_->find(key);
+ return i != object_->end();
+ }
+
+ inline std::string value::to_str() const {
+ switch (type_) {
+ case null_type: return "null";
+ case boolean_type: return boolean_ ? "true" : "false";
+ case number_type: {
+ char buf[256];
+ double tmp;
+ SNPRINTF(buf, sizeof(buf), modf(number_, &tmp) == 0 ? "%.f" : "%f", number_);
+ return buf;
+ }
+ case string_type: return *string_;
+ case array_type: return "array";
+ case object_type: return "object";
+ default: assert(0);
+#ifdef _MSC_VER
+ __assume(0);
+#endif
+ }
+ }
+
+ template <typename Iter> void copy(const std::string& s, Iter oi) {
+ std::copy(s.begin(), s.end(), oi);
+ }
+
+ template <typename Iter> void serialize_str(const std::string& s, Iter oi) {
+ *oi++ = '"';
+ for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) {
+ switch (*i) {
+#define MAP(val, sym) case val: copy(sym, oi); break
+ MAP('"', "\\\"");
+ MAP('\\', "\\\\");
+ MAP('/', "\\/");
+ MAP('\b', "\\b");
+ MAP('\f', "\\f");
+ MAP('\n', "\\n");
+ MAP('\r', "\\r");
+ MAP('\t', "\\t");
+#undef MAP
+ default:
+ if ((unsigned char)*i < 0x20 || *i == 0x7f) {
+ char buf[7];
+ SNPRINTF(buf, sizeof(buf), "\\u%04x", *i & 0xff);
+ copy(buf, buf + 6, oi);
+ } else {
+ *oi++ = *i;
+ }
+ break;
+ }
+ }
+ *oi++ = '"';
+ }
+
+ template <typename Iter> void value::serialize(Iter oi) const {
+ switch (type_) {
+ case string_type:
+ serialize_str(*string_, oi);
+ break;
+ case array_type: {
+ *oi++ = '[';
+ for (array::const_iterator i = array_->begin(); i != array_->end(); ++i) {
+ if (i != array_->begin()) {
+ *oi++ = ',';
+ }
+ i->serialize(oi);
+ }
+ *oi++ = ']';
+ break;
+ }
+ case object_type: {
+ *oi++ = '{';
+ for (object::const_iterator i = object_->begin();
+ i != object_->end();
+ ++i) {
+ if (i != object_->begin()) {
+ *oi++ = ',';
+ }
+ serialize_str(i->first, oi);
+ *oi++ = ':';
+ i->second.serialize(oi);
+ }
+ *oi++ = '}';
+ break;
+ }
+ default:
+ copy(to_str(), oi);
+ break;
+ }
+ }
+
+ inline std::string value::serialize() const {
+ std::string s;
+ serialize(std::back_inserter(s));
+ return s;
+ }
+
+ template <typename Iter> class input {
+ protected:
+ Iter cur_, end_;
+ int last_ch_;
+ bool ungot_;
+ int line_;
+ public:
+ input(const Iter& first, const Iter& last) : cur_(first), end_(last), last_ch_(-1), ungot_(false), line_(1) {}
+ int getc() {
+ if (ungot_) {
+ ungot_ = false;
+ return last_ch_;
+ }
+ if (cur_ == end_) {
+ last_ch_ = -1;
+ return -1;
+ }
+ if (last_ch_ == '\n') {
+ line_++;
+ }
+ last_ch_ = *cur_++ & 0xff;
+ return last_ch_;
+ }
+ void ungetc() {
+ if (last_ch_ != -1) {
+ assert(! ungot_);
+ ungot_ = true;
+ }
+ }
+ Iter cur() const { return cur_; }
+ int line() const { return line_; }
+ void skip_ws() {
+ while (1) {
+ int ch = getc();
+ if (! (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) {
+ ungetc();
+ break;
+ }
+ }
+ }
+ int expect(int expect) {
+ skip_ws();
+ if (getc() != expect) {
+ ungetc();
+ return false;
+ }
+ return true;
+ }
+ bool match(const std::string& pattern) {
+ for (std::string::const_iterator pi(pattern.begin());
+ pi != pattern.end();
+ ++pi) {
+ if (getc() != *pi) {
+ ungetc();
+ return false;
+ }
+ }
+ return true;
+ }
+ };
+
+ template<typename Iter> inline int _parse_quadhex(input<Iter> &in) {
+ int uni_ch = 0, hex;
+ for (int i = 0; i < 4; i++) {
+ if ((hex = in.getc()) == -1) {
+ return -1;
+ }
+ if ('0' <= hex && hex <= '9') {
+ hex -= '0';
+ } else if ('A' <= hex && hex <= 'F') {
+ hex -= 'A' - 0xa;
+ } else if ('a' <= hex && hex <= 'f') {
+ hex -= 'a' - 0xa;
+ } else {
+ in.ungetc();
+ return -1;
+ }
+ uni_ch = uni_ch * 16 + hex;
+ }
+ return uni_ch;
+ }
+
+ template<typename String, typename Iter> inline bool _parse_codepoint(String& out, input<Iter>& in) {
+ int uni_ch;
+ if ((uni_ch = _parse_quadhex(in)) == -1) {
+ return false;
+ }
+ if (0xd800 <= uni_ch && uni_ch <= 0xdfff) {
+ if (0xdc00 <= uni_ch) {
+ // a second 16-bit of a surrogate pair appeared
+ return false;
+ }
+ // first 16-bit of surrogate pair, get the next one
+ if (in.getc() != '\\' || in.getc() != 'u') {
+ in.ungetc();
+ return false;
+ }
+ int second = _parse_quadhex(in);
+ if (! (0xdc00 <= second && second <= 0xdfff)) {
+ return false;
+ }
+ uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff);
+ uni_ch += 0x10000;
+ }
+ if (uni_ch < 0x80) {
+ out.push_back(uni_ch);
+ } else {
+ if (uni_ch < 0x800) {
+ out.push_back(0xc0 | (uni_ch >> 6));
+ } else {
+ if (uni_ch < 0x10000) {
+ out.push_back(0xe0 | (uni_ch >> 12));
+ } else {
+ out.push_back(0xf0 | (uni_ch >> 18));
+ out.push_back(0x80 | ((uni_ch >> 12) & 0x3f));
+ }
+ out.push_back(0x80 | ((uni_ch >> 6) & 0x3f));
+ }
+ out.push_back(0x80 | (uni_ch & 0x3f));
+ }
+ return true;
+ }
+
+ template<typename String, typename Iter> inline bool _parse_string(String& out, input<Iter>& in) {
+ while (1) {
+ int ch = in.getc();
+ if (ch < ' ') {
+ in.ungetc();
+ return false;
+ } else if (ch == '"') {
+ return true;
+ } else if (ch == '\\') {
+ if ((ch = in.getc()) == -1) {
+ return false;
+ }
+ switch (ch) {
+#define MAP(sym, val) case sym: out.push_back(val); break
+ MAP('"', '\"');
+ MAP('\\', '\\');
+ MAP('/', '/');
+ MAP('b', '\b');
+ MAP('f', '\f');
+ MAP('n', '\n');
+ MAP('r', '\r');
+ MAP('t', '\t');
+#undef MAP
+ case 'u':
+ if (! _parse_codepoint(out, in)) {
+ return false;
+ }
+ break;
+ default:
+ return false;
+ }
+ } else {
+ out.push_back(ch);
+ }
+ }
+ return false;
+ }
+
+ template <typename Context, typename Iter> inline bool _parse_array(Context& ctx, input<Iter>& in) {
+ if (! ctx.parse_array_start()) {
+ return false;
+ }
+ if (in.expect(']')) {
+ return true;
+ }
+ size_t idx = 0;
+ do {
+ if (! ctx.parse_array_item(in, idx)) {
+ return false;
+ }
+ idx++;
+ } while (in.expect(','));
+ return in.expect(']');
+ }
+
+ template <typename Context, typename Iter> inline bool _parse_object(Context& ctx, input<Iter>& in) {
+ if (! ctx.parse_object_start()) {
+ return false;
+ }
+ if (in.expect('}')) {
+ return true;
+ }
+ do {
+ std::string key;
+ if (! in.expect('"')
+ || ! _parse_string(key, in)
+ || ! in.expect(':')) {
+ return false;
+ }
+ if (! ctx.parse_object_item(in, key)) {
+ return false;
+ }
+ } while (in.expect(','));
+ return in.expect('}');
+ }
+
+ template <typename Iter> inline bool _parse_number(double& out, input<Iter>& in) {
+ std::string num_str;
+ while (1) {
+ int ch = in.getc();
+ if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == '.'
+ || ch == 'e' || ch == 'E') {
+ num_str.push_back(ch);
+ } else {
+ in.ungetc();
+ break;
+ }
+ }
+ char* endp;
+ out = strtod(num_str.c_str(), &endp);
+ return endp == num_str.c_str() + num_str.size();
+ }
+
+ template <typename Context, typename Iter> inline bool _parse(Context& ctx, input<Iter>& in) {
+ in.skip_ws();
+ int ch = in.getc();
+ switch (ch) {
+#define IS(ch, text, op) case ch: \
+ if (in.match(text) && op) { \
+ return true; \
+ } else { \
+ return false; \
+ }
+ IS('n', "ull", ctx.set_null());
+ IS('f', "alse", ctx.set_bool(false));
+ IS('t', "rue", ctx.set_bool(true));
+#undef IS
+ case '"':
+ return ctx.parse_string(in);
+ case '[':
+ return _parse_array(ctx, in);
+ case '{':
+ return _parse_object(ctx, in);
+ default:
+ if (('0' <= ch && ch <= '9') || ch == '-') {
+ in.ungetc();
+ double f;
+ if (_parse_number(f, in)) {
+ ctx.set_number(f);
+ return true;
+ } else {
+ return false;
+ }
+ }
+ break;
+ }
+ in.ungetc();
+ return false;
+ }
+
+ class deny_parse_context {
+ public:
+ bool set_null() { return false; }
+ bool set_bool(bool) { return false; }
+ bool set_number(double) { return false; }
+ template <typename Iter> bool parse_string(input<Iter>&) { return false; }
+ bool parse_array_start() { return false; }
+ template <typename Iter> bool parse_array_item(input<Iter>&, size_t) {
+ return false;
+ }
+ bool parse_object_start() { return false; }
+ template <typename Iter> bool parse_object_item(input<Iter>&, const std::string&) {
+ return false;
+ }
+ };
+
+ class default_parse_context {
+ protected:
+ value* out_;
+ public:
+ default_parse_context(value* out) : out_(out) {}
+ bool set_null() {
+ *out_ = value();
+ return true;
+ }
+ bool set_bool(bool b) {
+ *out_ = value(b);
+ return true;
+ }
+ bool set_number(double f) {
+ *out_ = value(f);
+ return true;
+ }
+ template<typename Iter> bool parse_string(input<Iter>& in) {
+ *out_ = value(string_type, false);
+ return _parse_string(out_->get<std::string>(), in);
+ }
+ bool parse_array_start() {
+ *out_ = value(array_type, false);
+ return true;
+ }
+ template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) {
+ array& a = out_->get<array>();
+ a.push_back(value());
+ default_parse_context ctx(&a.back());
+ return _parse(ctx, in);
+ }
+ bool parse_object_start() {
+ *out_ = value(object_type, false);
+ return true;
+ }
+ template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string& key) {
+ object& o = out_->get<object>();
+ default_parse_context ctx(&o[key]);
+ return _parse(ctx, in);
+ }
+ private:
+ default_parse_context(const default_parse_context&);
+ default_parse_context& operator=(const default_parse_context&);
+ };
+
+ class null_parse_context {
+ public:
+ struct dummy_str {
+ void push_back(int) {}
+ };
+ public:
+ null_parse_context() {}
+ bool set_null() { return true; }
+ bool set_bool(bool) { return true; }
+ bool set_number(double) { return true; }
+ template <typename Iter> bool parse_string(input<Iter>& in) {
+ dummy_str s;
+ return _parse_string(s, in);
+ }
+ bool parse_array_start() { return true; }
+ template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) {
+ return _parse(*this, in);
+ }
+ bool parse_object_start() { return true; }
+ template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string&) {
+ return _parse(*this, in);
+ }
+ private:
+ null_parse_context(const null_parse_context&);
+ null_parse_context& operator=(const null_parse_context&);
+ };
+
+ // obsolete, use the version below
+ template <typename Iter> inline std::string parse(value& out, Iter& pos, const Iter& last) {
+ std::string err;
+ pos = parse(out, pos, last, &err);
+ return err;
+ }
+
+ template <typename Context, typename Iter> inline Iter _parse(Context& ctx, const Iter& first, const Iter& last, std::string* err) {
+ input<Iter> in(first, last);
+ if (! _parse(ctx, in) && err != NULL) {
+ char buf[64];
+ SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line());
+ *err = buf;
+ while (1) {
+ int ch = in.getc();
+ if (ch == -1 || ch == '\n') {
+ break;
+ } else if (ch >= ' ') {
+ err->push_back(ch);
+ }
+ }
+ }
+ return in.cur();
+ }
+
+ template <typename Iter> inline Iter parse(value& out, const Iter& first, const Iter& last, std::string* err) {
+ default_parse_context ctx(&out);
+ return _parse(ctx, first, last, err);
+ }
+
+ inline std::string parse(value& out, std::istream& is) {
+ std::string err;
+ parse(out, std::istreambuf_iterator<char>(is.rdbuf()),
+ std::istreambuf_iterator<char>(), &err);
+ return err;
+ }
+
+ template <typename T> struct last_error_t {
+ static std::string s;
+ };
+ template <typename T> std::string last_error_t<T>::s;
+
+ inline void set_last_error(const std::string& s) {
+ last_error_t<bool>::s = s;
+ }
+
+ inline const std::string& get_last_error() {
+ return last_error_t<bool>::s;
+ }
+
+ inline bool operator==(const value& x, const value& y) {
+ if (x.is<null>())
+ return y.is<null>();
+#define PICOJSON_CMP(type) \
+ if (x.is<type>()) \
+ return y.is<type>() && x.get<type>() == y.get<type>()
+ PICOJSON_CMP(bool);
+ PICOJSON_CMP(double);
+ PICOJSON_CMP(std::string);
+ PICOJSON_CMP(array);
+ PICOJSON_CMP(object);
+#undef PICOJSON_CMP
+ assert(0);
+#ifdef _MSC_VER
+ __assume(0);
+#endif
+ return false;
+ }
+
+ inline bool operator!=(const value& x, const value& y) {
+ return ! (x == y);
+ }
+}
+
+inline std::istream& operator>>(std::istream& is, picojson::value& x)
+{
+ picojson::set_last_error(std::string());
+ std::string err = picojson::parse(x, is);
+ if (! err.empty()) {
+ picojson::set_last_error(err);
+ is.setstate(std::ios::failbit);
+ }
+ return is;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const picojson::value& x)
+{
+ x.serialize(std::ostream_iterator<char>(os));
+ return os;
+}
+#ifdef _MSC_VER
+ #pragma warning(pop)
+#endif
+
+#endif
+#ifdef TEST_PICOJSON
+#ifdef _MSC_VER
+ #pragma warning(disable : 4127) // conditional expression is constant
+#endif
+
+using namespace std;
+
+static void plan(int num)
+{
+ printf("1..%d\n", num);
+}
+
+static bool success = true;
+
+static void ok(bool b, const char* name = "")
+{
+ static int n = 1;
+ if (! b)
+ success = false;
+ printf("%s %d - %s\n", b ? "ok" : "ng", n++, name);
+}
+
+template <typename T> void is(const T& x, const T& y, const char* name = "")
+{
+ if (x == y) {
+ ok(true, name);
+ } else {
+ ok(false, name);
+ }
+}
+
+#include <algorithm>
+
+int main(void)
+{
+ plan(75);
+
+ // constructors
+#define TEST(expr, expected) \
+ is(picojson::value expr .serialize(), string(expected), "picojson::value" #expr)
+
+ TEST( (true), "true");
+ TEST( (false), "false");
+ TEST( (42.0), "42");
+ TEST( (string("hello")), "\"hello\"");
+ TEST( ("hello"), "\"hello\"");
+ TEST( ("hello", 4), "\"hell\"");
+
+#undef TEST
+
+#define TEST(in, type, cmp, serialize_test) { \
+ picojson::value v; \
+ const char* s = in; \
+ string err = picojson::parse(v, s, s + strlen(s)); \
+ ok(err.empty(), in " no error"); \
+ ok(v.is<type>(), in " check type"); \
+ is<type>(v.get<type>(), cmp, in " correct output"); \
+ is(*s, '\0', in " read to eof"); \
+ if (serialize_test) { \
+ is(v.serialize(), string(in), in " serialize"); \
+ } \
+ }
+ TEST("false", bool, false, true);
+ TEST("true", bool, true, true);
+ TEST("90.5", double, 90.5, false);
+ TEST("\"hello\"", string, string("hello"), true);
+ TEST("\"\\\"\\\\\\/\\b\\f\\n\\r\\t\"", string, string("\"\\/\b\f\n\r\t"),
+ true);
+ TEST("\"\\u0061\\u30af\\u30ea\\u30b9\"", string,
+ string("a\xe3\x82\xaf\xe3\x83\xaa\xe3\x82\xb9"), false);
+ TEST("\"\\ud840\\udc0b\"", string, string("\xf0\xa0\x80\x8b"), false);
+#undef TEST
+
+#define TEST(type, expr) { \
+ picojson::value v; \
+ const char *s = expr; \
+ string err = picojson::parse(v, s, s + strlen(s)); \
+ ok(err.empty(), "empty " #type " no error"); \
+ ok(v.is<picojson::type>(), "empty " #type " check type"); \
+ ok(v.get<picojson::type>().empty(), "check " #type " array size"); \
+ }
+ TEST(array, "[]");
+ TEST(object, "{}");
+#undef TEST
+
+ {
+ picojson::value v;
+ const char *s = "[1,true,\"hello\"]";
+ string err = picojson::parse(v, s, s + strlen(s));
+ ok(err.empty(), "array no error");
+ ok(v.is<picojson::array>(), "array check type");
+ is(v.get<picojson::array>().size(), size_t(3), "check array size");
+ ok(v.contains(0), "check contains array[0]");
+ ok(v.get(0).is<double>(), "check array[0] type");
+ is(v.get(0).get<double>(), 1.0, "check array[0] value");
+ ok(v.contains(1), "check contains array[1]");
+ ok(v.get(1).is<bool>(), "check array[1] type");
+ ok(v.get(1).get<bool>(), "check array[1] value");
+ ok(v.contains(2), "check contains array[2]");
+ ok(v.get(2).is<string>(), "check array[2] type");
+ is(v.get(2).get<string>(), string("hello"), "check array[2] value");
+ ok(!v.contains(3), "check not contains array[3]");
+ }
+
+ {
+ picojson::value v;
+ const char *s = "{ \"a\": true }";
+ string err = picojson::parse(v, s, s + strlen(s));
+ ok(err.empty(), "object no error");
+ ok(v.is<picojson::object>(), "object check type");
+ is(v.get<picojson::object>().size(), size_t(1), "check object size");
+ ok(v.contains("a"), "check contains property");
+ ok(v.get("a").is<bool>(), "check bool property exists");
+ is(v.get("a").get<bool>(), true, "check bool property value");
+ is(v.serialize(), string("{\"a\":true}"), "serialize object");
+ ok(!v.contains("z"), "check not contains property");
+ }
+
+#define TEST(json, msg) do { \
+ picojson::value v; \
+ const char *s = json; \
+ string err = picojson::parse(v, s, s + strlen(s)); \
+ is(err, string("syntax error at line " msg), msg); \
+ } while (0)
+ TEST("falsoa", "1 near: oa");
+ TEST("{]", "1 near: ]");
+ TEST("\n\bbell", "2 near: bell");
+ TEST("\"abc\nd\"", "1 near: ");
+#undef TEST
+
+ {
+ picojson::value v1, v2;
+ const char *s;
+ string err;
+ s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }";
+ err = picojson::parse(v1, s, s + strlen(s));
+ s = "{ \"d\": 2.0, \"b\": true, \"a\": [1,2,\"three\"] }";
+ err = picojson::parse(v2, s, s + strlen(s));
+ ok((v1 == v2), "check == operator in deep comparison");
+ }
+
+ {
+ picojson::value v1, v2;
+ const char *s;
+ string err;
+ s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }";
+ err = picojson::parse(v1, s, s + strlen(s));
+ s = "{ \"d\": 2.0, \"a\": [1,\"three\"], \"b\": true }";
+ err = picojson::parse(v2, s, s + strlen(s));
+ ok((v1 != v2), "check != operator for array in deep comparison");
+ }
+
+ {
+ picojson::value v1, v2;
+ const char *s;
+ string err;
+ s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }";
+ err = picojson::parse(v1, s, s + strlen(s));
+ s = "{ \"d\": 2.0, \"a\": [1,2,\"three\"], \"b\": false }";
+ err = picojson::parse(v2, s, s + strlen(s));
+ ok((v1 != v2), "check != operator for object in deep comparison");
+ }
+
+ {
+ picojson::value v1, v2;
+ const char *s;
+ string err;
+ s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }";
+ err = picojson::parse(v1, s, s + strlen(s));
+ picojson::object& o = v1.get<picojson::object>();
+ o.erase("b");
+ picojson::array& a = o["a"].get<picojson::array>();
+ picojson::array::iterator i;
+ i = std::remove(a.begin(), a.end(), picojson::value(std::string("three")));
+ a.erase(i, a.end());
+ s = "{ \"a\": [1,2], \"d\": 2 }";
+ err = picojson::parse(v2, s, s + strlen(s));
+ ok((v1 == v2), "check erase()");
+ }
+
+ ok(picojson::value(3.0).serialize() == "3",
+ "integral number should be serialized as a integer");
+
+ {
+ const char* s = "{ \"a\": [1,2], \"d\": 2 }";
+ picojson::null_parse_context ctx;
+ string err;
+ picojson::_parse(ctx, s, s + strlen(s), &err);
+ ok(err.empty(), "null_parse_context");
+ }
+
+ return success ? 0 : 1;
+}
+
+#endif
diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc
index f6b295b3..bc91330b 100644
--- a/rst_parser/rst.cc
+++ b/rst_parser/rst.cc
@@ -2,6 +2,81 @@
using namespace std;
-StochasticForest::StochasticForest(const ArcFactoredForest& af) {
+// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time.
+// this is an awesome algorithm
+TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) {
+ // edges are directed from modifiers to heads, and finally to the root
+ vector<double> p;
+ for (int m = 1; m <= forest.size(); ++m) {
+#if USE_ALIAS_SAMPLER
+ p.clear();
+#else
+ SampleSet<double>& ss = usucc[m];
+#endif
+ double z = 0;
+ for (int h = 0; h <= forest.size(); ++h) {
+ double u = forest(h-1,m-1).edge_prob.as_float();
+ z += u;
+#if USE_ALIAS_SAMPLER
+ p.push_back(u);
+#else
+ ss.add(u);
+#endif
+ }
+#if USE_ALIAS_SAMPLER
+ for (int i = 0; i < p.size(); ++i) { p[i] /= z; }
+ usucc[m].Init(p);
+#endif
+ }
}
+void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree, MT19937* prng) {
+ MT19937& rng = *prng;
+ const int r = 0;
+ bool success = false;
+ while (!success) {
+ int roots = 0;
+ tree->h_m_pairs.clear();
+ tree->roots.clear();
+ vector<int> next(forest.size() + 1, -1);
+ vector<char> in_tree(forest.size() + 1, 0);
+ in_tree[r] = 1;
+ //cerr << "Forest size: " << forest.size() << endl;
+ for (int i = 0; i <= forest.size(); ++i) {
+ //cerr << "Sampling starting at u=" << i << endl;
+ int u = i;
+ if (in_tree[u]) continue;
+ while(!in_tree[u]) {
+#if USE_ALIAS_SAMPLER
+ next[u] = usucc[u].Draw(rng);
+#else
+ next[u] = rng.SelectSample(usucc[u]);
+#endif
+ u = next[u];
+ }
+ u = i;
+ //cerr << (u-1);
+ int prev = u-1;
+ while(!in_tree[u]) {
+ in_tree[u] = true;
+ u = next[u];
+ //cerr << " > " << (u-1);
+ if (u == r) {
+ ++roots;
+ tree->roots.push_back(prev);
+ } else {
+ tree->h_m_pairs.push_back(make_pair<short,short>(u-1,prev));
+ }
+ prev = u-1;
+ }
+ //cerr << endl;
+ }
+ assert(roots > 0);
+ if (roots > 1) {
+ //cerr << "FAILURE\n";
+ } else {
+ success = true;
+ }
+ }
+};
+
diff --git a/rst_parser/rst.h b/rst_parser/rst.h
index 865871eb..8bf389f7 100644
--- a/rst_parser/rst.h
+++ b/rst_parser/rst.h
@@ -1,10 +1,21 @@
#ifndef _RST_H_
#define _RST_H_
+#include <vector>
+#include "sampler.h"
#include "arc_factored.h"
+#include "alias_sampler.h"
-struct StochasticForest {
- explicit StochasticForest(const ArcFactoredForest& af);
+struct TreeSampler {
+ explicit TreeSampler(const ArcFactoredForest& af);
+ void SampleRandomSpanningTree(EdgeSubset* tree, MT19937* rng);
+ const ArcFactoredForest& forest;
+#define USE_ALIAS_SAMPLER 1
+#if USE_ALIAS_SAMPLER
+ std::vector<AliasSampler> usucc;
+#else
+ std::vector<SampleSet<double> > usucc;
+#endif
};
#endif
diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc
new file mode 100644
index 00000000..9c42a8f4
--- /dev/null
+++ b/rst_parser/rst_parse.cc
@@ -0,0 +1,111 @@
+#include "arc_factored.h"
+
+#include <vector>
+#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "arc_ff.h"
+#include "dep_training.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "weights.h"
+#include "rst.h"
+#include "global_ff.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ string cfg_file;
+ opts.add_options()
+ ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)")
+ ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)")
+ ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)")
+ ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(&cfg_file), "Configuration file")
+ ("help,?", "Print this help message and exit");
+
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(dconfig_options).add(clo);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (cfg_file.size() > 0) {
+ ReadFile rf(cfg_file);
+ po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+ }
+ if (conf->count("help") || conf->count("q_weights") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ vector<weight_t> qweights, pweights;
+ Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
+ if (conf.count("p_weights"))
+ Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights);
+ const bool global = pweights.size() > 0;
+ ArcFeatureFunctions ffs;
+ GlobalFeatureFunctions gff;
+ ReadFile rf(conf["input"].as<string>());
+ istream* in = rf.stream();
+ TrainingInstance sent;
+ MT19937 rng;
+ int samples = conf["samples"].as<unsigned>();
+ int totroot = 0, root_right = 0, tot = 0, cor = 0;
+ while(TrainingInstance::ReadInstance(in, &sent)) {
+ ffs.PrepareForInput(sent.ts);
+ if (global) gff.PrepareForInput(sent.ts);
+ ArcFactoredForest forest(sent.ts.pos.size());
+ forest.ExtractFeatures(sent.ts, ffs);
+ forest.Reweight(qweights);
+ TreeSampler ts(forest);
+ double best_score = -numeric_limits<double>::infinity();
+ EdgeSubset best_tree;
+ for (int n = 0; n < samples; ++n) {
+ EdgeSubset tree;
+ ts.SampleRandomSpanningTree(&tree, &rng);
+ SparseVector<double> qfeats, gfeats;
+ tree.ExtractFeatures(sent.ts, ffs, &qfeats);
+ double score = 0;
+ if (global) {
+ gff.Features(sent.ts, tree, &gfeats);
+ score = (qfeats + gfeats).dot(pweights);
+ } else {
+ score = qfeats.dot(qweights);
+ }
+ if (score > best_score) {
+ best_tree = tree;
+ best_score = score;
+ }
+ }
+ cerr << "BEST SCORE: " << best_score << endl;
+ cout << best_tree << endl;
+ const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0;
+ if (sent_has_ref) {
+ map<pair<short,short>, bool> ref;
+ for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i)
+ ref[sent.tree.h_m_pairs[i]] = true;
+ int ref_root = sent.tree.roots.front();
+ if (ref_root == best_tree.roots.front()) { ++root_right; }
+ ++totroot;
+ for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) {
+ if (ref[best_tree.h_m_pairs[i]]) {
+ ++cor;
+ }
+ ++tot;
+ }
+ }
+ }
+ cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl;
+ return 0;
+}
+
diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc
deleted file mode 100644
index e8fe706e..00000000
--- a/rst_parser/rst_test.cc
+++ /dev/null
@@ -1,33 +0,0 @@
-#include "arc_factored.h"
-
-#include <iostream>
-
-using namespace std;
-
-int main(int argc, char** argv) {
- // John saw Mary
- // (H -> M)
- // (1 -> 2) 20
- // (1 -> 3) 3
- // (2 -> 1) 20
- // (2 -> 3) 30
- // (3 -> 2) 0
- // (3 -> 1) 11
- // (0, 2) 10
- // (0, 1) 9
- // (0, 3) 9
- ArcFactoredForest af(3);
- af(1,2).edge_prob.logeq(20);
- af(1,3).edge_prob.logeq(3);
- af(2,1).edge_prob.logeq(20);
- af(2,3).edge_prob.logeq(30);
- af(3,2).edge_prob.logeq(0);
- af(3,1).edge_prob.logeq(11);
- af(0,2).edge_prob.logeq(10);
- af(0,1).edge_prob.logeq(9);
- af(0,3).edge_prob.logeq(9);
- SpanningTree tree;
- af.MaximumSpanningTree(&tree);
- return 0;
-}
-
diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc
new file mode 100644
index 00000000..9b730f3d
--- /dev/null
+++ b/rst_parser/rst_train.cc
@@ -0,0 +1,144 @@
+#include "arc_factored.h"
+
+#include <vector>
+#include <iostream>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "arc_ff.h"
+#include "dep_training.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "tdict.h"
+#include "weights.h"
+#include "rst.h"
+#include "global_ff.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ string cfg_file;
+ opts.add_options()
+ ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)")
+ ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution")
+ ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(&cfg_file), "Configuration file")
+ ("help,?", "Print this help message and exit");
+
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(dconfig_options).add(clo);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (cfg_file.size() > 0) {
+ ReadFile rf(cfg_file);
+ po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf);
+ }
+ if (conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ vector<weight_t> qweights(FD::NumFeats(), 0.0);
+ Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights);
+ vector<TrainingInstance> corpus;
+ ArcFeatureFunctions ffs;
+ GlobalFeatureFunctions gff;
+ TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus);
+ vector<ArcFactoredForest> forests(corpus.size());
+ vector<prob_t> zs(corpus.size());
+ SparseVector<double> empirical;
+ bool flag = false;
+ for (int i = 0; i < corpus.size(); ++i) {
+ TrainingInstance& cur = corpus[i];
+ if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; }
+ if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; }
+ SparseVector<weight_t> efmap;
+ ffs.PrepareForInput(cur.ts);
+ gff.PrepareForInput(cur.ts);
+ for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first,
+ cur.tree.h_m_pairs[j].second,
+ &efmap);
+ cur.features += efmap;
+ }
+ for (int j = 0; j < cur.tree.roots.size(); ++j) {
+ efmap.clear();
+ ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap);
+ cur.features += efmap;
+ }
+ efmap.clear();
+ gff.Features(cur.ts, cur.tree, &efmap);
+ cur.features += efmap;
+ empirical += cur.features;
+ forests[i].resize(cur.ts.words.size());
+ forests[i].ExtractFeatures(cur.ts, ffs);
+ forests[i].Reweight(qweights);
+ forests[i].EdgeMarginals(&zs[i]);
+ zs[i] = prob_t::One() / zs[i];
+ // cerr << zs[i] << endl;
+ forests[i].Reweight(qweights); // EdgeMarginals overwrites edge_prob
+ }
+ if (flag) cerr << endl;
+ MT19937 rng;
+ SparseVector<double> model_exp;
+ SparseVector<double> weights;
+ Weights::InitSparseVector(qweights, &weights);
+ int samples = conf["samples"].as<unsigned>();
+ for (int i = 0; i < corpus.size(); ++i) {
+#if 0
+ forests[i].EdgeMarginals();
+ model_exp.clear();
+ for (int h = -1; h < num_words; ++h) {
+ for (int m = 0; m < num_words; ++m) {
+ if (h == m) continue;
+ const ArcFactoredForest::Edge& edge = forests[i](h,m);
+ const SparseVector<weight_t>& fmap = edge.features;
+ double prob = edge.edge_prob.as_float();
+ model_exp += fmap * prob;
+ }
+ }
+ cerr << "TRUE EXP: " << model_exp << endl;
+ forests[i].Reweight(weights);
+#endif
+
+ TreeSampler ts(forests[i]);
+ prob_t zhat = prob_t::Zero();
+ SparseVector<prob_t> sampled_exp;
+ for (int n = 0; n < samples; ++n) {
+ EdgeSubset tree;
+ ts.SampleRandomSpanningTree(&tree, &rng);
+ SparseVector<double> qfeats, gfeats;
+ tree.ExtractFeatures(corpus[i].ts, ffs, &qfeats);
+ prob_t u; u.logeq(qfeats.dot(qweights));
+ const prob_t q = u / zs[i]; // proposal mass
+ gff.Features(corpus[i].ts, tree, &gfeats);
+ SparseVector<double> tot_feats = qfeats + gfeats;
+ u.logeq(tot_feats.dot(weights));
+ prob_t w = u / q;
+ zhat += w;
+ for (SparseVector<double>::const_iterator it = tot_feats.begin(); it != tot_feats.end(); ++it)
+ sampled_exp.add_value(it->first, w * prob_t(it->second));
+ }
+ sampled_exp /= zhat;
+ SparseVector<double> tot_m;
+ for (SparseVector<prob_t>::const_iterator it = sampled_exp.begin(); it != sampled_exp.end(); ++it)
+ tot_m.add_value(it->first, it->second.as_float());
+ //cerr << "DIFF: " << (tot_m - corpus[i].features) << endl;
+ const double eta = 0.03;
+ weights -= (tot_m - corpus[i].features) * eta;
+ }
+ cerr << "WEIGHTS.\n";
+ cerr << weights << endl;
+ return 0;
+}
+