summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc74
1 files changed, 57 insertions, 17 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index b5125a95..89bbf2e8 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -33,7 +33,7 @@ template <class M> void Starters(const M &model) {
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733);
// <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
+ StartTest("this_is_not_found", 1, -1.995635 + -0.4149733);
}
template <class M> void Continuation(const M &model) {
@@ -48,8 +48,8 @@ template <class M> void Continuation(const M &model) {
State preserve = state;
AppendTest("the", 1, -4.04005);
AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 0, -2.29666);
- AppendTest("more", 1, -1.20632);
+ AppendTest("not_found", 1, -2.29666);
+ AppendTest("more", 1, -1.20632 - 20.0);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
@@ -58,6 +58,42 @@ template <class M> void Continuation(const M &model) {
AppendTest("loin", 5, -0.0432557);
}
+template <class M> void Blanks(const M &model) {
+ FullScoreReturn ret;
+ State state(model.NullContextState());
+ State out;
+ AppendTest("also", 1, -1.687872);
+ AppendTest("would", 2, -2);
+ AppendTest("consider", 3, -3);
+ State preserve = state;
+ AppendTest("higher", 4, -4);
+ AppendTest("looking", 5, -5);
+
+ state = preserve;
+ AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103);
+
+ state = model.NullContextState();
+ // higher looking is a blank.
+ AppendTest("higher", 1, -1.509559);
+ AppendTest("looking", 1, -1.285941 - 0.30103);
+ AppendTest("not_found", 1, -1.995635 - 0.4771212);
+}
+
+template <class M> void Unknowns(const M &model) {
+ FullScoreReturn ret;
+ State state(model.NullContextState());
+ State out;
+
+ AppendTest("not_found", 1, -1.995635);
+ State preserve = state;
+ AppendTest("not_found2", 2, -15.0);
+ AppendTest("not_found3", 2, -15.0 - 2.0);
+
+ state = preserve;
+ AppendTest("however", 2, -4);
+ AppendTest("not_found3", 3, -6);
+}
+
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
@@ -103,16 +139,23 @@ template <class M> void Stateless(const M &model) {
// biarritz
StatelessTest(6, 1, 1, -1.9889);
// not found
- StatelessTest(7, 1, 0, -2.29666);
- StatelessTest(7, 0, 0, -1.995635);
+ StatelessTest(7, 1, 1, -2.29666);
+ StatelessTest(7, 0, 1, -1.995635);
WordIndex unk[1];
unk[0] = 0;
model.GetState(unk, unk + 1, state);
- BOOST_CHECK_EQUAL(0, state.valid_length_);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+ BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
-//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"};
+template <class M> void Everything(const M &m) {
+ Starters(m);
+ Continuation(m);
+ Blanks(m);
+ Unknowns(m);
+ Stateless(m);
+}
class ExpectEnumerateVocab : public EnumerateVocab {
public:
@@ -148,18 +191,16 @@ template <class ModelT> void LoadingTest() {
config.probing_multiplier = 2.0;
ModelT m("test.arpa", config);
enumerate.Check(m.GetVocabulary());
- Starters(m);
- Continuation(m);
- Stateless(m);
+ Everything(m);
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
-BOOST_AUTO_TEST_CASE(sorted) {
+/*BOOST_AUTO_TEST_CASE(sorted) {
LoadingTest<SortedModel>();
-}
+}*/
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
@@ -175,24 +216,23 @@ template <class ModelT> void BinaryTest() {
ModelT copy_model("test.arpa", config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
+ Everything(copy_model);
}
config.write_mmap = NULL;
ModelT binary("test.binary", config);
enumerate.Check(binary.GetVocabulary());
- Starters(binary);
- Continuation(binary);
- Stateless(binary);
+ Everything(binary);
unlink("test.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<Model>();
}
-BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
+/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
BinaryTest<SortedModel>();
-}
+}*/
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
}