diff options
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r-- | klm/lm/model_test.cc | 74 |
1 files changed, 57 insertions, 17 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index b5125a95..89bbf2e8 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -33,7 +33,7 @@ template <class M> void Starters(const M &model) { // , probability plus <s> backoff StartTest(",", 1, -1.383514 + -0.4149733); // <unk> probability plus <s> backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); } template <class M> void Continuation(const M &model) { @@ -48,8 +48,8 @@ template <class M> void Continuation(const M &model) { State preserve = state; AppendTest("the", 1, -4.04005); AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); + AppendTest("not_found", 1, -2.29666); + AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("</s>", 3, -0.0191651); @@ -58,6 +58,42 @@ template <class M> void Continuation(const M &model) { AppendTest("loin", 5, -0.0432557); } +template <class M> void Blanks(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + AppendTest("also", 1, -1.687872); + AppendTest("would", 2, -2); + AppendTest("consider", 3, -3); + State preserve = state; + AppendTest("higher", 4, -4); + AppendTest("looking", 5, -5); + + state = preserve; + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + + state = model.NullContextState(); + // higher looking is a blank. + AppendTest("higher", 1, -1.509559); + AppendTest("looking", 1, -1.285941 - 0.30103); + AppendTest("not_found", 1, -1.995635 - 0.4771212); +} + +template <class M> void Unknowns(const M &model) { + FullScoreReturn ret; + State state(model.NullContextState()); + State out; + + AppendTest("not_found", 1, -1.995635); + State preserve = state; + AppendTest("not_found2", 2, -15.0); + AppendTest("not_found3", 2, -15.0 - 2.0); + + state = preserve; + AppendTest("however", 2, -4); + AppendTest("not_found3", 3, -6); +} + #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ @@ -103,16 +139,23 @@ template <class M> void Stateless(const M &model) { // biarritz StatelessTest(6, 1, 1, -1.9889); // not found - StatelessTest(7, 1, 0, -2.29666); - StatelessTest(7, 0, 0, -1.995635); + StatelessTest(7, 1, 1, -2.29666); + StatelessTest(7, 0, 1, -1.995635); WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(0, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]); } -//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"}; +template <class M> void Everything(const M &m) { + Starters(m); + Continuation(m); + Blanks(m); + Unknowns(m); + Stateless(m); +} class ExpectEnumerateVocab : public EnumerateVocab { public: @@ -148,18 +191,16 @@ template <class ModelT> void LoadingTest() { config.probing_multiplier = 2.0; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); - Starters(m); - Continuation(m); - Stateless(m); + Everything(m); } BOOST_AUTO_TEST_CASE(probing) { LoadingTest<Model>(); } -BOOST_AUTO_TEST_CASE(sorted) { +/*BOOST_AUTO_TEST_CASE(sorted) { LoadingTest<SortedModel>(); -} +}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest<TrieModel>(); } @@ -175,24 +216,23 @@ template <class ModelT> void BinaryTest() { ModelT copy_model("test.arpa", config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); + Everything(copy_model); } config.write_mmap = NULL; ModelT binary("test.binary", config); enumerate.Check(binary.GetVocabulary()); - Starters(binary); - Continuation(binary); - Stateless(binary); + Everything(binary); unlink("test.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest<Model>(); } -BOOST_AUTO_TEST_CASE(write_and_read_sorted) { +/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { BinaryTest<SortedModel>(); -} +}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest<TrieModel>(); } |