diff options
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r-- | klm/lm/model_test.cc | 49 |
1 files changed, 47 insertions, 2 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 89bbf2e8..548c098d 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -8,6 +8,15 @@ namespace lm { namespace ngram { + +std::ostream &operator<<(std::ostream &o, const State &state) { + o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':'; + for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << ' ' << *i; + } + return o; +} + namespace { #define StartTest(word, ngram, score) \ @@ -17,7 +26,15 @@ namespace { out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); + BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \ + {\ + WordIndex context[state.valid_length_ + 1]; \ + context[0] = model.GetVocabulary().Index(word); \ + std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + State get_state; \ + model.GetState(context, context + state.valid_length_ + 1, get_state); \ + BOOST_CHECK_EQUAL(out, get_state); \ + } #define AppendTest(word, ngram, score) \ StartTest(word, ngram, score) \ @@ -52,10 +69,13 @@ template <class M> void Continuation(const M &model) { AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("</s>", 3, -0.0191651); + BOOST_CHECK_EQUAL(0, state.valid_length_); state = preserve; AppendTest("more", 5, -0.00181395); + BOOST_CHECK_EQUAL(4, state.valid_length_); AppendTest("loin", 5, -0.0432557); + BOOST_CHECK_EQUAL(1, state.valid_length_); } template <class M> void Blanks(const M &model) { @@ -68,6 +88,7 @@ template <class M> void Blanks(const M &model) { State preserve = state; AppendTest("higher", 4, -4); AppendTest("looking", 5, -5); + BOOST_CHECK_EQUAL(1, state.valid_length_); state = preserve; AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); @@ -94,6 +115,29 @@ template <class M> void Unknowns(const M &model) { AppendTest("not_found3", 3, -6); } +template <class M> void MinimalState(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + + AppendTest("baz", 1, -6.535897); + BOOST_CHECK_EQUAL(0, state.valid_length_); + state = model.NullContextState(); + AppendTest("foo", 1, -3.141592); + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 2, -6.0); + // Has to include the backoff weight. + BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("bar", 1, -2.718281 + 3.0); + BOOST_CHECK_EQUAL(1, state.valid_length_); + + state = model.NullContextState(); + AppendTest("to", 1, -1.687872); + AppendTest("look", 2, -0.2922095); + BOOST_CHECK_EQUAL(2, state.valid_length_); + AppendTest("good", 3, -7); +} + #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ @@ -154,6 +198,7 @@ template <class M> void Everything(const M &m) { Continuation(m); Blanks(m); Unknowns(m); + MinimalState(m); Stateless(m); } @@ -167,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab { } void Check(const base::Vocabulary &vocab) { - BOOST_CHECK_EQUAL(34ULL, seen.size()); + BOOST_CHECK_EQUAL(37ULL, seen.size()); BOOST_REQUIRE(!seen.empty()); BOOST_CHECK_EQUAL("<unk>", seen[0]); for (WordIndex i = 0; i < seen.size(); ++i) { |