diff options
Diffstat (limited to 'klm/lm/filter/filter_main.cc')
-rw-r--r-- | klm/lm/filter/filter_main.cc | 155 |
1 files changed, 80 insertions, 75 deletions
diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc index 1736bc40..82fdc1ef 100644 --- a/klm/lm/filter/filter_main.cc +++ b/klm/lm/filter/filter_main.cc @@ -6,6 +6,7 @@ #endif #include "lm/filter/vocab.hh" #include "lm/filter/wrapper.hh" +#include "util/exception.hh" #include "util/file_piece.hh" #include <boost/ptr_container/ptr_vector.hpp> @@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr } // namespace lm int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } + try { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - config.mode = lm::MODE_UNSET; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - config.mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - config.mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - config.mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - config.mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; #ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast<size_t>(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); return 1; } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast<size_t>(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { + } + + if (config.mode == lm::MODE_UNSET) { lm::DisplayHelp(argv[0]); return 1; } - } - - if (config.mode == lm::MODE_UNSET) { - lm::DisplayHelp(argv[0]); - return 1; - } - if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { - err(2, "Could not open input file %s", cmd_input); + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; + return 1; + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); + vocab = &cmd_file; } - vocab = &cmd_file; - } - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; } - return 0; } |