diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-02 23:48:19 -0400 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-04-02 23:48:19 -0400 | 
| commit | bf4a7606151301dba49265e91c289f2caab2b7ec (patch) | |
| tree | ab98435d1e42efe71659e75773850301ba01aa98 /rst_parser/arc_factored.h | |
| parent | b6eede632af4fa58a6f5325ee0d059c02a898b9f (diff) | |
fix bug in lattices with OOVs
Diffstat (limited to 'rst_parser/arc_factored.h')
| -rw-r--r-- | rst_parser/arc_factored.h | 58 | 
1 files changed, 58 insertions, 0 deletions
| diff --git a/rst_parser/arc_factored.h b/rst_parser/arc_factored.h new file mode 100644 index 00000000..312d7d67 --- /dev/null +++ b/rst_parser/arc_factored.h @@ -0,0 +1,58 @@ +#ifndef _ARC_FACTORED_H_ +#define _ARC_FACTORED_H_ + +#include <vector> +#include <cassert> +#include "array2d.h" +#include "sparse_vector.h" + +class ArcFactoredForest { + public: +  explicit ArcFactoredForest(short num_words) : +      num_words_(num_words), +      root_edges_(num_words), +      edges_(num_words, num_words) {} + +  struct Edge { +    Edge() : features(), edge_prob(prob_t::Zero()) {} +    SparseVector<weight_t> features; +    prob_t edge_prob; +  }; + +  template <class V> +  void Reweight(const V& weights) { +    for (int m = 0; m < num_words_; ++m) { +      for (int h = 0; h < num_words_; ++h) { +        if (h != m) { +          Edge& e = edges_(h, m); +          e.edge_prob.logeq(e.features.dot(weights)); +        } +      } +      if (m) { +        Edge& e = root_edges_[m]; +        e.edge_prob.logeq(e.features.dot(weights)); +      } +    } +  } + +  const Edge& operator()(short h, short m) const { +    assert(m > 0); +    assert(m <= num_words_); +    assert(h >= 0); +    assert(h <= num_words_); +    return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; +  } +  Edge& operator()(short h, short m) { +    assert(m > 0); +    assert(m <= num_words_); +    assert(h >= 0); +    assert(h <= num_words_); +    return h ? edges_(h - 1, m - 1) : root_edges[m - 1]; +  } + private: +  unsigned num_words_; +  std::vector<Edge> root_edges_; +  Array2D<Edge> edges_; +}; + +#endif | 
