summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc35
1 files changed, 19 insertions, 16 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 8a122c60..32084b5b 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -6,6 +6,9 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+// Apparently some Boost versions use templates and are pretty strict about types matching.
+#define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
+
namespace lm {
namespace ngram {
@@ -46,7 +49,7 @@ template <class Model> State GetState(const Model &model, const char *word, cons
state, \
model.GetVocabulary().Index(word), \
out);\
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
@@ -176,14 +179,14 @@ template <class M> void ExtendLeftTest(const M &model) {
State right;
FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
const float kLittleProb = -1.285941;
- BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
unsigned char next_use;
float backoff_out[4];
FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
BOOST_CHECK_EQUAL(0, next_use);
BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
- BOOST_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
const WordIndex a = model.GetVocabulary().Index("a");
@@ -191,16 +194,16 @@ template <class M> void ExtendLeftTest(const M &model) {
// a little
FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
- BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
+ SLOPPY_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
BOOST_CHECK(!extend_a.independent_left);
const WordIndex on = model.GetVocabulary().Index("on");
FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
- BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
+ SLOPPY_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
BOOST_CHECK(!extend_on.independent_left);
@@ -208,9 +211,9 @@ template <class M> void ExtendLeftTest(const M &model) {
float backoff_in_arr[4];
FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(2, next_use);
- BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
+ SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
+ SLOPPY_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
BOOST_CHECK(!extend_both.independent_left);
BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
@@ -218,12 +221,12 @@ template <class M> void ExtendLeftTest(const M &model) {
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
ret = model.FullScore(before, indices[num_words - word - 1], out); \
BOOST_CHECK(state == out); \
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
template <class M> void Stateless(const M &model) {
@@ -238,7 +241,7 @@ template <class M> void Stateless(const M &model) {
State state, out, before;
ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
- BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
StatelessTest(1, 1, 2, -0.484652);
// looking
@@ -276,7 +279,7 @@ template <class M> void NoUnkCheck(const M &model) {
State state;
FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
- BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+ SLOPPY_CHECK_CLOSE(-100.0, ret.prob, 0.001);
}
template <class M> void Everything(const M &m) {
@@ -426,8 +429,8 @@ BOOST_AUTO_TEST_CASE(rest_max) {
RestProbingModel model(TestLocation(), config);
State state, out;
FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
- BOOST_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
- BOOST_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
+ SLOPPY_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
}
} // namespace