diff options
Diffstat (limited to 'klm/lm/builder')
| -rw-r--r-- | klm/lm/builder/corpus_count.cc | 9 | ||||
| -rw-r--r-- | klm/lm/builder/lmplz_main.cc | 33 | ||||
| -rw-r--r-- | klm/lm/builder/pipeline.cc | 1 | 
3 files changed, 35 insertions, 8 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index aea93ad1..ccc06efc 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -238,12 +238,17 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {    const WordIndex end_sentence = vocab.Lookup("</s>");    Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);    uint64_t count = 0; -  StringPiece delimiters("\0\t\r ", 4); +  bool delimiters[256]; +  memset(delimiters, 0, sizeof(delimiters)); +  const char kDelimiterSet[] = "\0\t\n\r "; +  for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { +    delimiters[static_cast<unsigned char>(*i)] = true; +  }    try {      while(true) {        StringPiece line(from_.ReadLine());        writer.StartSentence(); -      for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { +      for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {          WordIndex word = vocab.Lookup(*w);          UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus.  I plan to support models containing <unk> in the future.");          writer.Append(word); diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index c87abdb8..2563deed 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -33,7 +33,10 @@ int main(int argc, char *argv[]) {      po::options_description options("Language model building options");      lm::builder::PipelineConfig pipeline; +    std::string text, arpa; +      options.add_options() +      ("help", po::bool_switch(), "Show this help message")        ("order,o", po::value<std::size_t>(&pipeline.order)  #if BOOST_VERSION >= 104200           ->required() @@ -47,8 +50,13 @@ int main(int argc, char *argv[]) {        ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")        ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")        ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") -      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); -    if (argc == 1) { +      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") +      ("text", po::value<std::string>(&text), "Read text from a file instead of stdin") +      ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout"); +    po::variables_map vm; +    po::store(po::parse_command_line(argc, argv, options), vm); + +    if (argc == 1 || vm["help"].as<bool>()) {        std::cerr <<           "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"          "Please cite:\n" @@ -66,12 +74,17 @@ int main(int argc, char *argv[]) {          "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"          "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"          "Valid units are \% for percentage of memory (supported platforms only) and (in\n" -        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n\n"; +        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n"; +      uint64_t mem = util::GuessPhysicalMemory(); +      if (mem) { +        std::cerr << "This machine has " << mem << " bytes of memory.\n\n"; +      } else { +        std::cerr << "Unable to determine the amount of memory on this machine.\n\n"; +      }         std::cerr << options << std::endl;        return 1;      } -    po::variables_map vm; -    po::store(po::parse_command_line(argc, argv, options), vm); +      po::notify(vm);      // required() appeared in Boost 1.42.0. @@ -92,9 +105,17 @@ int main(int argc, char *argv[]) {      initial.adder_out.block_count = 2;      pipeline.read_backoffs = initial.adder_out; +    util::scoped_fd in(0), out(1); +    if (vm.count("text")) { +      in.reset(util::OpenReadOrThrow(text.c_str())); +    } +    if (vm.count("arpa")) { +      out.reset(util::CreateOrThrow(arpa.c_str())); +    } +      // Read from stdin      try { -      lm::builder::Pipeline(pipeline, 0, 1); +      lm::builder::Pipeline(pipeline, in.release(), out.release());      } catch (const util::MallocException &e) {        std::cerr << e.what() << std::endl;        std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl; diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index b89ea6ba..44a2313c 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -226,6 +226,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m    util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());    chain.Wait(true); +  std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;    std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;    master.InitForAdjust(sorter, type_count);  }  | 
