diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-05-16 13:24:08 -0700 |
---|---|---|
committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-05-26 22:59:54 -0400 |
commit | 149232c38eec558ddb1097698d1570aacb67b59f (patch) | |
tree | 5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm/lm/left_test.cc | |
parent | 01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff) |
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/left_test.cc')
-rw-r--r-- | klm/lm/left_test.cc | 83 |
1 files changed, 55 insertions, 28 deletions
diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc index c85e5efa..b23e6a0f 100644 --- a/klm/lm/left_test.cc +++ b/klm/lm/left_test.cc @@ -24,7 +24,7 @@ template <class M> void Short(const M &m) { Term("loin"); BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); } - BOOST_CHECK(base.full); + BOOST_CHECK(base.left.full); BOOST_CHECK_EQUAL(2, base.left.length); BOOST_CHECK_EQUAL(1, base.right.length); VCheck("loin", base.right.words[0]); @@ -40,7 +40,7 @@ template <class M> void Short(const M &m) { BOOST_CHECK_EQUAL(3, more_left.left.length); BOOST_CHECK_EQUAL(1, more_left.right.length); VCheck("loin", more_left.right.words[0]); - BOOST_CHECK(more_left.full); + BOOST_CHECK(more_left.left.full); ChartState shorter; { @@ -52,7 +52,7 @@ template <class M> void Short(const M &m) { BOOST_CHECK_EQUAL(1, shorter.left.length); BOOST_CHECK_EQUAL(1, shorter.right.length); VCheck("loin", shorter.right.words[0]); - BOOST_CHECK(shorter.full); + BOOST_CHECK(shorter.left.full); } template <class M> void Charge(const M &m) { @@ -66,7 +66,7 @@ template <class M> void Charge(const M &m) { BOOST_CHECK_EQUAL(1, base.left.length); BOOST_CHECK_EQUAL(1, base.right.length); VCheck("more", base.right.words[0]); - BOOST_CHECK(base.full); + BOOST_CHECK(base.left.full); ChartState extend; { @@ -78,7 +78,7 @@ template <class M> void Charge(const M &m) { BOOST_CHECK_EQUAL(2, extend.left.length); BOOST_CHECK_EQUAL(1, extend.right.length); VCheck("more", extend.right.words[0]); - BOOST_CHECK(extend.full); + BOOST_CHECK(extend.left.full); ChartState tobos; { @@ -91,9 +91,9 @@ template <class M> void Charge(const M &m) { BOOST_CHECK_EQUAL(1, tobos.right.length); } -template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words) { +template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) { float ret = 0.0; - State right = m.NullContextState(); + State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState(); for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) { State copy(right); ret += m.Score(copy, *i, right); @@ -101,12 +101,12 @@ template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &w return ret; } -template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words) { +template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) { float ret = 0.0; ChartState state; state.left.length = 0; state.right.length = 0; - state.full = false; + state.left.full = false; for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { ChartState copy(state); RuleScore<M> score(m, state); @@ -114,10 +114,17 @@ template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &w score.NonTerminal(copy, ret); ret = score.Finish(); } + if (begin_sentence) { + ChartState copy(state); + RuleScore<M> score(m, state); + score.BeginSentence(); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } return ret; } -template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words) { +template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) { std::vector<std::pair<ChartState, float> > states(words.size()); for (unsigned int i = 0; i < words.size(); ++i) { RuleScore<M> score(m, states[i].first); @@ -137,7 +144,19 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo } std::swap(states, upper); } - return states.empty() ? 0 : states.back().second; + + if (states.empty()) return 0.0; + + if (begin_sentence) { + ChartState ignored; + RuleScore<M> score(m, ignored); + score.BeginSentence(); + score.NonTerminal(states.front().first, states.front().second); + return score.Finish(); + } else { + return states.front().second; + } + } template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) { @@ -148,16 +167,15 @@ template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vec } #define TEXT_TEST(str) \ -{ \ - std::vector<WordIndex> words; \ LookupVocab(m, str, words); \ - float expect = LeftToRight(m, words); \ - BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ - BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ -} + expect = LeftToRight(m, words, rest); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \ // Build sentences, or parts thereof, from right to left. -template <class M> void GrowBig(const M &m) { +template <class M> void GrowBig(const M &m, bool rest = false) { + std::vector<WordIndex> words; + float expect; TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); TEXT_TEST("on a little more loin also would consider higher to look good"); @@ -171,6 +189,14 @@ template <class M> void GrowBig(const M &m) { TEXT_TEST("consider higher"); } +template <class M> void GrowSmall(const M &m, bool rest = false) { + std::vector<WordIndex> words; + float expect; + TEXT_TEST("in biarritz watching considering looking . </s>"); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + template <class M> void AlsoWouldConsiderHigher(const M &m) { ChartState also; { @@ -210,7 +236,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) { } BOOST_CHECK_EQUAL(1, consider.left.length); BOOST_CHECK_EQUAL(1, consider.right.length); - BOOST_CHECK(!consider.full); + BOOST_CHECK(!consider.left.full); ChartState higher; float higher_score; @@ -222,7 +248,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); BOOST_CHECK_EQUAL(1, higher.left.length); BOOST_CHECK_EQUAL(1, higher.right.length); - BOOST_CHECK(!higher.full); + BOOST_CHECK(!higher.left.full); VCheck("higher", higher.right.words[0]); BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); @@ -234,7 +260,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); } BOOST_CHECK_EQUAL(2, consider_higher.left.length); - BOOST_CHECK(!consider_higher.full); + BOOST_CHECK(!consider_higher.left.full); ChartState full; { @@ -246,12 +272,6 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) { BOOST_CHECK_EQUAL(4, full.right.length); } -template <class M> void GrowSmall(const M &m) { - TEXT_TEST("in biarritz watching considering looking . </s>"); - TEXT_TEST("in biarritz watching considering looking ."); - TEXT_TEST("in biarritz"); -} - #define CHECK_SCORE(str, val) \ { \ float got = val; \ @@ -315,7 +335,7 @@ template <class M> void FullGrow(const M &m) { CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish()); } BOOST_CHECK_EQUAL(l2[1].left.length, 1); - BOOST_CHECK(l2[1].full); + BOOST_CHECK(l2[1].left.full); ChartState top; { @@ -362,6 +382,13 @@ BOOST_AUTO_TEST_CASE(ArrayTrieAll) { Everything<ArrayTrieModel>(); } +BOOST_AUTO_TEST_CASE(RestProbing) { + Config config; + config.messages = NULL; + RestProbingModel m(FileLocation(), config); + GrowBig(m, true); +} + } // namespace } // namespace ngram } // namespace lm |