summaryrefslogtreecommitdiff
path: root/klm/lm/left_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/left_test.cc')
-rw-r--r--klm/lm/left_test.cc83
1 files changed, 55 insertions, 28 deletions
diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc
index c85e5efa..b23e6a0f 100644
--- a/klm/lm/left_test.cc
+++ b/klm/lm/left_test.cc
@@ -24,7 +24,7 @@ template <class M> void Short(const M &m) {
Term("loin");
BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001);
}
- BOOST_CHECK(base.full);
+ BOOST_CHECK(base.left.full);
BOOST_CHECK_EQUAL(2, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("loin", base.right.words[0]);
@@ -40,7 +40,7 @@ template <class M> void Short(const M &m) {
BOOST_CHECK_EQUAL(3, more_left.left.length);
BOOST_CHECK_EQUAL(1, more_left.right.length);
VCheck("loin", more_left.right.words[0]);
- BOOST_CHECK(more_left.full);
+ BOOST_CHECK(more_left.left.full);
ChartState shorter;
{
@@ -52,7 +52,7 @@ template <class M> void Short(const M &m) {
BOOST_CHECK_EQUAL(1, shorter.left.length);
BOOST_CHECK_EQUAL(1, shorter.right.length);
VCheck("loin", shorter.right.words[0]);
- BOOST_CHECK(shorter.full);
+ BOOST_CHECK(shorter.left.full);
}
template <class M> void Charge(const M &m) {
@@ -66,7 +66,7 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(1, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("more", base.right.words[0]);
- BOOST_CHECK(base.full);
+ BOOST_CHECK(base.left.full);
ChartState extend;
{
@@ -78,7 +78,7 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(2, extend.left.length);
BOOST_CHECK_EQUAL(1, extend.right.length);
VCheck("more", extend.right.words[0]);
- BOOST_CHECK(extend.full);
+ BOOST_CHECK(extend.left.full);
ChartState tobos;
{
@@ -91,9 +91,9 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(1, tobos.right.length);
}
-template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
- State right = m.NullContextState();
+ State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState();
for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) {
State copy(right);
ret += m.Score(copy, *i, right);
@@ -101,12 +101,12 @@ template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &w
return ret;
}
-template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
ChartState state;
state.left.length = 0;
state.right.length = 0;
- state.full = false;
+ state.left.full = false;
for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) {
ChartState copy(state);
RuleScore<M> score(m, state);
@@ -114,10 +114,17 @@ template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &w
score.NonTerminal(copy, ret);
ret = score.Finish();
}
+ if (begin_sentence) {
+ ChartState copy(state);
+ RuleScore<M> score(m, state);
+ score.BeginSentence();
+ score.NonTerminal(copy, ret);
+ ret = score.Finish();
+ }
return ret;
}
-template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
std::vector<std::pair<ChartState, float> > states(words.size());
for (unsigned int i = 0; i < words.size(); ++i) {
RuleScore<M> score(m, states[i].first);
@@ -137,7 +144,19 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo
}
std::swap(states, upper);
}
- return states.empty() ? 0 : states.back().second;
+
+ if (states.empty()) return 0.0;
+
+ if (begin_sentence) {
+ ChartState ignored;
+ RuleScore<M> score(m, ignored);
+ score.BeginSentence();
+ score.NonTerminal(states.front().first, states.front().second);
+ return score.Finish();
+ } else {
+ return states.front().second;
+ }
+
}
template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
@@ -148,16 +167,15 @@ template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vec
}
#define TEXT_TEST(str) \
-{ \
- std::vector<WordIndex> words; \
LookupVocab(m, str, words); \
- float expect = LeftToRight(m, words); \
- BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \
- BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \
-}
+ expect = LeftToRight(m, words, rest); \
+ BOOST_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \
+ BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \
// Build sentences, or parts thereof, from right to left.
-template <class M> void GrowBig(const M &m) {
+template <class M> void GrowBig(const M &m, bool rest = false) {
+ std::vector<WordIndex> words;
+ float expect;
TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good");
@@ -171,6 +189,14 @@ template <class M> void GrowBig(const M &m) {
TEXT_TEST("consider higher");
}
+template <class M> void GrowSmall(const M &m, bool rest = false) {
+ std::vector<WordIndex> words;
+ float expect;
+ TEXT_TEST("in biarritz watching considering looking . </s>");
+ TEXT_TEST("in biarritz watching considering looking .");
+ TEXT_TEST("in biarritz");
+}
+
template <class M> void AlsoWouldConsiderHigher(const M &m) {
ChartState also;
{
@@ -210,7 +236,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
}
BOOST_CHECK_EQUAL(1, consider.left.length);
BOOST_CHECK_EQUAL(1, consider.right.length);
- BOOST_CHECK(!consider.full);
+ BOOST_CHECK(!consider.left.full);
ChartState higher;
float higher_score;
@@ -222,7 +248,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001);
BOOST_CHECK_EQUAL(1, higher.left.length);
BOOST_CHECK_EQUAL(1, higher.right.length);
- BOOST_CHECK(!higher.full);
+ BOOST_CHECK(!higher.left.full);
VCheck("higher", higher.right.words[0]);
BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001);
@@ -234,7 +260,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, consider_higher.left.length);
- BOOST_CHECK(!consider_higher.full);
+ BOOST_CHECK(!consider_higher.left.full);
ChartState full;
{
@@ -246,12 +272,6 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_EQUAL(4, full.right.length);
}
-template <class M> void GrowSmall(const M &m) {
- TEXT_TEST("in biarritz watching considering looking . </s>");
- TEXT_TEST("in biarritz watching considering looking .");
- TEXT_TEST("in biarritz");
-}
-
#define CHECK_SCORE(str, val) \
{ \
float got = val; \
@@ -315,7 +335,7 @@ template <class M> void FullGrow(const M &m) {
CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish());
}
BOOST_CHECK_EQUAL(l2[1].left.length, 1);
- BOOST_CHECK(l2[1].full);
+ BOOST_CHECK(l2[1].left.full);
ChartState top;
{
@@ -362,6 +382,13 @@ BOOST_AUTO_TEST_CASE(ArrayTrieAll) {
Everything<ArrayTrieModel>();
}
+BOOST_AUTO_TEST_CASE(RestProbing) {
+ Config config;
+ config.messages = NULL;
+ RestProbingModel m(FileLocation(), config);
+ GrowBig(m, true);
+}
+
} // namespace
} // namespace ngram
} // namespace lm