summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc42
1 files changed, 29 insertions, 13 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 461704d4..8a122c60 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -30,7 +30,15 @@ const char *TestNoUnkLocation() {
return "test_nounk.arpa";
}
return boost::unit_test::framework::master_test_suite().argv[2];
+}
+template <class Model> State GetState(const Model &model, const char *word, const State &in) {
+ WordIndex context[in.length + 1];
+ context[0] = model.GetVocabulary().Index(word);
+ std::copy(in.words, in.words + in.length, context + 1);
+ State ret;
+ model.GetState(context, context + in.length + 1, ret);
+ return ret;
}
#define StartTest(word, ngram, score, indep_left) \
@@ -42,14 +50,7 @@ const char *TestNoUnkLocation() {
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
- {\
- WordIndex context[state.length + 1]; \
- context[0] = model.GetVocabulary().Index(word); \
- std::copy(state.words, state.words + state.length, context + 1); \
- State get_state; \
- model.GetState(context, context + state.length + 1, get_state); \
- BOOST_CHECK_EQUAL(out, get_state); \
- }
+ BOOST_CHECK_EQUAL(out, GetState(model, word, state));
#define AppendTest(word, ngram, score, indep_left) \
StartTest(word, ngram, score, indep_left) \
@@ -182,7 +183,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
BOOST_CHECK_EQUAL(0, next_use);
BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
- BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001);
+ BOOST_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
const WordIndex a = model.GetVocabulary().Index("a");
@@ -191,7 +192,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
BOOST_CHECK(!extend_a.independent_left);
@@ -199,7 +200,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
BOOST_CHECK(!extend_on.independent_left);
@@ -209,7 +210,7 @@ template <class M> void ExtendLeftTest(const M &model) {
BOOST_CHECK_EQUAL(2, next_use);
BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
BOOST_CHECK(!extend_both.independent_left);
BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
@@ -399,7 +400,10 @@ template <class ModelT> void BinaryTest() {
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
- BinaryTest<Model>();
+ BinaryTest<ProbingModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
+ BinaryTest<RestProbingModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
@@ -414,6 +418,18 @@ BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
BinaryTest<QuantArrayTrieModel>();
}
+BOOST_AUTO_TEST_CASE(rest_max) {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+
+ RestProbingModel model(TestLocation(), config);
+ State state, out;
+ FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
+ BOOST_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
+ BOOST_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
+}
+
} // namespace
} // namespace ngram
} // namespace lm