summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc73
1 files changed, 62 insertions, 11 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 8bf040ff..57c7291c 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
+template <class M> void NoUnkCheck(const M &model) {
+ WordIndex unk_index = 0;
+ State state;
+
+ FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
+ BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+}
+
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
@@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
- ExpectEnumerateVocab enumerate;
- config.enumerate_vocab = &enumerate;
config.probing_multiplier = 2.0;
- ModelT m("test.arpa", config);
- enumerate.Check(m.GetVocabulary());
- Everything(m);
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ Everything(m);
+ }
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test_nounk.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ NoUnkCheck(m);
+ }
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
-
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
-
-BOOST_AUTO_TEST_CASE(quant) {
+BOOST_AUTO_TEST_CASE(quant_trie) {
LoadingTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(bhiksha_trie) {
+ LoadingTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
+ LoadingTest<QuantArrayTrieModel>();
+}
template <class ModelT> void BinaryTest() {
Config config;
@@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() {
config.write_mmap = NULL;
- ModelT binary("test.binary", config);
- enumerate.Check(binary.GetVocabulary());
- Everything(binary);
+ ModelType type;
+ BOOST_REQUIRE(RecognizeBinary("test.binary", type));
+ BOOST_CHECK_EQUAL(ModelT::kModelType, type);
+
+ {
+ ModelT binary("test.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ Everything(binary);
+ }
unlink("test.binary");
+
+ // Now test without <unk>.
+ config.write_mmap = "test_nounk.binary";
+ config.messages = NULL;
+ enumerate.Clear();
+ {
+ ModelT copy_model("test_nounk.arpa", config);
+ enumerate.Check(copy_model.GetVocabulary());
+ enumerate.Clear();
+ NoUnkCheck(copy_model);
+ }
+ config.write_mmap = NULL;
+ {
+ ModelT binary("test_nounk.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ NoUnkCheck(binary);
+ }
+ unlink("test_nounk.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
@@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
BinaryTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
+ BinaryTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
+ BinaryTest<QuantArrayTrieModel>();
+}
} // namespace
} // namespace ngram