diff options
| author | Patrick Simianer <p@simianer.de> | 2012-04-23 21:44:02 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2012-04-23 21:44:02 +0200 | 
| commit | 2f427278616cbe3fa6f56d6b97c40b3894dbd950 (patch) | |
| tree | 6998435e4677437c474cf0f835ce9f72d70d3945 /rst_parser | |
| parent | 6d0d0eb6bbfaee6b6998659a55e2195977ccd217 (diff) | |
| parent | 217c4aaeba1c9f19b3420b526235bffd86c7a92b (diff) | |
Merge remote-tracking branch 'upstream/master'
Conflicts:
	Makefile.am
	configure.ac
Diffstat (limited to 'rst_parser')
| -rw-r--r-- | rst_parser/Makefile.am | 20 | ||||
| -rw-r--r-- | rst_parser/arc_factored.cc | 128 | ||||
| -rw-r--r-- | rst_parser/arc_factored.h | 82 | ||||
| -rw-r--r-- | rst_parser/arc_factored_marginals.cc | 58 | ||||
| -rw-r--r-- | rst_parser/arc_ff.cc | 183 | ||||
| -rw-r--r-- | rst_parser/arc_ff.h | 28 | ||||
| -rw-r--r-- | rst_parser/dep_training.cc | 76 | ||||
| -rw-r--r-- | rst_parser/dep_training.h | 19 | ||||
| -rw-r--r-- | rst_parser/global_ff.cc | 44 | ||||
| -rw-r--r-- | rst_parser/global_ff.h | 18 | ||||
| -rw-r--r-- | rst_parser/mst_train.cc | 220 | ||||
| -rw-r--r-- | rst_parser/picojson.h | 979 | ||||
| -rw-r--r-- | rst_parser/rst.cc | 77 | ||||
| -rw-r--r-- | rst_parser/rst.h | 15 | ||||
| -rw-r--r-- | rst_parser/rst_parse.cc | 111 | ||||
| -rw-r--r-- | rst_parser/rst_test.cc | 33 | ||||
| -rw-r--r-- | rst_parser/rst_train.cc | 144 | 
17 files changed, 2159 insertions, 76 deletions
| diff --git a/rst_parser/Makefile.am b/rst_parser/Makefile.am index e97ab5c5..4977f584 100644 --- a/rst_parser/Makefile.am +++ b/rst_parser/Makefile.am @@ -1,19 +1,17 @@  bin_PROGRAMS = \ -  mst_train - -noinst_PROGRAMS = \ -  rst_test - -TESTS = rst_test +  mst_train rst_train rst_parse  noinst_LIBRARIES = librst.a -librst_a_SOURCES = arc_factored.cc rst.cc +librst_a_SOURCES = arc_factored.cc arc_factored_marginals.cc rst.cc arc_ff.cc dep_training.cc global_ff.cc  mst_train_SOURCES = mst_train.cc -mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../training/optimize.o -lz + +rst_train_SOURCES = rst_train.cc +rst_train_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -rst_test_SOURCES = rst_test.cc -rst_test_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +rst_parse_SOURCES = rst_parse.cc +rst_parse_LDADD = librst.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/training -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/rst_parser/arc_factored.cc b/rst_parser/arc_factored.cc index 1e75600b..74bf7516 100644 --- a/rst_parser/arc_factored.cc +++ b/rst_parser/arc_factored.cc @@ -1,31 +1,151 @@  #include "arc_factored.h"  #include <set> +#include <tr1/unordered_set>  #include <boost/pending/disjoint_sets.hpp> +#include <boost/functional/hash.hpp> + +#include "arc_ff.h"  using namespace std; +using namespace std::tr1;  using namespace boost; +void EdgeSubset::ExtractFeatures(const TaggedSentence& sentence, +                                 const ArcFeatureFunctions& ffs, +                                 SparseVector<double>* features) const { +  SparseVector<weight_t> efmap; +  for (int j = 0; j < h_m_pairs.size(); ++j) { +    efmap.clear(); +    ffs.EdgeFeatures(sentence, h_m_pairs[j].first, +                     h_m_pairs[j].second, +                     &efmap); +    (*features) += efmap; +  } +  for (int j = 0; j < roots.size(); ++j) { +    efmap.clear(); +    ffs.EdgeFeatures(sentence, -1, roots[j], &efmap); +    (*features) += efmap; +  } +} + +void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence, +                                        const ArcFeatureFunctions& ffs) { +  for (int m = 0; m < num_words_; ++m) { +    for (int h = 0; h < num_words_; ++h) { +      ffs.EdgeFeatures(sentence, h, m, &edges_(h,m).features); +    } +    ffs.EdgeFeatures(sentence, -1, m, &root_edges_[m].features); +  } +} + +void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const { +  for (int m = 0; m < num_words_; ++m) { +    int best_head = -2; +    prob_t best_score; +    for (int h = -1; h < num_words_; ++h) { +      const Edge& edge = (*this)(h,m); +      if (best_head < -1 || edge.edge_prob > best_score) { +        best_score = edge.edge_prob; +        best_head = h; +      } +    } +    assert(best_head >= -1); +    if (best_head >= 0) +      st->h_m_pairs.push_back(make_pair<short,short>(best_head, m)); +    else +      st->roots.push_back(m); +  } +} + +struct WeightedEdge { +  WeightedEdge() : h(), m(), weight() {} +  WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {} +  short h, m; +  float weight; +  inline bool operator==(const WeightedEdge& o) const { +    return h == o.h && m == o.m && weight == o.weight; +  } +  inline bool operator!=(const WeightedEdge& o) const { +    return h != o.h || m != o.m || weight != o.weight; +  } +}; +inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; } +inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); } + + +struct PriorityQueue { +  void push(const WeightedEdge& e) {} +  const WeightedEdge& top() const { +    static WeightedEdge w(1,2,3); +    return w; +  } +  void pop() {} +  void increment_all(float p) {} +}; +  // based on Trajan 1977 -void ArcFactoredForest::MaximumSpanningTree(SpanningTree* st) const { +void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const {    typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,        find_with_full_path_compression> DisjointSet;    DisjointSet strongly(num_words_ + 1);    DisjointSet weakly(num_words_ + 1); -  set<unsigned> roots, h, rset; -  vector<pair<short, short> > enter(num_words_ + 1); +  set<unsigned> roots, rset; +  unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h; +  vector<PriorityQueue> qs(num_words_ + 1); +  vector<WeightedEdge> enter(num_words_ + 1); +  vector<unsigned> mins(num_words_ + 1); +  const WeightedEdge kDUMMY(0,0,0.0f);    for (unsigned i = 0; i <= num_words_; ++i) { +    if (i > 0) { +      // I(i) incidence on i -- all incoming edges +      for (unsigned j = 0; j <= num_words_; ++j) { +        qs[i].push(WeightedEdge(j, i, Weight(j,i))); +      } +    }      strongly.make_set(i);      weakly.make_set(i);      roots.insert(i); +    enter[i] = kDUMMY; +    mins[i] = i;    }    while(!roots.empty()) {      set<unsigned>::iterator it = roots.begin();      const unsigned k = *it;      roots.erase(it);      cerr << "k=" << k << endl; -    pair<short,short> ij; // TODO = Max(k); +    WeightedEdge ij = qs[k].top();  // MAX(k) +    qs[k].pop(); +    if (ij.weight <= 0) { +      rset.insert(k); +    } else { +      if (strongly.find_set(ij.h) == k) { +        roots.insert(k); +      } else { +        h.insert(ij); +        if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) { +          weakly.union_set(ij.h, ij.m); +          enter[k] = ij; +        } else { +          unsigned vertex = 0; +          float val = 99999999999; +          WeightedEdge xy = ij; +          while(xy != kDUMMY) { +            if (xy.weight < val) { +              val = xy.weight; +              vertex = strongly.find_set(xy.m); +            } +            xy = enter[strongly.find_set(xy.h)]; +          } +          qs[k].increment_all(val - ij.weight); +          mins[k] = mins[vertex]; +          xy = enter[strongly.find_set(ij.h)]; +          while (xy != kDUMMY) { +          } +        } +      } +    }    }  } diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h index e99be482..c5481d80 100644 --- a/rst_parser/arc_factored.h +++ b/rst_parser/arc_factored.h @@ -5,37 +5,65 @@  #include <cassert>  #include <vector>  #include <utility> +#include <boost/shared_ptr.hpp>  #include "array2d.h"  #include "sparse_vector.h"  #include "prob.h"  #include "weights.h" +#include "wordid.h" -struct SpanningTree { -  SpanningTree() : roots(1, -1) {} +struct TaggedSentence { +  std::vector<WordID> words; +  std::vector<WordID> pos; +}; + +struct ArcFeatureFunctions; +struct EdgeSubset { +  EdgeSubset() {}    std::vector<short> roots; // unless multiroot trees are supported, this                              // will have a single member -  std::vector<std::pair<short, short> > h_m_pairs; +  std::vector<std::pair<short, short> > h_m_pairs; // h,m start at 0 +  // assumes ArcFeatureFunction::PrepareForInput has already been called +  void ExtractFeatures(const TaggedSentence& sentence, +                       const ArcFeatureFunctions& ffs, +                       SparseVector<double>* features) const;  };  class ArcFactoredForest {   public: -  explicit ArcFactoredForest(short num_words) : -      num_words_(num_words), -      root_edges_(num_words), -      edges_(num_words, num_words) { +  ArcFactoredForest() : num_words_() {} +  explicit ArcFactoredForest(short num_words) : num_words_(num_words) { +    resize(num_words); +  } + +  unsigned size() const { return num_words_; } + +  void resize(unsigned num_words) { +    num_words_ = num_words; +    root_edges_.clear(); +    edges_.clear(); +    root_edges_.resize(num_words); +    edges_.resize(num_words, num_words);      for (int h = 0; h < num_words; ++h) {        for (int m = 0; m < num_words; ++m) { -        edges_(h, m).h = h + 1; -        edges_(h, m).m = m + 1; +        edges_(h, m).h = h; +        edges_(h, m).m = m;        } -      root_edges_[h].h = 0; -      root_edges_[h].m = h + 1; +      root_edges_[h].h = -1; +      root_edges_[h].m = h;      }    }    // compute the maximum spanning tree based on the current weighting    // using the O(n^2) CLE algorithm -  void MaximumSpanningTree(SpanningTree* st) const; +  void MaximumSpanningTree(EdgeSubset* st) const; + +  // Reweight edges so that edge_prob is the edge's marginals +  // optionally returns log partition +  void EdgeMarginals(prob_t* p_log_z = NULL); + +  // This may not return a tree +  void PickBestParentForEachWord(EdgeSubset* st) const;    struct Edge {      Edge() : h(), m(), features(), edge_prob(prob_t::Zero()) {} @@ -45,20 +73,20 @@ class ArcFactoredForest {      prob_t edge_prob;    }; +  // set eges_[*].features +  void ExtractFeatures(const TaggedSentence& sentence, +                       const ArcFeatureFunctions& ffs); +    const Edge& operator()(short h, short m) const { -    assert(m > 0); -    assert(m <= num_words_); -    assert(h >= 0); -    assert(h <= num_words_); -    return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; +    return h >= 0 ? edges_(h, m) : root_edges_[m];    }    Edge& operator()(short h, short m) { -    assert(m > 0); -    assert(m <= num_words_); -    assert(h >= 0); -    assert(h <= num_words_); -    return h ? edges_(h - 1, m - 1) : root_edges_[m - 1]; +    return h >= 0 ? edges_(h, m) : root_edges_[m]; +  } + +  float Weight(short h, short m) const { +    return log((*this)(h,m).edge_prob);    }    template <class V> @@ -76,7 +104,7 @@ class ArcFactoredForest {    }   private: -  unsigned num_words_; +  int num_words_;    std::vector<Edge> root_edges_;    Array2D<Edge> edges_;  }; @@ -85,4 +113,12 @@ inline std::ostream& operator<<(std::ostream& os, const ArcFactoredForest::Edge&    return os << "(" << edge.h << " < " << edge.m << ")";  } +inline std::ostream& operator<<(std::ostream& os, const EdgeSubset& ss) { +  for (unsigned i = 0; i < ss.roots.size(); ++i) +    os << "ROOT < " << ss.roots[i] << std::endl; +  for (unsigned i = 0; i < ss.h_m_pairs.size(); ++i) +    os << ss.h_m_pairs[i].first << " < " << ss.h_m_pairs[i].second << std::endl; +  return os; +} +  #endif diff --git a/rst_parser/arc_factored_marginals.cc b/rst_parser/arc_factored_marginals.cc new file mode 100644 index 00000000..3e8c9f86 --- /dev/null +++ b/rst_parser/arc_factored_marginals.cc @@ -0,0 +1,58 @@ +#include "arc_factored.h" + +#include <iostream> + +#include "config.h" + +using namespace std; + +#if HAVE_EIGEN + +#include <Eigen/Dense> +typedef Eigen::Matrix<prob_t, Eigen::Dynamic, Eigen::Dynamic> ArcMatrix; +typedef Eigen::Matrix<prob_t, Eigen::Dynamic, 1> RootVector; + +void ArcFactoredForest::EdgeMarginals(prob_t *plog_z) { +  ArcMatrix A(num_words_,num_words_); +  RootVector r(num_words_); +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      if (h != m) +        A(h,m) = edges_(h,m).edge_prob; +      else +        A(h,m) = prob_t::Zero(); +    } +    r(h) = root_edges_[h].edge_prob; +  } + +  ArcMatrix L = -A; +  L.diagonal() = A.colwise().sum(); +  L.row(0) = r; +  ArcMatrix Linv = L.inverse(); +  if (plog_z) *plog_z = Linv.determinant(); +  RootVector rootMarginals = r.cwiseProduct(Linv.col(0)); +  static const prob_t ZERO(0); +  static const prob_t ONE(1); +//  ArcMatrix T = Linv; +  for (int h = 0; h < num_words_; ++h) { +    for (int m = 0; m < num_words_; ++m) { +      const prob_t marginal = (m == 0 ? ZERO : ONE) * A(h,m) * Linv(m,m) - +                              (h == 0 ? ZERO : ONE) * A(h,m) * Linv(m,h); +      edges_(h,m).edge_prob = marginal; +//      T(h,m) = marginal; +    } +    root_edges_[h].edge_prob = rootMarginals(h); +  } +//   cerr << "ROOT MARGINALS: " << rootMarginals.transpose() << endl; +//  cerr << "M:\n" << T << endl; +} + +#else + +void ArcFactoredForest::EdgeMarginals(prob_t *) { +  cerr << "EdgeMarginals() requires --with-eigen!\n"; +  abort(); +} + +#endif + diff --git a/rst_parser/arc_ff.cc b/rst_parser/arc_ff.cc new file mode 100644 index 00000000..c4e5aa17 --- /dev/null +++ b/rst_parser/arc_ff.cc @@ -0,0 +1,183 @@ +#include "arc_ff.h" + +#include <iostream> +#include <sstream> + +#include "stringlib.h" +#include "tdict.h" +#include "fdict.h" +#include "sentence_metadata.h" + +using namespace std; + +struct ArcFFImpl { +  ArcFFImpl() : kROOT("ROOT"), kLEFT_POS("LEFT"), kRIGHT_POS("RIGHT") {} +  const string kROOT; +  const string kLEFT_POS; +  const string kRIGHT_POS; +  map<WordID, vector<int> > pcs; + +  void PrepareForInput(const TaggedSentence& sent) { +    pcs.clear(); +    for (int i = 0; i < sent.pos.size(); ++i) +      pcs[sent.pos[i]].resize(1, 0); +    pcs[sent.pos[0]][0] = 1; +    for (int i = 1; i < sent.pos.size(); ++i) { +      const WordID posi = sent.pos[i]; +      for (map<WordID, vector<int> >::iterator j = pcs.begin(); j != pcs.end(); ++j) { +        const WordID posj = j->first; +        vector<int>& cs = j->second; +        cs.push_back(cs.back() + (posj == posi ? 1 : 0)); +      } +    } +  } + +  template <typename A> +  static void Fire(SparseVector<weight_t>* v, const A& a) { +    ostringstream os; +    os << a; +    v->set_value(FD::Convert(os.str()), 1); +  } + +  template <typename A, typename B> +  static void Fire(SparseVector<weight_t>* v, const A& a, const B& b) { +    ostringstream os; +    os << a << ':' << b; +    v->set_value(FD::Convert(os.str()), 1); +  } + +  template <typename A, typename B, typename C> +  static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c) { +    ostringstream os; +    os << a << ':' << b << '_' << c; +    v->set_value(FD::Convert(os.str()), 1); +  } + +  template <typename A, typename B, typename C, typename D> +  static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d) { +    ostringstream os; +    os << a << ':' << b << '_' << c << '_' << d; +    v->set_value(FD::Convert(os.str()), 1); +  } + +  template <typename A, typename B, typename C, typename D, typename E> +  static void Fire(SparseVector<weight_t>* v, const A& a, const B& b, const C& c, const D& d, const E& e) { +    ostringstream os; +    os << a << ':' << b << '_' << c << '_' << d << '_' << e; +    v->set_value(FD::Convert(os.str()), 1); +  } + +  static void AddConjoin(const SparseVector<double>& v, const string& feat, SparseVector<double>* pf) { +    for (SparseVector<double>::const_iterator it = v.begin(); it != v.end(); ++it) +      pf->set_value(FD::Convert(FD::Convert(it->first) + "_" + feat), it->second); +  } + +  static inline string Fixup(const string& str) { +    string res = LowercaseString(str); +    if (res.size() < 6) return res; +    return res.substr(0, 5) + "*"; +  } + +  static inline string Suffix(const string& str) { +    if (str.size() < 4) return ""; else return str.substr(str.size() - 3); +  } + +  void EdgeFeatures(const TaggedSentence& sent, +                    short h, +                    short m, +                    SparseVector<weight_t>* features) const { +    const bool is_root = (h == -1); +    const string head_word = (is_root ? kROOT : Fixup(TD::Convert(sent.words[h]))); +    int num_words = sent.words.size(); +    const string& head_pos = (is_root ? kROOT : TD::Convert(sent.pos[h])); +    const string mod_word = Fixup(TD::Convert(sent.words[m])); +    const string& mod_pos = TD::Convert(sent.pos[m]); +    const string& mod_pos_L = (m > 0 ? TD::Convert(sent.pos[m-1]) : kLEFT_POS); +    const string& mod_pos_R = (m < sent.pos.size() - 1 ? TD::Convert(sent.pos[m]) : kRIGHT_POS); +    const bool bdir = m < h; +    const string dir = (bdir ? "MLeft" : "MRight"); +    int v = m - h; +    if (v < 0) { +      v= -1 - int(log(-v) / log(1.6)); +    } else { +      v= int(log(v) / log(1.6)) + 1; +    } +    ostringstream os; +    if (v < 0) os << "LenL" << -v; else os << "LenR" << v; +    const string lenstr = os.str(); +    Fire(features, dir); +    Fire(features, lenstr); +    // dir, lenstr +    if (is_root) { +      Fire(features, "wROOT", mod_word); +      Fire(features, "pROOT", mod_pos); +      Fire(features, "wpROOT", mod_word, mod_pos); +      Fire(features, "DROOT", mod_pos, lenstr); +      Fire(features, "LROOT", mod_pos_L); +      Fire(features, "RROOT", mod_pos_R); +      Fire(features, "LROOT", mod_pos_L, mod_pos); +      Fire(features, "RROOT", mod_pos_R, mod_pos); +      Fire(features, "LDist", m); +      Fire(features, "RDist", num_words - m); +    } else { // not root +      const string& head_pos_L = (h > 0 ? TD::Convert(sent.pos[h-1]) : kLEFT_POS); +      const string& head_pos_R = (h < sent.pos.size() - 1 ? TD::Convert(sent.pos[h]) : kRIGHT_POS); +      SparseVector<double> fv; +      SparseVector<double>* f = &fv; +      Fire(f, "H", head_pos); +      Fire(f, "M", mod_pos); +      Fire(f, "HM", head_pos, mod_pos); + +      // surrounders +      Fire(f, "posLL", head_pos, mod_pos, head_pos_L, mod_pos_L); +      Fire(f, "posRR", head_pos, mod_pos, head_pos_R, mod_pos_R); +      Fire(f, "posLR", head_pos, mod_pos, head_pos_L, mod_pos_R); +      Fire(f, "posRL", head_pos, mod_pos, head_pos_R, mod_pos_L); + +      // between features +      int left = min(h,m); +      int right = max(h,m); +      if (right - left >= 2) { +        if (bdir) --right; else ++left; +        for (map<WordID, vector<int> >::const_iterator it = pcs.begin(); it != pcs.end(); ++it) { +          if (it->second[left] != it->second[right]) { +            Fire(f, "BT", head_pos, TD::Convert(it->first), mod_pos); +          } +        } +      } + +      Fire(f, "wH", head_word); +      Fire(f, "wM", mod_word); +      Fire(f, "wpH", head_word, head_pos); +      Fire(f, "wpM", mod_word, mod_pos); +      Fire(f, "pHwM", head_pos, mod_word); +      Fire(f, "wHpM", head_word, mod_pos); + +      Fire(f, "wHM", head_word, mod_word); +      Fire(f, "pHMwH", head_pos, mod_pos, head_word); +      Fire(f, "pHMwM", head_pos, mod_pos, mod_word); +      Fire(f, "wHMpH", head_word, mod_word, head_pos); +      Fire(f, "wHMpM", head_word, mod_word, mod_pos); +      Fire(f, "wHMpHM", head_word, mod_word, head_pos, mod_pos); + +      AddConjoin(fv, dir, features); +      AddConjoin(fv, lenstr, features); +      (*features) += fv; +    } +  } +}; + +ArcFeatureFunctions::ArcFeatureFunctions() : pimpl(new ArcFFImpl) {} +ArcFeatureFunctions::~ArcFeatureFunctions() { delete pimpl; } + +void ArcFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) { +  pimpl->PrepareForInput(sentence); +} + +void ArcFeatureFunctions::EdgeFeatures(const TaggedSentence& sentence, +                                       short h, +                                       short m, +                                       SparseVector<weight_t>* features) const { +  pimpl->EdgeFeatures(sentence, h, m, features); +} + diff --git a/rst_parser/arc_ff.h b/rst_parser/arc_ff.h new file mode 100644 index 00000000..52f311d2 --- /dev/null +++ b/rst_parser/arc_ff.h @@ -0,0 +1,28 @@ +#ifndef _ARC_FF_H_ +#define _ARC_FF_H_ + +#include <string> +#include "sparse_vector.h" +#include "weights.h" +#include "arc_factored.h" + +struct TaggedSentence; +struct ArcFFImpl; +class ArcFeatureFunctions { + public: +  ArcFeatureFunctions(); +  ~ArcFeatureFunctions(); + +  // called once, per input, before any calls to EdgeFeatures +  // used to initialize sentence-specific data structures +  void PrepareForInput(const TaggedSentence& sentence); + +  void EdgeFeatures(const TaggedSentence& sentence, +                    short h, +                    short m, +                    SparseVector<weight_t>* features) const; + private: +  ArcFFImpl* pimpl; +}; + +#endif diff --git a/rst_parser/dep_training.cc b/rst_parser/dep_training.cc new file mode 100644 index 00000000..ef97798b --- /dev/null +++ b/rst_parser/dep_training.cc @@ -0,0 +1,76 @@ +#include "dep_training.h" + +#include <vector> +#include <iostream> + +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "picojson.h" + +using namespace std; + +static void ParseInstance(const string& line, int start, TrainingInstance* out, int lc = 0) { +  picojson::value obj; +  string err; +  picojson::parse(obj, line.begin() + start, line.end(), &err); +  if (err.size() > 0) { cerr << "JSON parse error in " << lc << ": " << err << endl; abort(); } +  TrainingInstance& cur = *out; +  TaggedSentence& ts = cur.ts; +  EdgeSubset& tree = cur.tree; +  ts.pos.clear(); +  ts.words.clear(); +  tree.roots.clear(); +  tree.h_m_pairs.clear(); +  assert(obj.is<picojson::object>()); +  const picojson::object& d = obj.get<picojson::object>(); +  const picojson::array& ta = d.find("tokens")->second.get<picojson::array>(); +  for (unsigned i = 0; i < ta.size(); ++i) { +    ts.words.push_back(TD::Convert(ta[i].get<picojson::array>()[0].get<string>())); +    ts.pos.push_back(TD::Convert(ta[i].get<picojson::array>()[1].get<string>())); +  } +  if (d.find("deps") != d.end()) { +    const picojson::array& da = d.find("deps")->second.get<picojson::array>(); +    for (unsigned i = 0; i < da.size(); ++i) { +      const picojson::array& thm = da[i].get<picojson::array>(); +      // get dep type here +      short h = thm[2].get<double>(); +      short m = thm[1].get<double>(); +      if (h < 0) +        tree.roots.push_back(m); +      else +        tree.h_m_pairs.push_back(make_pair(h,m)); +    } +  } +  //cerr << TD::GetString(ts.words) << endl << TD::GetString(ts.pos) << endl << tree << endl; +} + +bool TrainingInstance::ReadInstance(std::istream* in, TrainingInstance* instance) { +  string line; +  if (!getline(*in, line)) return false; +  size_t pos = line.rfind('\t'); +  assert(pos != string::npos); +  static int lc = 0; ++lc; +  ParseInstance(line, pos + 1, instance, lc); +  return true; +} + +void TrainingInstance::ReadTrainingCorpus(const string& fname, vector<TrainingInstance>* corpus, int rank, int size) { +  ReadFile rf(fname); +  istream& in = *rf.stream(); +  string line; +  int lc = 0; +  bool flag = false; +  while(getline(in, line)) { +    ++lc; +    if ((lc-1) % size != rank) continue; +    if (rank == 0 && lc % 10 == 0) { cerr << '.' << flush; flag = true; } +    if (rank == 0 && lc % 400 == 0) { cerr << " [" << lc << "]\n"; flag = false; } +    size_t pos = line.rfind('\t'); +    assert(pos != string::npos); +    corpus->push_back(TrainingInstance()); +    ParseInstance(line, pos + 1, &corpus->back(), lc); +  } +  if (flag) cerr << "\nRead " << lc << " training instances\n"; +} + diff --git a/rst_parser/dep_training.h b/rst_parser/dep_training.h new file mode 100644 index 00000000..3eeee22e --- /dev/null +++ b/rst_parser/dep_training.h @@ -0,0 +1,19 @@ +#ifndef _DEP_TRAINING_H_ +#define _DEP_TRAINING_H_ + +#include <iostream> +#include <string> +#include <vector> +#include "arc_factored.h" +#include "weights.h" + +struct TrainingInstance { +  TaggedSentence ts; +  EdgeSubset tree; +  SparseVector<weight_t> features; +  // reads a "Jsent" formatted dependency file +  static bool ReadInstance(std::istream* in, TrainingInstance* instance); // returns false at EOF +  static void ReadTrainingCorpus(const std::string& fname, std::vector<TrainingInstance>* corpus, int rank = 0, int size = 1); +}; + +#endif diff --git a/rst_parser/global_ff.cc b/rst_parser/global_ff.cc new file mode 100644 index 00000000..ae410875 --- /dev/null +++ b/rst_parser/global_ff.cc @@ -0,0 +1,44 @@ +#include "global_ff.h" + +#include <iostream> +#include <sstream> + +#include "tdict.h" + +using namespace std; + +struct GFFImpl { +  void PrepareForInput(const TaggedSentence& sentence) { +  } +  void Features(const TaggedSentence& sentence, +                const EdgeSubset& tree, +                SparseVector<double>* feats) const { +    const vector<WordID>& words = sentence.words; +    const vector<WordID>& tags = sentence.pos; +    const vector<pair<short,short> >& hms = tree.h_m_pairs; +    assert(words.size() == tags.size()); +    vector<int> mods(words.size()); +    for (int i = 0; i < hms.size(); ++i) { +      mods[hms[i].first]++;        // first = head, second = modifier +    } +    for (int i = 0; i < mods.size(); ++i) { +      ostringstream os; +      os << "NM:" << TD::Convert(tags[i]) << "_" << mods[i]; +      feats->add_value(FD::Convert(os.str()), 1.0); +    } +  } +}; + +GlobalFeatureFunctions::GlobalFeatureFunctions() : pimpl(new GFFImpl) {} +GlobalFeatureFunctions::~GlobalFeatureFunctions() { delete pimpl; } + +void GlobalFeatureFunctions::PrepareForInput(const TaggedSentence& sentence) { +  pimpl->PrepareForInput(sentence); +} + +void GlobalFeatureFunctions::Features(const TaggedSentence& sentence, +                                      const EdgeSubset& tree, +                                      SparseVector<double>* feats) const { +  pimpl->Features(sentence, tree, feats); +} + diff --git a/rst_parser/global_ff.h b/rst_parser/global_ff.h new file mode 100644 index 00000000..d71d0fa1 --- /dev/null +++ b/rst_parser/global_ff.h @@ -0,0 +1,18 @@ +#ifndef _GLOBAL_FF_H_ +#define _GLOBAL_FF_H_ + +#include "arc_factored.h" + +struct GFFImpl; +struct GlobalFeatureFunctions { +  GlobalFeatureFunctions(); +  ~GlobalFeatureFunctions(); +  void PrepareForInput(const TaggedSentence& sentence); +  void Features(const TaggedSentence& sentence, +                const EdgeSubset& tree, +                SparseVector<double>* feats) const; + private: +  GFFImpl* pimpl; +}; + +#endif diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index 7b5af4c1..6332693e 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -1,12 +1,228 @@  #include "arc_factored.h" +#include <vector>  #include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +// #define HAVE_THREAD 1 +#if HAVE_THREAD +#include <boost/thread.hpp> +#endif + +#include "arc_ff.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "dep_training.h" +#include "optimize.h" +#include "weights.h"  using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  string cfg_file; +  opts.add_options() +        ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") +        ("weights,w",po::value<string>(), "Optional starting weights") +        ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations") +        ("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength") +#ifdef HAVE_CMPH +        ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features") +#endif +#if HAVE_THREAD +        ("threads,T",po::value<unsigned>()->default_value(1), "Number of threads") +#endif +        ("correction_buffers,m", po::value<int>()->default_value(10), "LBFGS correction buffers"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config,c", po::value<string>(&cfg_file), "Configuration file") +        ("help,?", "Print this help message and exit"); + +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(dconfig_options).add(clo); +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (cfg_file.size() > 0) { +    ReadFile rf(cfg_file); +    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); +  } +  if (conf->count("help")) { +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +void AddFeatures(double prob, const SparseVector<double>& fmap, vector<double>* g) { +  for (SparseVector<double>::const_iterator it = fmap.begin(); it != fmap.end(); ++it) +    (*g)[it->first] += it->second * prob; +} + +double ApplyRegularizationTerms(const double C, +                                const vector<double>& weights, +                                vector<double>* g) { +  assert(weights.size() == g->size()); +  double reg = 0; +  for (size_t i = 0; i < weights.size(); ++i) { +//    const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); +    const double& w_i = weights[i]; +    double& g_i = (*g)[i]; +    reg += C * w_i * w_i; +    g_i += 2 * C * w_i; + +//    reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); +//    g_i += 2 * T * (w_i - prev_w_i); +  } +  return reg; +} + +struct GradientWorker { +  GradientWorker(int f, +                 int t, +                 vector<double>* w, +                 vector<TrainingInstance>* c, +                 vector<ArcFactoredForest>* fs) : obj(), weights(*w), from(f), to(t), corpus(*c), forests(*fs), g(w->size()) {} +  void operator()() { +    int every = (to - from) / 20; +    if (!every) every++; +    for (int i = from; i < to; ++i) { +      if ((from == 0) && (i + 1) % every == 0) cerr << '.' << flush; +      const int num_words = corpus[i].ts.words.size(); +      forests[i].Reweight(weights); +      prob_t z; +      forests[i].EdgeMarginals(&z); +      obj -= log(z); +      //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << "  OO= " << (-corpus[i].features.dot(weights) - lz) << endl; +      //cerr << " ZZ = " << zz << endl; +      for (int h = -1; h < num_words; ++h) { +        for (int m = 0; m < num_words; ++m) { +          if (h == m) continue; +          const ArcFactoredForest::Edge& edge = forests[i](h,m); +          const SparseVector<weight_t>& fmap = edge.features; +          double prob = edge.edge_prob.as_float(); +          if (prob < -0.000001) { cerr << "Prob < 0: " << prob << endl; prob = 0; } +          if (prob > 1.000001) { cerr << "Prob > 1: " << prob << endl; prob = 1; } +          AddFeatures(prob, fmap, &g); +          //mfm += fmap * prob;  // DE +        } +      } +    } +  } +  double obj; +  vector<double>& weights; +  const int from, to; +  vector<TrainingInstance>& corpus; +  vector<ArcFactoredForest>& forests; +  vector<double> g; // local gradient +};  int main(int argc, char** argv) { -  ArcFactoredForest af(5); -  cerr << af(0,3) << endl; +  int rank = 0; +  int size = 1; +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); +  if (conf.count("cmph_perfect_feature_hash")) { +    cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n"; +    FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>()); +    cerr << "  " << FD::NumFeats() << " features in map\n"; +  } +  ArcFeatureFunctions ffs; +  vector<TrainingInstance> corpus; +  TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus, rank, size); +  vector<weight_t> weights; +    Weights::InitFromFile(conf["weights"].as<string>(), &weights); +  vector<ArcFactoredForest> forests(corpus.size()); +  SparseVector<double> empirical; +  cerr << "Extracting features...\n"; +  bool flag = false; +  for (int i = 0; i < corpus.size(); ++i) { +    TrainingInstance& cur = corpus[i]; +    if (rank == 0 && (i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } +    if (rank == 0 && (i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } +    ffs.PrepareForInput(cur.ts); +    SparseVector<weight_t> efmap; +    for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { +      efmap.clear(); +      ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, +                       cur.tree.h_m_pairs[j].second, +                       &efmap); +      cur.features += efmap; +    } +    for (int j = 0; j < cur.tree.roots.size(); ++j) { +      efmap.clear(); +      ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); +      cur.features += efmap; +    } +    empirical += cur.features; +    forests[i].resize(cur.ts.words.size()); +    forests[i].ExtractFeatures(cur.ts, ffs); +  } +  if (flag) cerr << endl; +  //cerr << "EMP: " << empirical << endl; //DE +  weights.resize(FD::NumFeats(), 0.0); +  vector<weight_t> g(FD::NumFeats(), 0.0); +  cerr << "features initialized\noptimizing...\n"; +  boost::shared_ptr<BatchOptimizer> o; +#if HAVE_THREAD +  unsigned threads = conf["threads"].as<unsigned>(); +  if (threads > corpus.size()) threads = corpus.size(); +#else +  const unsigned threads = 1; +#endif +  int chunk = corpus.size() / threads; +  o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as<int>())); +  int iterations = 1000; +  for (int iter = 0; iter < iterations; ++iter) { +    cerr << "ITERATION " << iter << " " << flush; +    fill(g.begin(), g.end(), 0.0); +    for (SparseVector<double>::const_iterator it = empirical.begin(); it != empirical.end(); ++it) +      g[it->first] = -it->second; +    double obj = -empirical.dot(weights); +    vector<boost::shared_ptr<GradientWorker> > jobs; +    for (int from = 0; from < corpus.size(); from += chunk) { +      int to = from + chunk; +      if (to > corpus.size()) to = corpus.size(); +      jobs.push_back(boost::shared_ptr<GradientWorker>(new GradientWorker(from, to, &weights, &corpus, &forests))); +    } +#if HAVE_THREAD +    boost::thread_group tg; +    for (int i = 0; i < threads; ++i) +      tg.create_thread(boost::ref(*jobs[i])); +    tg.join_all(); +#else +    (*jobs[0])(); +#endif +    for (int i = 0; i < threads; ++i) { +      obj += jobs[i]->obj; +      vector<double>& tg = jobs[i]->g; +      for (unsigned j = 0; j < g.size(); ++j) +        g[j] += tg[j]; +    } +    // SparseVector<double> mfm;  //DE +    //cerr << endl << "E: " << empirical << endl;  // DE +    //cerr << "M: " << mfm << endl;  // DE +    double r = ApplyRegularizationTerms(conf["regularization_strength"].as<double>(), weights, &g); +    double gnorm = 0; +    for (int i = 0; i < g.size(); ++i) +      gnorm += g[i]*g[i]; +    ostringstream ll; +    ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); +    cerr << ' ' << ll.str().substr(ll.str().find('\t')+1) << endl; +    obj += r; +    assert(obj >= 0); +    o->Optimize(obj, g, &weights); +    Weights::ShowLargestFeatures(weights); +    const bool converged = o->HasConverged(); +    const char* ofname = converged ? "weights.final.gz" : "weights.cur.gz"; +    if (converged || ((iter+1) % conf["output_every_i_iterations"].as<unsigned>()) == 0) { +      cerr << "writing..." << flush; +      const string sl = ll.str(); +      Weights::WriteToFile(ofname, weights, true, &sl); +      cerr << "done" << endl; +    } +    if (converged) { cerr << "CONVERGED\n"; break; } +  }    return 0;  } diff --git a/rst_parser/picojson.h b/rst_parser/picojson.h new file mode 100644 index 00000000..bdb26057 --- /dev/null +++ b/rst_parser/picojson.h @@ -0,0 +1,979 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011 Kazuho Oku + *  + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + *  + * 1. Redistributions of source code must retain the above copyright notice, + *    this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + *    this list of conditions and the following disclaimer in the documentation + *    and/or other materials provided with the distribution. + *  + * THIS SOFTWARE IS PROVIDED BY CYBOZU LABS, INC. ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO + * EVENT SHALL CYBOZU LABS, INC. OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, + * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *  + * The views and conclusions contained in the software and documentation are + * those of the authors and should not be interpreted as representing official + * policies, either expressed or implied, of Cybozu Labs, Inc. + * + */ +#ifndef picojson_h +#define picojson_h + +#include <cassert> +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <iterator> +#include <map> +#include <string> +#include <vector> + +#ifdef _MSC_VER +    #define SNPRINTF _snprintf_s +    #pragma warning(push) +    #pragma warning(disable : 4244) // conversion from int to char +#else +    #define SNPRINTF snprintf +#endif + +namespace picojson { +   +  enum { +    null_type, +    boolean_type, +    number_type, +    string_type, +    array_type, +    object_type +  }; +   +  struct null {}; +   +  class value { +  public: +    typedef std::vector<value> array; +    typedef std::map<std::string, value> object; +  protected: +    int type_; +    union { +      bool boolean_; +      double number_; +      std::string* string_; +      array* array_; +      object* object_; +    }; +  public: +    value(); +    value(int type, bool); +    explicit value(bool b); +    explicit value(double n); +    explicit value(const std::string& s); +    explicit value(const array& a); +    explicit value(const object& o); +    explicit value(const char* s); +    value(const char* s, size_t len); +    ~value(); +    value(const value& x); +    value& operator=(const value& x); +    template <typename T> bool is() const; +    template <typename T> const T& get() const; +    template <typename T> T& get(); +    bool evaluate_as_boolean() const; +    const value& get(size_t idx) const; +    const value& get(const std::string& key) const; +    bool contains(size_t idx) const; +    bool contains(const std::string& key) const; +    std::string to_str() const; +    template <typename Iter> void serialize(Iter os) const; +    std::string serialize() const; +  private: +    template <typename T> value(const T*); // intentionally defined to block implicit conversion of pointer to bool +  }; +   +  typedef value::array array; +  typedef value::object object; +   +  inline value::value() : type_(null_type) {} +   +  inline value::value(int type, bool) : type_(type) { +    switch (type) { +#define INIT(p, v) case p##type: p = v; break +      INIT(boolean_, false); +      INIT(number_, 0.0); +      INIT(string_, new std::string()); +      INIT(array_, new array()); +      INIT(object_, new object()); +#undef INIT +    default: break; +    } +  } +   +  inline value::value(bool b) : type_(boolean_type) { +    boolean_ = b; +  } +   +  inline value::value(double n) : type_(number_type) { +    number_ = n; +  } +   +  inline value::value(const std::string& s) : type_(string_type) { +    string_ = new std::string(s); +  } +   +  inline value::value(const array& a) : type_(array_type) { +    array_ = new array(a); +  } +   +  inline value::value(const object& o) : type_(object_type) { +    object_ = new object(o); +  } +   +  inline value::value(const char* s) : type_(string_type) { +    string_ = new std::string(s); +  } +   +  inline value::value(const char* s, size_t len) : type_(string_type) { +    string_ = new std::string(s, len); +  } +   +  inline value::~value() { +    switch (type_) { +#define DEINIT(p) case p##type: delete p; break +      DEINIT(string_); +      DEINIT(array_); +      DEINIT(object_); +#undef DEINIT +    default: break; +    } +  } +   +  inline value::value(const value& x) : type_(x.type_) { +    switch (type_) { +#define INIT(p, v) case p##type: p = v; break +      INIT(boolean_, x.boolean_); +      INIT(number_, x.number_); +      INIT(string_, new std::string(*x.string_)); +      INIT(array_, new array(*x.array_)); +      INIT(object_, new object(*x.object_)); +#undef INIT +    default: break; +    } +  } +   +  inline value& value::operator=(const value& x) { +    if (this != &x) { +      this->~value(); +      new (this) value(x); +    } +    return *this; +  } +   +#define IS(ctype, jtype)			     \ +  template <> inline bool value::is<ctype>() const { \ +    return type_ == jtype##_type;		     \ +  } +  IS(null, null) +  IS(bool, boolean) +  IS(int, number) +  IS(double, number) +  IS(std::string, string) +  IS(array, array) +  IS(object, object) +#undef IS +   +#define GET(ctype, var)						\ +  template <> inline const ctype& value::get<ctype>() const {	\ +    assert("type mismatch! call vis<type>() before get<type>()" \ +	   && is<ctype>());				        \ +    return var;							\ +  }								\ +  template <> inline ctype& value::get<ctype>() {		\ +    assert("type mismatch! call is<type>() before get<type>()"	\ +	   && is<ctype>());					\ +    return var;							\ +  } +  GET(bool, boolean_) +  GET(double, number_) +  GET(std::string, *string_) +  GET(array, *array_) +  GET(object, *object_) +#undef GET +   +  inline bool value::evaluate_as_boolean() const { +    switch (type_) { +    case null_type: +      return false; +    case boolean_type: +      return boolean_; +    case number_type: +      return number_ != 0; +    case string_type: +      return ! string_->empty(); +    default: +      return true; +    } +  } +   +  inline const value& value::get(size_t idx) const { +    static value s_null; +    assert(is<array>()); +    return idx < array_->size() ? (*array_)[idx] : s_null; +  } + +  inline const value& value::get(const std::string& key) const { +    static value s_null; +    assert(is<object>()); +    object::const_iterator i = object_->find(key); +    return i != object_->end() ? i->second : s_null; +  } + +  inline bool value::contains(size_t idx) const { +    assert(is<array>()); +    return idx < array_->size(); +  } + +  inline bool value::contains(const std::string& key) const { +    assert(is<object>()); +    object::const_iterator i = object_->find(key); +    return i != object_->end(); +  } +   +  inline std::string value::to_str() const { +    switch (type_) { +    case null_type:      return "null"; +    case boolean_type:   return boolean_ ? "true" : "false"; +    case number_type:    { +      char buf[256]; +      double tmp; +      SNPRINTF(buf, sizeof(buf), modf(number_, &tmp) == 0 ? "%.f" : "%f", number_); +      return buf; +    } +    case string_type:    return *string_; +    case array_type:     return "array"; +    case object_type:    return "object"; +    default:             assert(0); +#ifdef _MSC_VER +      __assume(0); +#endif +    } +  } +   +  template <typename Iter> void copy(const std::string& s, Iter oi) { +    std::copy(s.begin(), s.end(), oi); +  } +   +  template <typename Iter> void serialize_str(const std::string& s, Iter oi) { +    *oi++ = '"'; +    for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) { +      switch (*i) { +#define MAP(val, sym) case val: copy(sym, oi); break +	MAP('"', "\\\""); +	MAP('\\', "\\\\"); +	MAP('/', "\\/"); +	MAP('\b', "\\b"); +	MAP('\f', "\\f"); +	MAP('\n', "\\n"); +	MAP('\r', "\\r"); +	MAP('\t', "\\t"); +#undef MAP +      default: +	if ((unsigned char)*i < 0x20 || *i == 0x7f) { +	  char buf[7]; +	  SNPRINTF(buf, sizeof(buf), "\\u%04x", *i & 0xff); +	  copy(buf, buf + 6, oi); +	  } else { +	  *oi++ = *i; +	} +	break; +      } +    } +    *oi++ = '"'; +  } +   +  template <typename Iter> void value::serialize(Iter oi) const { +    switch (type_) { +    case string_type: +      serialize_str(*string_, oi); +      break; +    case array_type: { +      *oi++ = '['; +      for (array::const_iterator i = array_->begin(); i != array_->end(); ++i) { +	if (i != array_->begin()) { +	  *oi++ = ','; +	} +	i->serialize(oi); +      } +      *oi++ = ']'; +      break; +    } +    case object_type: { +      *oi++ = '{'; +      for (object::const_iterator i = object_->begin(); +	   i != object_->end(); +	   ++i) { +	if (i != object_->begin()) { +	  *oi++ = ','; +	} +	serialize_str(i->first, oi); +	*oi++ = ':'; +	i->second.serialize(oi); +      } +      *oi++ = '}'; +      break; +    } +    default: +      copy(to_str(), oi); +      break; +    } +  } +   +  inline std::string value::serialize() const { +    std::string s; +    serialize(std::back_inserter(s)); +    return s; +  } +   +  template <typename Iter> class input { +  protected: +    Iter cur_, end_; +    int last_ch_; +    bool ungot_; +    int line_; +  public: +    input(const Iter& first, const Iter& last) : cur_(first), end_(last), last_ch_(-1), ungot_(false), line_(1) {} +    int getc() { +      if (ungot_) { +	ungot_ = false; +	return last_ch_; +      } +      if (cur_ == end_) { +	last_ch_ = -1; +	return -1; +      } +      if (last_ch_ == '\n') { +	line_++; +      } +      last_ch_ = *cur_++ & 0xff; +      return last_ch_; +    } +    void ungetc() { +      if (last_ch_ != -1) { +	assert(! ungot_); +	ungot_ = true; +      } +    } +    Iter cur() const { return cur_; } +    int line() const { return line_; } +    void skip_ws() { +      while (1) { +	int ch = getc(); +	if (! (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { +	  ungetc(); +	  break; +	} +      } +    } +    int expect(int expect) { +      skip_ws(); +      if (getc() != expect) { +	ungetc(); +	return false; +      } +      return true; +    } +    bool match(const std::string& pattern) { +      for (std::string::const_iterator pi(pattern.begin()); +	   pi != pattern.end(); +	   ++pi) { +	if (getc() != *pi) { +	  ungetc(); +	  return false; +	} +      } +      return true; +    } +  }; +   +  template<typename Iter> inline int _parse_quadhex(input<Iter> &in) { +    int uni_ch = 0, hex; +    for (int i = 0; i < 4; i++) { +      if ((hex = in.getc()) == -1) { +	return -1; +      } +      if ('0' <= hex && hex <= '9') { +	hex -= '0'; +      } else if ('A' <= hex && hex <= 'F') { +	hex -= 'A' - 0xa; +      } else if ('a' <= hex && hex <= 'f') { +	hex -= 'a' - 0xa; +      } else { +	in.ungetc(); +	return -1; +      } +      uni_ch = uni_ch * 16 + hex; +    } +    return uni_ch; +  } +   +  template<typename String, typename Iter> inline bool _parse_codepoint(String& out, input<Iter>& in) { +    int uni_ch; +    if ((uni_ch = _parse_quadhex(in)) == -1) { +      return false; +    } +    if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { +      if (0xdc00 <= uni_ch) { +	// a second 16-bit of a surrogate pair appeared +	return false; +      } +      // first 16-bit of surrogate pair, get the next one +      if (in.getc() != '\\' || in.getc() != 'u') { +	in.ungetc(); +	return false; +      } +      int second = _parse_quadhex(in); +      if (! (0xdc00 <= second && second <= 0xdfff)) { +	return false; +      } +      uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); +      uni_ch += 0x10000; +    } +    if (uni_ch < 0x80) { +      out.push_back(uni_ch); +    } else { +      if (uni_ch < 0x800) { +	out.push_back(0xc0 | (uni_ch >> 6)); +      } else { +	if (uni_ch < 0x10000) { +	  out.push_back(0xe0 | (uni_ch >> 12)); +	} else { +	  out.push_back(0xf0 | (uni_ch >> 18)); +	  out.push_back(0x80 | ((uni_ch >> 12) & 0x3f)); +	} +	out.push_back(0x80 | ((uni_ch >> 6) & 0x3f)); +      } +      out.push_back(0x80 | (uni_ch & 0x3f)); +    } +    return true; +  } +   +  template<typename String, typename Iter> inline bool _parse_string(String& out, input<Iter>& in) { +    while (1) { +      int ch = in.getc(); +      if (ch < ' ') { +	in.ungetc(); +	return false; +      } else if (ch == '"') { +	return true; +      } else if (ch == '\\') { +	if ((ch = in.getc()) == -1) { +	  return false; +	} +	switch (ch) { +#define MAP(sym, val) case sym: out.push_back(val); break +	  MAP('"', '\"'); +	  MAP('\\', '\\'); +	  MAP('/', '/'); +	  MAP('b', '\b'); +	  MAP('f', '\f'); +	  MAP('n', '\n'); +	  MAP('r', '\r'); +	  MAP('t', '\t'); +#undef MAP +	case 'u': +	  if (! _parse_codepoint(out, in)) { +	    return false; +	  } +	  break; +	default: +	  return false; +	} +      } else { +	out.push_back(ch); +      } +    } +    return false; +  } +   +  template <typename Context, typename Iter> inline bool _parse_array(Context& ctx, input<Iter>& in) { +    if (! ctx.parse_array_start()) { +      return false; +    } +    if (in.expect(']')) { +      return true; +    } +    size_t idx = 0; +    do { +      if (! ctx.parse_array_item(in, idx)) { +	return false; +      } +      idx++; +    } while (in.expect(',')); +    return in.expect(']'); +  } +   +  template <typename Context, typename Iter> inline bool _parse_object(Context& ctx, input<Iter>& in) { +    if (! ctx.parse_object_start()) { +      return false; +    } +    if (in.expect('}')) { +      return true; +    } +    do { +      std::string key; +      if (! in.expect('"') +	  || ! _parse_string(key, in) +	  || ! in.expect(':')) { +	return false; +      } +      if (! ctx.parse_object_item(in, key)) { +	return false; +      } +    } while (in.expect(',')); +    return in.expect('}'); +  } +   +  template <typename Iter> inline bool _parse_number(double& out, input<Iter>& in) { +    std::string num_str; +    while (1) { +      int ch = in.getc(); +      if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == '.' +	  || ch == 'e' || ch == 'E') { +	num_str.push_back(ch); +      } else { +	in.ungetc(); +	break; +      } +    } +    char* endp; +    out = strtod(num_str.c_str(), &endp); +    return endp == num_str.c_str() + num_str.size(); +  } +   +  template <typename Context, typename Iter> inline bool _parse(Context& ctx, input<Iter>& in) { +    in.skip_ws(); +    int ch = in.getc(); +    switch (ch) { +#define IS(ch, text, op) case ch: \ +      if (in.match(text) && op) { \ +	return true; \ +      } else { \ +	return false; \ +      } +      IS('n', "ull", ctx.set_null()); +      IS('f', "alse", ctx.set_bool(false)); +      IS('t', "rue", ctx.set_bool(true)); +#undef IS +    case '"': +      return ctx.parse_string(in); +    case '[': +      return _parse_array(ctx, in); +    case '{': +      return _parse_object(ctx, in); +    default: +      if (('0' <= ch && ch <= '9') || ch == '-') { +	in.ungetc(); +	double f; +	if (_parse_number(f, in)) { +	  ctx.set_number(f); +	  return true; +	} else { +	  return false; +	} +      } +      break; +    } +    in.ungetc(); +    return false; +  } +   +  class deny_parse_context { +  public: +    bool set_null() { return false; } +    bool set_bool(bool) { return false; } +    bool set_number(double) { return false; } +    template <typename Iter> bool parse_string(input<Iter>&) { return false; } +    bool parse_array_start() { return false; } +    template <typename Iter> bool parse_array_item(input<Iter>&, size_t) { +      return false; +    } +    bool parse_object_start() { return false; } +    template <typename Iter> bool parse_object_item(input<Iter>&, const std::string&) { +      return false; +    } +  }; +   +  class default_parse_context { +  protected: +    value* out_; +  public: +    default_parse_context(value* out) : out_(out) {} +    bool set_null() { +      *out_ = value(); +      return true; +    } +    bool set_bool(bool b) { +      *out_ = value(b); +      return true; +    } +    bool set_number(double f) { +      *out_ = value(f); +      return true; +    } +    template<typename Iter> bool parse_string(input<Iter>& in) { +      *out_ = value(string_type, false); +      return _parse_string(out_->get<std::string>(), in); +    } +    bool parse_array_start() { +      *out_ = value(array_type, false); +      return true; +    } +    template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) { +      array& a = out_->get<array>(); +      a.push_back(value()); +      default_parse_context ctx(&a.back()); +      return _parse(ctx, in); +    } +    bool parse_object_start() { +      *out_ = value(object_type, false); +      return true; +    } +    template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string& key) { +      object& o = out_->get<object>(); +      default_parse_context ctx(&o[key]); +      return _parse(ctx, in); +    } +  private: +    default_parse_context(const default_parse_context&); +    default_parse_context& operator=(const default_parse_context&); +  }; + +  class null_parse_context { +  public: +    struct dummy_str { +      void push_back(int) {} +    }; +  public: +    null_parse_context() {} +    bool set_null() { return true; } +    bool set_bool(bool) { return true; } +    bool set_number(double) { return true; } +    template <typename Iter> bool parse_string(input<Iter>& in) { +      dummy_str s; +      return _parse_string(s, in); +    } +    bool parse_array_start() { return true; } +    template <typename Iter> bool parse_array_item(input<Iter>& in, size_t) { +      return _parse(*this, in); +    } +    bool parse_object_start() { return true; } +    template <typename Iter> bool parse_object_item(input<Iter>& in, const std::string&) { +      return _parse(*this, in); +    } +  private: +    null_parse_context(const null_parse_context&); +    null_parse_context& operator=(const null_parse_context&); +  }; +   +  // obsolete, use the version below +  template <typename Iter> inline std::string parse(value& out, Iter& pos, const Iter& last) { +    std::string err; +    pos = parse(out, pos, last, &err); +    return err; +  } +   +  template <typename Context, typename Iter> inline Iter _parse(Context& ctx, const Iter& first, const Iter& last, std::string* err) { +    input<Iter> in(first, last); +    if (! _parse(ctx, in) && err != NULL) { +      char buf[64]; +      SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); +      *err = buf; +      while (1) { +	int ch = in.getc(); +	if (ch == -1 || ch == '\n') { +	  break; +	} else if (ch >= ' ') { +	  err->push_back(ch); +	} +      } +    } +    return in.cur(); +  } +   +  template <typename Iter> inline Iter parse(value& out, const Iter& first, const Iter& last, std::string* err) { +    default_parse_context ctx(&out); +    return _parse(ctx, first, last, err); +  } +   +  inline std::string parse(value& out, std::istream& is) { +    std::string err; +    parse(out, std::istreambuf_iterator<char>(is.rdbuf()), +	  std::istreambuf_iterator<char>(), &err); +    return err; +  } +   +  template <typename T> struct last_error_t { +    static std::string s; +  }; +  template <typename T> std::string last_error_t<T>::s; +   +  inline void set_last_error(const std::string& s) { +    last_error_t<bool>::s = s; +  } +   +  inline const std::string& get_last_error() { +    return last_error_t<bool>::s; +  } + +  inline bool operator==(const value& x, const value& y) { +    if (x.is<null>()) +      return y.is<null>(); +#define PICOJSON_CMP(type)					\ +    if (x.is<type>())						\ +      return y.is<type>() && x.get<type>() == y.get<type>() +    PICOJSON_CMP(bool); +    PICOJSON_CMP(double); +    PICOJSON_CMP(std::string); +    PICOJSON_CMP(array); +    PICOJSON_CMP(object); +#undef PICOJSON_CMP +    assert(0); +#ifdef _MSC_VER +    __assume(0); +#endif +    return false; +  } +   +  inline bool operator!=(const value& x, const value& y) { +    return ! (x == y); +  } +} + +inline std::istream& operator>>(std::istream& is, picojson::value& x) +{ +  picojson::set_last_error(std::string()); +  std::string err = picojson::parse(x, is); +  if (! err.empty()) { +    picojson::set_last_error(err); +    is.setstate(std::ios::failbit); +  } +  return is; +} + +inline std::ostream& operator<<(std::ostream& os, const picojson::value& x) +{ +  x.serialize(std::ostream_iterator<char>(os)); +  return os; +} +#ifdef _MSC_VER +    #pragma warning(pop) +#endif + +#endif +#ifdef TEST_PICOJSON +#ifdef _MSC_VER +    #pragma warning(disable : 4127) // conditional expression is constant +#endif + +using namespace std; +   +static void plan(int num) +{ +  printf("1..%d\n", num); +} + +static bool success = true; + +static void ok(bool b, const char* name = "") +{ +  static int n = 1; +  if (! b) +    success = false; +  printf("%s %d - %s\n", b ? "ok" : "ng", n++, name); +} + +template <typename T> void is(const T& x, const T& y, const char* name = "") +{ +  if (x == y) { +    ok(true, name); +  } else { +    ok(false, name); +  } +} + +#include <algorithm> + +int main(void) +{ +  plan(75); + +  // constructors +#define TEST(expr, expected) \ +    is(picojson::value expr .serialize(), string(expected), "picojson::value" #expr) +   +  TEST( (true),  "true"); +  TEST( (false), "false"); +  TEST( (42.0),   "42"); +  TEST( (string("hello")), "\"hello\""); +  TEST( ("hello"), "\"hello\""); +  TEST( ("hello", 4), "\"hell\""); +   +#undef TEST +   +#define TEST(in, type, cmp, serialize_test) {				\ +    picojson::value v;							\ +    const char* s = in;							\ +    string err = picojson::parse(v, s, s + strlen(s));			\ +    ok(err.empty(), in " no error");					\ +    ok(v.is<type>(), in " check type");					\ +    is<type>(v.get<type>(), cmp, in " correct output");			\ +    is(*s, '\0', in " read to eof");					\ +    if (serialize_test) {						\ +      is(v.serialize(), string(in), in " serialize");			\ +    }									\ +  } +  TEST("false", bool, false, true); +  TEST("true", bool, true, true); +  TEST("90.5", double, 90.5, false); +  TEST("\"hello\"", string, string("hello"), true); +  TEST("\"\\\"\\\\\\/\\b\\f\\n\\r\\t\"", string, string("\"\\/\b\f\n\r\t"), +       true); +  TEST("\"\\u0061\\u30af\\u30ea\\u30b9\"", string, +       string("a\xe3\x82\xaf\xe3\x83\xaa\xe3\x82\xb9"), false); +  TEST("\"\\ud840\\udc0b\"", string, string("\xf0\xa0\x80\x8b"), false); +#undef TEST + +#define TEST(type, expr) {					       \ +    picojson::value v;						       \ +    const char *s = expr;					       \ +    string err = picojson::parse(v, s, s + strlen(s));		       \ +    ok(err.empty(), "empty " #type " no error");		       \ +    ok(v.is<picojson::type>(), "empty " #type " check type");	       \ +    ok(v.get<picojson::type>().empty(), "check " #type " array size"); \ +  } +  TEST(array, "[]"); +  TEST(object, "{}"); +#undef TEST +   +  { +    picojson::value v; +    const char *s = "[1,true,\"hello\"]"; +    string err = picojson::parse(v, s, s + strlen(s)); +    ok(err.empty(), "array no error"); +    ok(v.is<picojson::array>(), "array check type"); +    is(v.get<picojson::array>().size(), size_t(3), "check array size"); +    ok(v.contains(0), "check contains array[0]"); +    ok(v.get(0).is<double>(), "check array[0] type"); +    is(v.get(0).get<double>(), 1.0, "check array[0] value"); +    ok(v.contains(1), "check contains array[1]"); +    ok(v.get(1).is<bool>(), "check array[1] type"); +    ok(v.get(1).get<bool>(), "check array[1] value"); +    ok(v.contains(2), "check contains array[2]"); +    ok(v.get(2).is<string>(), "check array[2] type"); +    is(v.get(2).get<string>(), string("hello"), "check array[2] value"); +    ok(!v.contains(3), "check not contains array[3]"); +  } +   +  { +    picojson::value v; +    const char *s = "{ \"a\": true }"; +    string err = picojson::parse(v, s, s + strlen(s)); +    ok(err.empty(), "object no error"); +    ok(v.is<picojson::object>(), "object check type"); +    is(v.get<picojson::object>().size(), size_t(1), "check object size"); +    ok(v.contains("a"), "check contains property"); +    ok(v.get("a").is<bool>(), "check bool property exists"); +    is(v.get("a").get<bool>(), true, "check bool property value"); +    is(v.serialize(), string("{\"a\":true}"), "serialize object"); +    ok(!v.contains("z"), "check not contains property"); +  } + +#define TEST(json, msg) do {				\ +    picojson::value v;					\ +    const char *s = json;				\ +    string err = picojson::parse(v, s, s + strlen(s));	\ +    is(err, string("syntax error at line " msg), msg);	\ +  } while (0) +  TEST("falsoa", "1 near: oa"); +  TEST("{]", "1 near: ]"); +  TEST("\n\bbell", "2 near: bell"); +  TEST("\"abc\nd\"", "1 near: "); +#undef TEST +   +  { +    picojson::value v1, v2; +    const char *s; +    string err; +    s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; +    err = picojson::parse(v1, s, s + strlen(s)); +    s = "{ \"d\": 2.0, \"b\": true, \"a\": [1,2,\"three\"] }"; +    err = picojson::parse(v2, s, s + strlen(s)); +    ok((v1 == v2), "check == operator in deep comparison"); +  } + +  { +    picojson::value v1, v2; +    const char *s; +    string err; +    s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; +    err = picojson::parse(v1, s, s + strlen(s)); +    s = "{ \"d\": 2.0, \"a\": [1,\"three\"], \"b\": true }"; +    err = picojson::parse(v2, s, s + strlen(s)); +    ok((v1 != v2), "check != operator for array in deep comparison"); +  } + +  { +    picojson::value v1, v2; +    const char *s; +    string err; +    s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; +    err = picojson::parse(v1, s, s + strlen(s)); +    s = "{ \"d\": 2.0, \"a\": [1,2,\"three\"], \"b\": false }"; +    err = picojson::parse(v2, s, s + strlen(s)); +    ok((v1 != v2), "check != operator for object in deep comparison"); +  } + +  { +    picojson::value v1, v2; +    const char *s; +    string err; +    s = "{ \"b\": true, \"a\": [1,2,\"three\"], \"d\": 2 }"; +    err = picojson::parse(v1, s, s + strlen(s)); +    picojson::object& o = v1.get<picojson::object>(); +    o.erase("b"); +    picojson::array& a = o["a"].get<picojson::array>(); +    picojson::array::iterator i; +    i = std::remove(a.begin(), a.end(), picojson::value(std::string("three"))); +    a.erase(i, a.end()); +    s = "{ \"a\": [1,2], \"d\": 2 }"; +    err = picojson::parse(v2, s, s + strlen(s)); +    ok((v1 == v2), "check erase()"); +  } + +  ok(picojson::value(3.0).serialize() == "3", +     "integral number should be serialized as a integer"); +   +  { +    const char* s = "{ \"a\": [1,2], \"d\": 2 }"; +    picojson::null_parse_context ctx; +    string err; +    picojson::_parse(ctx, s, s + strlen(s), &err); +    ok(err.empty(), "null_parse_context"); +  } +   +  return success ? 0 : 1; +} + +#endif diff --git a/rst_parser/rst.cc b/rst_parser/rst.cc index f6b295b3..bc91330b 100644 --- a/rst_parser/rst.cc +++ b/rst_parser/rst.cc @@ -2,6 +2,81 @@  using namespace std; -StochasticForest::StochasticForest(const ArcFactoredForest& af) { +// David B. Wilson. Generating Random Spanning Trees More Quickly than the Cover Time. +// this is an awesome algorithm +TreeSampler::TreeSampler(const ArcFactoredForest& af) : forest(af), usucc(af.size() + 1) { +  // edges are directed from modifiers to heads, and finally to the root +  vector<double> p; +  for (int m = 1; m <= forest.size(); ++m) { +#if USE_ALIAS_SAMPLER +    p.clear(); +#else +    SampleSet<double>& ss = usucc[m]; +#endif +    double z = 0; +    for (int h = 0; h <= forest.size(); ++h) { +      double u = forest(h-1,m-1).edge_prob.as_float(); +      z += u; +#if USE_ALIAS_SAMPLER +      p.push_back(u); +#else +      ss.add(u); +#endif +    } +#if USE_ALIAS_SAMPLER +    for (int i = 0; i < p.size(); ++i) { p[i] /= z; } +    usucc[m].Init(p); +#endif +  }  } +void TreeSampler::SampleRandomSpanningTree(EdgeSubset* tree, MT19937* prng) { +  MT19937& rng = *prng; +  const int r = 0; +  bool success = false; +  while (!success) { +    int roots = 0; +    tree->h_m_pairs.clear(); +    tree->roots.clear(); +    vector<int> next(forest.size() + 1, -1); +    vector<char> in_tree(forest.size() + 1, 0); +    in_tree[r] = 1; +    //cerr << "Forest size: " << forest.size() << endl; +    for (int i = 0; i <= forest.size(); ++i) { +      //cerr << "Sampling starting at u=" << i << endl; +      int u = i; +      if (in_tree[u]) continue; +      while(!in_tree[u]) { +#if USE_ALIAS_SAMPLER +        next[u] = usucc[u].Draw(rng); +#else +        next[u] = rng.SelectSample(usucc[u]); +#endif +        u = next[u]; +      } +      u = i; +      //cerr << (u-1); +      int prev = u-1; +      while(!in_tree[u]) { +        in_tree[u] = true; +        u = next[u]; +        //cerr << " > " << (u-1); +        if (u == r) { +          ++roots; +          tree->roots.push_back(prev); +        } else { +          tree->h_m_pairs.push_back(make_pair<short,short>(u-1,prev)); +        } +        prev = u-1; +      } +      //cerr << endl; +    } +    assert(roots > 0); +    if (roots > 1) { +      //cerr << "FAILURE\n"; +    } else { +      success = true; +    } +  } +}; + diff --git a/rst_parser/rst.h b/rst_parser/rst.h index 865871eb..8bf389f7 100644 --- a/rst_parser/rst.h +++ b/rst_parser/rst.h @@ -1,10 +1,21 @@  #ifndef _RST_H_  #define _RST_H_ +#include <vector> +#include "sampler.h"  #include "arc_factored.h" +#include "alias_sampler.h" -struct StochasticForest { -  explicit StochasticForest(const ArcFactoredForest& af); +struct TreeSampler { +  explicit TreeSampler(const ArcFactoredForest& af); +  void SampleRandomSpanningTree(EdgeSubset* tree, MT19937* rng); +  const ArcFactoredForest& forest; +#define USE_ALIAS_SAMPLER 1 +#if USE_ALIAS_SAMPLER +  std::vector<AliasSampler> usucc; +#else +  std::vector<SampleSet<double> > usucc; +#endif  };  #endif diff --git a/rst_parser/rst_parse.cc b/rst_parser/rst_parse.cc new file mode 100644 index 00000000..9c42a8f4 --- /dev/null +++ b/rst_parser/rst_parse.cc @@ -0,0 +1,111 @@ +#include "arc_factored.h" + +#include <vector> +#include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  string cfg_file; +  opts.add_options() +        ("input,i",po::value<string>()->default_value("-"), "File containing test data (jsent format)") +        ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution (mandatory)") +        ("p_weights,p",po::value<string>(), "Weights for target distribution (optional)") +        ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config,c", po::value<string>(&cfg_file), "Configuration file") +        ("help,?", "Print this help message and exit"); + +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(dconfig_options).add(clo); +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (cfg_file.size() > 0) { +    ReadFile rf(cfg_file); +    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); +  } +  if (conf->count("help") || conf->count("q_weights") == 0) { +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); +  vector<weight_t> qweights, pweights; +  Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights); +  if (conf.count("p_weights")) +    Weights::InitFromFile(conf["p_weights"].as<string>(), &pweights); +  const bool global = pweights.size() > 0; +  ArcFeatureFunctions ffs; +  GlobalFeatureFunctions gff; +  ReadFile rf(conf["input"].as<string>()); +  istream* in = rf.stream(); +  TrainingInstance sent; +  MT19937 rng; +  int samples = conf["samples"].as<unsigned>(); +  int totroot = 0, root_right = 0, tot = 0, cor = 0; +  while(TrainingInstance::ReadInstance(in, &sent)) { +    ffs.PrepareForInput(sent.ts); +    if (global) gff.PrepareForInput(sent.ts); +    ArcFactoredForest forest(sent.ts.pos.size()); +    forest.ExtractFeatures(sent.ts, ffs); +    forest.Reweight(qweights); +    TreeSampler ts(forest); +    double best_score = -numeric_limits<double>::infinity(); +    EdgeSubset best_tree; +    for (int n = 0; n < samples; ++n) { +      EdgeSubset tree; +      ts.SampleRandomSpanningTree(&tree, &rng); +      SparseVector<double> qfeats, gfeats; +      tree.ExtractFeatures(sent.ts, ffs, &qfeats); +      double score = 0; +      if (global) { +        gff.Features(sent.ts, tree, &gfeats); +        score = (qfeats + gfeats).dot(pweights); +      } else { +        score = qfeats.dot(qweights); +      } +      if (score > best_score) { +        best_tree = tree; +        best_score = score; +      } +    } +    cerr << "BEST SCORE: " << best_score << endl; +    cout << best_tree << endl; +    const bool sent_has_ref = sent.tree.h_m_pairs.size() > 0; +    if (sent_has_ref) { +      map<pair<short,short>, bool> ref; +      for (int i = 0; i < sent.tree.h_m_pairs.size(); ++i) +        ref[sent.tree.h_m_pairs[i]] = true; +      int ref_root = sent.tree.roots.front(); +      if (ref_root == best_tree.roots.front()) { ++root_right; } +      ++totroot; +      for (int i = 0; i < best_tree.h_m_pairs.size(); ++i) { +        if (ref[best_tree.h_m_pairs[i]]) { +          ++cor; +        } +        ++tot; +      } +    } +  } +  cerr << "F = " << (double(cor + root_right) / (tot + totroot)) << endl; +  return 0; +} + diff --git a/rst_parser/rst_test.cc b/rst_parser/rst_test.cc deleted file mode 100644 index e8fe706e..00000000 --- a/rst_parser/rst_test.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "arc_factored.h" - -#include <iostream> - -using namespace std; - -int main(int argc, char** argv) { -  // John saw Mary -  //   (H -> M) -  //   (1 -> 2) 20 -  //   (1 -> 3) 3 -  //   (2 -> 1) 20 -  //   (2 -> 3) 30 -  //   (3 -> 2) 0 -  //   (3 -> 1) 11 -  //   (0, 2) 10 -  //   (0, 1) 9 -  //   (0, 3) 9 -  ArcFactoredForest af(3); -  af(1,2).edge_prob.logeq(20); -  af(1,3).edge_prob.logeq(3); -  af(2,1).edge_prob.logeq(20); -  af(2,3).edge_prob.logeq(30); -  af(3,2).edge_prob.logeq(0); -  af(3,1).edge_prob.logeq(11); -  af(0,2).edge_prob.logeq(10); -  af(0,1).edge_prob.logeq(9); -  af(0,3).edge_prob.logeq(9); -  SpanningTree tree; -  af.MaximumSpanningTree(&tree); -  return 0; -} - diff --git a/rst_parser/rst_train.cc b/rst_parser/rst_train.cc new file mode 100644 index 00000000..9b730f3d --- /dev/null +++ b/rst_parser/rst_train.cc @@ -0,0 +1,144 @@ +#include "arc_factored.h" + +#include <vector> +#include <iostream> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "timing_stats.h" +#include "arc_ff.h" +#include "dep_training.h" +#include "stringlib.h" +#include "filelib.h" +#include "tdict.h" +#include "weights.h" +#include "rst.h" +#include "global_ff.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  string cfg_file; +  opts.add_options() +        ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") +        ("q_weights,q",po::value<string>(), "Arc-factored weights for proposal distribution") +        ("samples,n",po::value<unsigned>()->default_value(1000), "Number of samples"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config,c", po::value<string>(&cfg_file), "Configuration file") +        ("help,?", "Print this help message and exit"); + +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(dconfig_options).add(clo); +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (cfg_file.size() > 0) { +    ReadFile rf(cfg_file); +    po::store(po::parse_config_file(*rf.stream(), dconfig_options), *conf); +  } +  if (conf->count("help")) { +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); +  vector<weight_t> qweights(FD::NumFeats(), 0.0); +  Weights::InitFromFile(conf["q_weights"].as<string>(), &qweights); +  vector<TrainingInstance> corpus; +  ArcFeatureFunctions ffs; +  GlobalFeatureFunctions gff; +  TrainingInstance::ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); +  vector<ArcFactoredForest> forests(corpus.size()); +  vector<prob_t> zs(corpus.size()); +  SparseVector<double> empirical; +  bool flag = false; +  for (int i = 0; i < corpus.size(); ++i) { +    TrainingInstance& cur = corpus[i]; +    if ((i+1) % 10 == 0) { cerr << '.' << flush; flag = true; } +    if ((i+1) % 400 == 0) { cerr << " [" << (i+1) << "]\n"; flag = false; } +    SparseVector<weight_t> efmap; +    ffs.PrepareForInput(cur.ts); +    gff.PrepareForInput(cur.ts); +    for (int j = 0; j < cur.tree.h_m_pairs.size(); ++j) { +      efmap.clear(); +      ffs.EdgeFeatures(cur.ts, cur.tree.h_m_pairs[j].first, +                       cur.tree.h_m_pairs[j].second, +                       &efmap); +      cur.features += efmap; +    } +    for (int j = 0; j < cur.tree.roots.size(); ++j) { +      efmap.clear(); +      ffs.EdgeFeatures(cur.ts, -1, cur.tree.roots[j], &efmap); +      cur.features += efmap; +    } +    efmap.clear(); +    gff.Features(cur.ts, cur.tree, &efmap); +    cur.features += efmap; +    empirical += cur.features; +    forests[i].resize(cur.ts.words.size()); +    forests[i].ExtractFeatures(cur.ts, ffs); +    forests[i].Reweight(qweights); +    forests[i].EdgeMarginals(&zs[i]); +    zs[i] = prob_t::One() / zs[i]; +    // cerr << zs[i] << endl; +    forests[i].Reweight(qweights);    // EdgeMarginals overwrites edge_prob +  } +  if (flag) cerr << endl; +  MT19937 rng; +  SparseVector<double> model_exp; +  SparseVector<double> weights; +  Weights::InitSparseVector(qweights, &weights); +  int samples = conf["samples"].as<unsigned>(); +  for (int i = 0; i < corpus.size(); ++i) { +#if 0 +    forests[i].EdgeMarginals(); +    model_exp.clear(); +    for (int h = -1; h < num_words; ++h) { +      for (int m = 0; m < num_words; ++m) { +        if (h == m) continue; +        const ArcFactoredForest::Edge& edge = forests[i](h,m); +        const SparseVector<weight_t>& fmap = edge.features; +        double prob = edge.edge_prob.as_float(); +        model_exp += fmap * prob; +      } +    } +    cerr << "TRUE EXP: " << model_exp << endl; +    forests[i].Reweight(weights); +#endif + +    TreeSampler ts(forests[i]); +    prob_t zhat = prob_t::Zero(); +    SparseVector<prob_t> sampled_exp; +    for (int n = 0; n < samples; ++n) { +      EdgeSubset tree; +      ts.SampleRandomSpanningTree(&tree, &rng); +      SparseVector<double> qfeats, gfeats; +      tree.ExtractFeatures(corpus[i].ts, ffs, &qfeats); +      prob_t u; u.logeq(qfeats.dot(qweights)); +      const prob_t q = u / zs[i];  // proposal mass +      gff.Features(corpus[i].ts, tree, &gfeats); +      SparseVector<double> tot_feats = qfeats + gfeats; +      u.logeq(tot_feats.dot(weights)); +      prob_t w = u / q; +      zhat += w; +      for (SparseVector<double>::const_iterator it = tot_feats.begin(); it != tot_feats.end(); ++it) +        sampled_exp.add_value(it->first, w * prob_t(it->second)); +    } +    sampled_exp /= zhat; +    SparseVector<double> tot_m; +    for (SparseVector<prob_t>::const_iterator it = sampled_exp.begin(); it != sampled_exp.end(); ++it) +      tot_m.add_value(it->first, it->second.as_float()); +    //cerr << "DIFF: " << (tot_m - corpus[i].features) << endl; +    const double eta = 0.03; +    weights -= (tot_m - corpus[i].features) * eta; +  } +  cerr << "WEIGHTS.\n"; +  cerr << weights << endl; +  return 0; +} + | 
