diff options
-rw-r--r-- | decoder/Makefile.am | 1 | ||||
-rw-r--r-- | decoder/hg.h | 17 | ||||
-rw-r--r-- | utils/Makefile.am | 2 | ||||
-rw-r--r-- | utils/exp_semiring.h (renamed from decoder/exp_semiring.h) | 19 | ||||
-rw-r--r-- | utils/logval.h | 12 | ||||
-rw-r--r-- | utils/star.h | 12 |
6 files changed, 48 insertions, 15 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index c41cd7f9..7481192b 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -37,7 +37,6 @@ libcdec_a_SOURCES = \ csplit.h \ decoder.h \ earley_composer.h \ - exp_semiring.h \ factored_lexicon_helper.h \ ff.h \ ff_basic.h \ diff --git a/decoder/hg.h b/decoder/hg.h index 3d8cd9bc..343b99cf 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -25,6 +25,7 @@ #include "tdict.h" #include "trule.h" #include "prob.h" +#include "exp_semiring.h" #include "indices_after.h" #include "nt_span.h" @@ -527,7 +528,21 @@ struct EdgeFeaturesAndProbWeightFunction { struct TransitionCountWeightFunction { typedef double Weight; - inline double operator()(const HG::Edge& e) const { (void)e; return 1.0; } + inline double operator()(const HG::Edge&) const { return 1.0; } +}; + +template <class P, class PWeightFunction, class R, class RWeightFunction> +struct PRWeightFunction { + explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), + const RWeightFunction& rwf = RWeightFunction()) : + pweight(pwf), rweight(rwf) {} + PRPair<P,R> operator()(const HG::Edge& e) const { + const P p = pweight(e); + const R r = rweight(e); + return PRPair<P,R>(p, r * p); + } + const PWeightFunction pweight; + const RWeightFunction rweight; }; #endif diff --git a/utils/Makefile.am b/utils/Makefile.am index a22b6727..c0ce3509 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -27,6 +27,7 @@ libutils_a_SOURCES = \ citycrc.h \ corpus_tools.h \ dict.h \ + exp_semiring.h \ fast_sparse_vector.h \ fdict.h \ feature_vector.h \ @@ -49,6 +50,7 @@ libutils_a_SOURCES = \ show.h \ small_vector.h \ sparse_vector.h \ + star.h \ static_utoa.h \ stringlib.h \ tdict.h \ diff --git a/decoder/exp_semiring.h b/utils/exp_semiring.h index 2a9034bb..7572ccf5 100644 --- a/decoder/exp_semiring.h +++ b/utils/exp_semiring.h @@ -2,6 +2,7 @@ #define _EXP_SEMIRING_H_ #include <iostream> +#include "star.h" // this file implements the first-order expectation semiring described // in Li & Eisner (EMNLP 2009) @@ -54,18 +55,10 @@ const PRPair<P,R> operator*(const PRPair<P,R>& a, const PRPair<P,R>& b) { return result; } -template <class P, class PWeightFunction, class R, class RWeightFunction> -struct PRWeightFunction { - explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), - const RWeightFunction& rwf = RWeightFunction()) : - pweight(pwf), rweight(rwf) {} - PRPair<P,R> operator()(const HG::Edge& e) const { - const P p = pweight(e); - const R r = rweight(e); - return PRPair<P,R>(p, r * p); - } - const PWeightFunction pweight; - const RWeightFunction rweight; -}; +template <class P, class R> +const PRPair<P,R> star(const PRPair<P,R>& x) { + const P pstar = star(x.p); + return PRPair<P,R>(pstar, pstar * x.r * pstar); +} #endif diff --git a/utils/logval.h b/utils/logval.h index ec1f6acd..7f1e1024 100644 --- a/utils/logval.h +++ b/utils/logval.h @@ -11,6 +11,7 @@ #include <cassert> #include "semiring.h" #include "show.h" +#include "star.h" //TODO: template for supporting negation or not - most uses are for nonnegative "probs" only; probably some 10-20% speedup available template <class T> @@ -242,4 +243,15 @@ bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) { template <class T> std::size_t hash_value(const LogVal<T>& x) { return x.hash_impl(); } +template <class T> +LogVal<T> star(LogVal<T> x) { + if (x.is_0()) return x; + if (x.v_ >= 0) { + x.v_ = std::numeric_limits<T>::infinity(); + } else { + x.v_ = -log1p(-x.as_float()); + } + return x; +} + #endif diff --git a/utils/star.h b/utils/star.h new file mode 100644 index 00000000..3295112c --- /dev/null +++ b/utils/star.h @@ -0,0 +1,12 @@ +#ifndef _STAR_H_ +#define _STAR_H_ + +template <typename T> +T star(const T& x) { + if (!x) return T(); + if (x > T(1)) return std::numeric_limits<T>::infinity(); + if (x < -T(1)) return -std::numeric_limits<T>::infinity(); + return T(1) / (T(1) - x); +} + +#endif |