summaryrefslogtreecommitdiff
path: root/klm/lm/model_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model_test.cc')
-rw-r--r--klm/lm/model_test.cc24
1 files changed, 19 insertions, 5 deletions
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 2654071f..461704d4 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -19,6 +19,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) {
namespace {
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+const char *TestNoUnkLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 3) {
+ return "test_nounk.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[2];
+
+}
+
#define StartTest(word, ngram, score, indep_left) \
ret = model.FullScore( \
state, \
@@ -307,7 +321,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test.arpa", config);
+ ModelT m(TestLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
Everything(m);
@@ -315,7 +329,7 @@ template <class ModelT> void LoadingTest() {
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
- ModelT m("test_nounk.arpa", config);
+ ModelT m(TestNoUnkLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
NoUnkCheck(m);
@@ -346,7 +360,7 @@ template <class ModelT> void BinaryTest() {
config.enumerate_vocab = &enumerate;
{
- ModelT copy_model("test.arpa", config);
+ ModelT copy_model(TestLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
Everything(copy_model);
@@ -370,14 +384,14 @@ template <class ModelT> void BinaryTest() {
config.messages = NULL;
enumerate.Clear();
{
- ModelT copy_model("test_nounk.arpa", config);
+ ModelT copy_model(TestNoUnkLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
NoUnkCheck(copy_model);
}
config.write_mmap = NULL;
{
- ModelT binary("test_nounk.binary", config);
+ ModelT binary(TestNoUnkLocation(), config);
enumerate.Check(binary.GetVocabulary());
NoUnkCheck(binary);
}