diff options
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r-- | klm/lm/model_test.cc | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 2654071f..461704d4 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { +const char *TestLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 2) { + return "test.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[1]; +} +const char *TestNoUnkLocation() { + if (boost::unit_test::framework::master_test_suite().argc < 3) { + return "test_nounk.arpa"; + } + return boost::unit_test::framework::master_test_suite().argv[2]; + +} + #define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ @@ -307,7 +321,7 @@ template <class ModelT> void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test.arpa", config); + ModelT m(TestLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); @@ -315,7 +329,7 @@ template <class ModelT> void LoadingTest() { { ExpectEnumerateVocab enumerate; config.enumerate_vocab = &enumerate; - ModelT m("test_nounk.arpa", config); + ModelT m(TestNoUnkLocation(), config); enumerate.Check(m.GetVocabulary()); BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); @@ -346,7 +360,7 @@ template <class ModelT> void BinaryTest() { config.enumerate_vocab = &enumerate; { - ModelT copy_model("test.arpa", config); + ModelT copy_model(TestLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); Everything(copy_model); @@ -370,14 +384,14 @@ template <class ModelT> void BinaryTest() { config.messages = NULL; enumerate.Clear(); { - ModelT copy_model("test_nounk.arpa", config); + ModelT copy_model(TestNoUnkLocation(), config); enumerate.Check(copy_model.GetVocabulary()); enumerate.Clear(); NoUnkCheck(copy_model); } config.write_mmap = NULL; { - ModelT binary("test_nounk.binary", config); + ModelT binary(TestNoUnkLocation(), config); enumerate.Check(binary.GetVocabulary()); NoUnkCheck(binary); } |