summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc49
1 files changed, 47 insertions, 2 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 89bbf2e8..548c098d 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -8,6 +8,15 @@
namespace lm {
namespace ngram {
+
+std::ostream &operator<<(std::ostream &o, const State &state) {
+ o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':';
+ for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) {
+ o << ' ' << *i;
+ }
+ return o;
+}
+
namespace {
#define StartTest(word, ngram, score) \
@@ -17,7 +26,15 @@ namespace {
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
+ BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \
+ {\
+ WordIndex context[state.valid_length_ + 1]; \
+ context[0] = model.GetVocabulary().Index(word); \
+ std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \
+ State get_state; \
+ model.GetState(context, context + state.valid_length_ + 1, get_state); \
+ BOOST_CHECK_EQUAL(out, get_state); \
+ }
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
@@ -52,10 +69,13 @@ template <class M> void Continuation(const M &model) {
AppendTest("more", 1, -1.20632 - 20.0);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
+ BOOST_CHECK_EQUAL(0, state.valid_length_);
state = preserve;
AppendTest("more", 5, -0.00181395);
+ BOOST_CHECK_EQUAL(4, state.valid_length_);
AppendTest("loin", 5, -0.0432557);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
}
template <class M> void Blanks(const M &model) {
@@ -68,6 +88,7 @@ template <class M> void Blanks(const M &model) {
State preserve = state;
AppendTest("higher", 4, -4);
AppendTest("looking", 5, -5);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
state = preserve;
AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103);
@@ -94,6 +115,29 @@ template <class M> void Unknowns(const M &model) {
AppendTest("not_found3", 3, -6);
}
+template <class M> void MinimalState(const M &model) {
+ FullScoreReturn ret;
+ State state(model.NullContextState());
+ State out;
+
+ AppendTest("baz", 1, -6.535897);
+ BOOST_CHECK_EQUAL(0, state.valid_length_);
+ state = model.NullContextState();
+ AppendTest("foo", 1, -3.141592);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("bar", 2, -6.0);
+ // Has to include the backoff weight.
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("bar", 1, -2.718281 + 3.0);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+
+ state = model.NullContextState();
+ AppendTest("to", 1, -1.687872);
+ AppendTest("look", 2, -0.2922095);
+ BOOST_CHECK_EQUAL(2, state.valid_length_);
+ AppendTest("good", 3, -7);
+}
+
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
@@ -154,6 +198,7 @@ template <class M> void Everything(const M &m) {
Continuation(m);
Blanks(m);
Unknowns(m);
+ MinimalState(m);
Stateless(m);
}
@@ -167,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab {
}
void Check(const base::Vocabulary &vocab) {
- BOOST_CHECK_EQUAL(34ULL, seen.size());
+ BOOST_CHECK_EQUAL(37ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {