summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc184
1 files changed, 125 insertions, 59 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 57c7291c..2654071f 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -10,8 +10,8 @@ namespace lm {
namespace ngram {
std::ostream &operator<<(std::ostream &o, const State &state) {
- o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':';
- for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) {
+ o << "State length " << static_cast<unsigned int>(state.length) << ':';
+ for (const WordIndex *i = state.words; i < state.words + state.length; ++i) {
o << ' ' << *i;
}
return o;
@@ -19,25 +19,26 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
-#define StartTest(word, ngram, score) \
+#define StartTest(word, ngram, score, indep_left) \
ret = model.FullScore( \
state, \
model.GetVocabulary().Index(word), \
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \
+ BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
+ BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
{\
- WordIndex context[state.valid_length_ + 1]; \
+ WordIndex context[state.length + 1]; \
context[0] = model.GetVocabulary().Index(word); \
- std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \
+ std::copy(state.words, state.words + state.length, context + 1); \
State get_state; \
- model.GetState(context, context + state.valid_length_ + 1, get_state); \
+ model.GetState(context, context + state.length + 1, get_state); \
BOOST_CHECK_EQUAL(out, get_state); \
}
-#define AppendTest(word, ngram, score) \
- StartTest(word, ngram, score) \
+#define AppendTest(word, ngram, score, indep_left) \
+ StartTest(word, ngram, score, indep_left) \
state = out;
template <class M> void Starters(const M &model) {
@@ -45,12 +46,12 @@ template <class M> void Starters(const M &model) {
Model::State state(model.BeginSentenceState());
Model::State out;
- StartTest("looking", 2, -0.4846522);
+ StartTest("looking", 2, -0.4846522, true);
// , probability plus <s> backoff
- StartTest(",", 1, -1.383514 + -0.4149733);
+ StartTest(",", 1, -1.383514 + -0.4149733, true);
// <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 1, -1.995635 + -0.4149733);
+ StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true);
}
template <class M> void Continuation(const M &model) {
@@ -58,46 +59,64 @@ template <class M> void Continuation(const M &model) {
Model::State state(model.BeginSentenceState());
Model::State out;
- AppendTest("looking", 2, -0.484652);
- AppendTest("on", 3, -0.348837);
- AppendTest("a", 4, -0.0155266);
- AppendTest("little", 5, -0.00306122);
+ AppendTest("looking", 2, -0.484652, true);
+ AppendTest("on", 3, -0.348837, true);
+ AppendTest("a", 4, -0.0155266, true);
+ AppendTest("little", 5, -0.00306122, true);
State preserve = state;
- AppendTest("the", 1, -4.04005);
- AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 1, -2.29666);
- AppendTest("more", 1, -1.20632 - 20.0);
- AppendTest(".", 2, -0.51363);
- AppendTest("</s>", 3, -0.0191651);
- BOOST_CHECK_EQUAL(0, state.valid_length_);
+ AppendTest("the", 1, -4.04005, true);
+ AppendTest("biarritz", 1, -1.9889, true);
+ AppendTest("not_found", 1, -2.29666, true);
+ AppendTest("more", 1, -1.20632 - 20.0, true);
+ AppendTest(".", 2, -0.51363, true);
+ AppendTest("</s>", 3, -0.0191651, true);
+ BOOST_CHECK_EQUAL(0, state.length);
state = preserve;
- AppendTest("more", 5, -0.00181395);
- BOOST_CHECK_EQUAL(4, state.valid_length_);
- AppendTest("loin", 5, -0.0432557);
- BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("more", 5, -0.00181395, true);
+ BOOST_CHECK_EQUAL(4, state.length);
+ AppendTest("loin", 5, -0.0432557, true);
+ BOOST_CHECK_EQUAL(1, state.length);
}
template <class M> void Blanks(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
- AppendTest("also", 1, -1.687872);
- AppendTest("would", 2, -2);
- AppendTest("consider", 3, -3);
+ AppendTest("also", 1, -1.687872, false);
+ AppendTest("would", 2, -2, true);
+ AppendTest("consider", 3, -3, true);
State preserve = state;
- AppendTest("higher", 4, -4);
- AppendTest("looking", 5, -5);
- BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("higher", 4, -4, true);
+ AppendTest("looking", 5, -5, true);
+ BOOST_CHECK_EQUAL(1, state.length);
state = preserve;
- AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103);
+ // also would consider not_found
+ AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true);
state = model.NullContextState();
// higher looking is a blank.
- AppendTest("higher", 1, -1.509559);
- AppendTest("looking", 1, -1.285941 - 0.30103);
- AppendTest("not_found", 1, -1.995635 - 0.4771212);
+ AppendTest("higher", 1, -1.509559, false);
+ AppendTest("looking", 2, -1.285941 - 0.30103, false);
+
+ State higher_looking = state;
+
+ BOOST_CHECK_EQUAL(1, state.length);
+ AppendTest("not_found", 1, -1.995635 - 0.4771212, true);
+
+ state = higher_looking;
+ // higher looking consider
+ AppendTest("consider", 1, -1.687872 - 0.4771212, true);
+
+ state = model.NullContextState();
+ AppendTest("would", 1, -1.687872, false);
+ BOOST_CHECK_EQUAL(1, state.length);
+ AppendTest("consider", 2, -1.687872 -0.30103, false);
+ BOOST_CHECK_EQUAL(2, state.length);
+ AppendTest("higher", 3, -1.509559 - 0.30103, false);
+ BOOST_CHECK_EQUAL(3, state.length);
+ AppendTest("looking", 4, -1.285941 - 0.30103, false);
}
template <class M> void Unknowns(const M &model) {
@@ -105,14 +124,14 @@ template <class M> void Unknowns(const M &model) {
State state(model.NullContextState());
State out;
- AppendTest("not_found", 1, -1.995635);
+ AppendTest("not_found", 1, -1.995635, false);
State preserve = state;
- AppendTest("not_found2", 2, -15.0);
- AppendTest("not_found3", 2, -15.0 - 2.0);
+ AppendTest("not_found2", 2, -15.0, true);
+ AppendTest("not_found3", 2, -15.0 - 2.0, true);
state = preserve;
- AppendTest("however", 2, -4);
- AppendTest("not_found3", 3, -6);
+ AppendTest("however", 2, -4, true);
+ AppendTest("not_found3", 3, -6, true);
}
template <class M> void MinimalState(const M &model) {
@@ -120,22 +139,66 @@ template <class M> void MinimalState(const M &model) {
State state(model.NullContextState());
State out;
- AppendTest("baz", 1, -6.535897);
- BOOST_CHECK_EQUAL(0, state.valid_length_);
+ AppendTest("baz", 1, -6.535897, true);
+ BOOST_CHECK_EQUAL(0, state.length);
state = model.NullContextState();
- AppendTest("foo", 1, -3.141592);
- BOOST_CHECK_EQUAL(1, state.valid_length_);
- AppendTest("bar", 2, -6.0);
+ AppendTest("foo", 1, -3.141592, true);
+ BOOST_CHECK_EQUAL(1, state.length);
+ AppendTest("bar", 2, -6.0, true);
// Has to include the backoff weight.
- BOOST_CHECK_EQUAL(1, state.valid_length_);
- AppendTest("bar", 1, -2.718281 + 3.0);
- BOOST_CHECK_EQUAL(1, state.valid_length_);
+ BOOST_CHECK_EQUAL(1, state.length);
+ AppendTest("bar", 1, -2.718281 + 3.0, true);
+ BOOST_CHECK_EQUAL(1, state.length);
state = model.NullContextState();
- AppendTest("to", 1, -1.687872);
- AppendTest("look", 2, -0.2922095);
- BOOST_CHECK_EQUAL(2, state.valid_length_);
- AppendTest("good", 3, -7);
+ AppendTest("to", 1, -1.687872, false);
+ AppendTest("look", 2, -0.2922095, true);
+ BOOST_CHECK_EQUAL(2, state.length);
+ AppendTest("good", 3, -7, true);
+}
+
+template <class M> void ExtendLeftTest(const M &model) {
+ State right;
+ FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
+ const float kLittleProb = -1.285941;
+ BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
+ unsigned char next_use;
+ float backoff_out[4];
+
+ FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
+ BOOST_CHECK_EQUAL(0, next_use);
+ BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
+ BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001);
+ BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
+
+ const WordIndex a = model.GetVocabulary().Index("a");
+ float backoff_in = 3.14;
+ // a little
+ FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
+ BOOST_CHECK_EQUAL(1, next_use);
+ BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001);
+ BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
+ BOOST_CHECK(!extend_a.independent_left);
+
+ const WordIndex on = model.GetVocabulary().Index("on");
+ FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
+ BOOST_CHECK_EQUAL(1, next_use);
+ BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001);
+ BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
+ BOOST_CHECK(!extend_on.independent_left);
+
+ const WordIndex both[2] = {a, on};
+ float backoff_in_arr[4];
+ FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
+ BOOST_CHECK_EQUAL(2, next_use);
+ BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001);
+ BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
+ BOOST_CHECK(!extend_both.independent_left);
+ BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
}
#define StatelessTest(word, provide, ngram, score) \
@@ -166,17 +229,17 @@ template <class M> void Stateless(const M &model) {
// looking
StatelessTest(1, 2, 2, -0.484652);
// on
- AppendTest("on", 3, -0.348837);
+ AppendTest("on", 3, -0.348837, true);
StatelessTest(2, 3, 3, -0.348837);
StatelessTest(2, 2, 3, -0.348837);
StatelessTest(2, 1, 2, -0.4638903);
// a
StatelessTest(3, 4, 4, -0.0155266);
// little
- AppendTest("little", 5, -0.00306122);
+ AppendTest("little", 5, -0.00306122, true);
StatelessTest(4, 5, 5, -0.00306122);
// the
- AppendTest("the", 1, -4.04005);
+ AppendTest("the", 1, -4.04005, true);
StatelessTest(5, 5, 1, -4.04005);
// No context of the.
StatelessTest(5, 0, 1, -1.687872);
@@ -189,8 +252,8 @@ template <class M> void Stateless(const M &model) {
WordIndex unk[1];
unk[0] = 0;
model.GetState(unk, unk + 1, state);
- BOOST_CHECK_EQUAL(1, state.valid_length_);
- BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
+ BOOST_CHECK_EQUAL(1, state.length);
+ BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]);
}
template <class M> void NoUnkCheck(const M &model) {
@@ -207,6 +270,7 @@ template <class M> void Everything(const M &m) {
Blanks(m);
Unknowns(m);
MinimalState(m);
+ ExtendLeftTest(m);
Stateless(m);
}
@@ -245,6 +309,7 @@ template <class ModelT> void LoadingTest() {
config.enumerate_vocab = &enumerate;
ModelT m("test.arpa", config);
enumerate.Check(m.GetVocabulary());
+ BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
Everything(m);
}
{
@@ -252,6 +317,7 @@ template <class ModelT> void LoadingTest() {
config.enumerate_vocab = &enumerate;
ModelT m("test_nounk.arpa", config);
enumerate.Check(m.GetVocabulary());
+ BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
NoUnkCheck(m);
}
}