#include "lm/model.hh" #include #define BOOST_TEST_MODULE ModelTest #include #include namespace lm { namespace ngram { std::ostream &operator<<(std::ostream &o, const State &state) { o << "State length " << static_cast(state.valid_length_) << ':'; for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { o << ' ' << *i; } return o; } namespace { #define StartTest(word, ngram, score) \ ret = model.FullScore( \ state, \ model.GetVocabulary().Index(word), \ out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ {\ WordIndex context[state.valid_length_ + 1]; \ context[0] = model.GetVocabulary().Index(word); \ std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ State get_state; \ model.GetState(context, context + state.valid_length_ + 1, get_state); \ BOOST_CHECK_EQUAL(out, get_state); \ } #define AppendTest(word, ngram, score) \ StartTest(word, ngram, score) \ state = out; template void Starters(const M &model) { FullScoreReturn ret; Model::State state(model.BeginSentenceState()); Model::State out; StartTest("looking", 2, -0.4846522); // , probability plus backoff StartTest(",", 1, -1.383514 + -0.4149733); // probability plus backoff StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); } template void Continuation(const M &model) { FullScoreReturn ret; Model::State state(model.BeginSentenceState()); Model::State out; AppendTest("looking", 2, -0.484652); AppendTest("on", 3, -0.348837); AppendTest("a", 4, -0.0155266); AppendTest("little", 5, -0.00306122); State preserve = state; AppendTest("the", 1, -4.04005); AppendTest("biarritz", 1, -1.9889); AppendTest("not_found", 1, -2.29666); AppendTest("more", 1, -1.20632 - 20.0); AppendTest(".", 2, -0.51363); AppendTest("", 3, -0.0191651); BOOST_CHECK_EQUAL(0, state.valid_length_); state = preserve; AppendTest("more", 5, -0.00181395); BOOST_CHECK_EQUAL(4, state.valid_length_); AppendTest("loin", 5, -0.0432557); BOOST_CHECK_EQUAL(1, state.valid_length_); } template void Blanks(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; AppendTest("also", 1, -1.687872); AppendTest("would", 2, -2); AppendTest("consider", 3, -3); State preserve = state; AppendTest("higher", 4, -4); AppendTest("looking", 5, -5); BOOST_CHECK_EQUAL(1, state.valid_length_); state = preserve; AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); state = model.NullContextState(); // higher looking is a blank. AppendTest("higher", 1, -1.509559); AppendTest("looking", 1, -1.285941 - 0.30103); AppendTest("not_found", 1, -1.995635 - 0.4771212); } template void Unknowns(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; AppendTest("not_found", 1, -1.995635); State preserve = state; AppendTest("not_found2", 2, -15.0); AppendTest("not_found3", 2, -15.0 - 2.0); state = preserve; AppendTest("however", 2, -4); AppendTest("not_found3", 3, -6); } template void MinimalState(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; AppendTest("baz", 1, -6.535897); BOOST_CHECK_EQUAL(0, state.valid_length_); state = model.NullContextState(); AppendTest("foo", 1, -3.141592); BOOST_CHECK_EQUAL(1, state.valid_length_); AppendTest("bar", 2, -6.0); // Has to include the backoff weight. BOOST_CHECK_EQUAL(1, state.valid_length_); AppendTest("bar", 1, -2.718281 + 3.0); BOOST_CHECK_EQUAL(1, state.valid_length_); state = model.NullContextState(); AppendTest("to", 1, -1.687872); AppendTest("look", 2, -0.2922095); BOOST_CHECK_EQUAL(2, state.valid_length_); AppendTest("good", 3, -7); } #define StatelessTest(word, provide, ngram, score) \ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \ ret = model.FullScore(before, indices[num_words - word - 1], out); \ BOOST_CHECK(state == out); \ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); template void Stateless(const M &model) { const char *words[] = {"", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", ""}; const size_t num_words = sizeof(words) / sizeof(const char*); // Silience "array subscript is above array bounds" when extracting end pointer. WordIndex indices[num_words + 1]; for (unsigned int i = 0; i < num_words; ++i) { indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]); } FullScoreReturn ret; State state, out, before; ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state); BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001); StatelessTest(1, 1, 2, -0.484652); // looking StatelessTest(1, 2, 2, -0.484652); // on AppendTest("on", 3, -0.348837); StatelessTest(2, 3, 3, -0.348837); StatelessTest(2, 2, 3, -0.348837); StatelessTest(2, 1, 2, -0.4638903); // a StatelessTest(3, 4, 4, -0.0155266); // little AppendTest("little", 5, -0.00306122); StatelessTest(4, 5, 5, -0.00306122); // the AppendTest("the", 1, -4.04005); StatelessTest(5, 5, 1, -4.04005); // No context of the. StatelessTest(5, 0, 1, -1.687872); // biarritz StatelessTest(6, 1, 1, -1.9889); // not found StatelessTest(7, 1, 1, -2.29666); StatelessTest(7, 0, 1, -1.995635); WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); BOOST_CHECK_EQUAL(1, state.valid_length_); BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } template void Everything(const M &m) { Starters(m); Continuation(m); Blanks(m); Unknowns(m); MinimalState(m); Stateless(m); } class ExpectEnumerateVocab : public EnumerateVocab { public: ExpectEnumerateVocab() {} void Add(WordIndex index, const StringPiece &str) { BOOST_CHECK_EQUAL(seen.size(), index); seen.push_back(std::string(str.data(), str.length())); } void Check(const base::Vocabulary &vocab) { BOOST_CHECK_EQUAL(37ULL, seen.size()); BOOST_REQUIRE(!seen.empty()); BOOST_CHECK_EQUAL("", seen[0]); for (WordIndex i = 0; i < seen.size(); ++i) { BOOST_CHECK_EQUAL(i, vocab.Index(seen[i])); } } void Clear() { seen.clear(); } std::vector seen; }; template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); Everything(m); } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } BOOST_AUTO_TEST_CASE(quant) { LoadingTest(); } template void BinaryTest() { Config config; config.write_mmap = "test.binary"; config.messages = NULL; ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; { ModelT copy_model("test.arpa", config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); Everything(copy_model); } config.write_mmap = NULL; ModelT binary("test.binary", config); enumerate.Check(binary.GetVocabulary()); Everything(binary); unlink("test.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest(); } BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); } BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } } // namespace } // namespace ngram } // namespace lm