summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/hg.h17
-rw-r--r--utils/Makefile.am2
-rw-r--r--utils/exp_semiring.h (renamed from decoder/exp_semiring.h)19
-rw-r--r--utils/logval.h12
-rw-r--r--utils/star.h12
6 files changed, 48 insertions, 15 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index c41cd7f9..7481192b 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -37,7 +37,6 @@ libcdec_a_SOURCES = \
csplit.h \
decoder.h \
earley_composer.h \
- exp_semiring.h \
factored_lexicon_helper.h \
ff.h \
ff_basic.h \
diff --git a/decoder/hg.h b/decoder/hg.h
index 3d8cd9bc..343b99cf 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -25,6 +25,7 @@
#include "tdict.h"
#include "trule.h"
#include "prob.h"
+#include "exp_semiring.h"
#include "indices_after.h"
#include "nt_span.h"
@@ -527,7 +528,21 @@ struct EdgeFeaturesAndProbWeightFunction {
struct TransitionCountWeightFunction {
typedef double Weight;
- inline double operator()(const HG::Edge& e) const { (void)e; return 1.0; }
+ inline double operator()(const HG::Edge&) const { return 1.0; }
+};
+
+template <class P, class PWeightFunction, class R, class RWeightFunction>
+struct PRWeightFunction {
+ explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(),
+ const RWeightFunction& rwf = RWeightFunction()) :
+ pweight(pwf), rweight(rwf) {}
+ PRPair<P,R> operator()(const HG::Edge& e) const {
+ const P p = pweight(e);
+ const R r = rweight(e);
+ return PRPair<P,R>(p, r * p);
+ }
+ const PWeightFunction pweight;
+ const RWeightFunction rweight;
};
#endif
diff --git a/utils/Makefile.am b/utils/Makefile.am
index a22b6727..c0ce3509 100644
--- a/utils/Makefile.am
+++ b/utils/Makefile.am
@@ -27,6 +27,7 @@ libutils_a_SOURCES = \
citycrc.h \
corpus_tools.h \
dict.h \
+ exp_semiring.h \
fast_sparse_vector.h \
fdict.h \
feature_vector.h \
@@ -49,6 +50,7 @@ libutils_a_SOURCES = \
show.h \
small_vector.h \
sparse_vector.h \
+ star.h \
static_utoa.h \
stringlib.h \
tdict.h \
diff --git a/decoder/exp_semiring.h b/utils/exp_semiring.h
index 2a9034bb..7572ccf5 100644
--- a/decoder/exp_semiring.h
+++ b/utils/exp_semiring.h
@@ -2,6 +2,7 @@
#define _EXP_SEMIRING_H_
#include <iostream>
+#include "star.h"
// this file implements the first-order expectation semiring described
// in Li & Eisner (EMNLP 2009)
@@ -54,18 +55,10 @@ const PRPair<P,R> operator*(const PRPair<P,R>& a, const PRPair<P,R>& b) {
return result;
}
-template <class P, class PWeightFunction, class R, class RWeightFunction>
-struct PRWeightFunction {
- explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(),
- const RWeightFunction& rwf = RWeightFunction()) :
- pweight(pwf), rweight(rwf) {}
- PRPair<P,R> operator()(const HG::Edge& e) const {
- const P p = pweight(e);
- const R r = rweight(e);
- return PRPair<P,R>(p, r * p);
- }
- const PWeightFunction pweight;
- const RWeightFunction rweight;
-};
+template <class P, class R>
+const PRPair<P,R> star(const PRPair<P,R>& x) {
+ const P pstar = star(x.p);
+ return PRPair<P,R>(pstar, pstar * x.r * pstar);
+}
#endif
diff --git a/utils/logval.h b/utils/logval.h
index ec1f6acd..7f1e1024 100644
--- a/utils/logval.h
+++ b/utils/logval.h
@@ -11,6 +11,7 @@
#include <cassert>
#include "semiring.h"
#include "show.h"
+#include "star.h"
//TODO: template for supporting negation or not - most uses are for nonnegative "probs" only; probably some 10-20% speedup available
template <class T>
@@ -242,4 +243,15 @@ bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
template <class T>
std::size_t hash_value(const LogVal<T>& x) { return x.hash_impl(); }
+template <class T>
+LogVal<T> star(LogVal<T> x) {
+ if (x.is_0()) return x;
+ if (x.v_ >= 0) {
+ x.v_ = std::numeric_limits<T>::infinity();
+ } else {
+ x.v_ = -log1p(-x.as_float());
+ }
+ return x;
+}
+
#endif
diff --git a/utils/star.h b/utils/star.h
new file mode 100644
index 00000000..3295112c
--- /dev/null
+++ b/utils/star.h
@@ -0,0 +1,12 @@
+#ifndef _STAR_H_
+#define _STAR_H_
+
+template <typename T>
+T star(const T& x) {
+ if (!x) return T();
+ if (x > T(1)) return std::numeric_limits<T>::infinity();
+ if (x < -T(1)) return -std::numeric_limits<T>::infinity();
+ return T(1) / (T(1) - x);
+}
+
+#endif