diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-09-12 15:07:44 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-09-12 15:07:44 +0100 |
commit | 173910593bf6bf3dc52902f99a683560d8c73942 (patch) | |
tree | 746ea920283178f3bf6b7f86e7b9e6b195821676 /klm/alone/main.cc | |
parent | 143ba7317dcaee3058d66f9e6558316f88f95212 (diff) |
Add the alone stuff, using a wrapper to the edge class.
Diffstat (limited to 'klm/alone/main.cc')
-rw-r--r-- | klm/alone/main.cc | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/klm/alone/main.cc b/klm/alone/main.cc new file mode 100644 index 00000000..7768b89c --- /dev/null +++ b/klm/alone/main.cc @@ -0,0 +1,84 @@ +#include "alone/threading.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "util/exception.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include <boost/lexical_cast.hpp> + +#include <iostream> +#include <memory> + +namespace alone { + +template <class Control> void ReadLoop(const std::string &graph_prefix, Control &control) { + for (unsigned int sentence = 0; ; ++sentence) { + std::stringstream name; + name << graph_prefix << '/' << sentence; + std::auto_ptr<util::FilePiece> file; + try { + file.reset(new util::FilePiece(name.str().c_str())); + } catch (const util::ErrnoException &e) { + if (e.Error() == ENOENT) return; + throw; + } + control.Add(file.release()); + } +} + +template <class Model> void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { + Model model(model_file); + search::Config config(weight_str, pop_limit); + + if (threads > 1) { +#ifdef WITH_THREADS + Controller<Model> controller(config, model, threads, std::cout); + ReadLoop(graph_prefix, controller); +#else + UTIL_THROW(util::Exception, "Threading support not compiled in."); +#endif + } else { + InThread<Model> controller(config, model, std::cout); + ReadLoop(graph_prefix, controller); + } +} + +void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + RunWithModelType<lm::ngram::ProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); + break; + case lm::ngram::REST_PROBING: + RunWithModelType<lm::ngram::RestProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); + break; + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +} // namespace alone + +int main(int argc, char *argv[]) { + if (argc < 5 || argc > 6) { + std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; + return 1; + } + +#ifdef WITH_THREADS + unsigned thread_count = boost::thread::hardware_concurrency(); +#else + unsigned thread_count = 1; +#endif + if (argc == 6) { + thread_count = boost::lexical_cast<unsigned>(argv[5]); + UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); + } + UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); + alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast<unsigned int>(argv[4]), thread_count); + + util::PrintUsage(std::cerr); + return 0; +} |