diff options
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r-- | klm/lm/model_test.cc | 73 |
1 files changed, 62 insertions, 11 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]); } +template <class M> void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template <class M> void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest<Model>(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest<TrieModel>(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest<QuantTrieModel>(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest<QuantArrayTrieModel>(); +} template <class ModelT> void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without <unk>. + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest<QuantTrieModel>(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest<QuantArrayTrieModel>(); +} } // namespace } // namespace ngram |