From b1ed81ef3216b212295afa76c5d20a56fb647204 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 13 Oct 2014 00:42:37 -0400 Subject: new kenlm --- klm/lm/interpolate/example_sort_main.cc | 144 ++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 klm/lm/interpolate/example_sort_main.cc (limited to 'klm/lm/interpolate/example_sort_main.cc') diff --git a/klm/lm/interpolate/example_sort_main.cc b/klm/lm/interpolate/example_sort_main.cc new file mode 100644 index 00000000..4282255e --- /dev/null +++ b/klm/lm/interpolate/example_sort_main.cc @@ -0,0 +1,144 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" +#include "lm/vocab.hh" +#include "util/file.hh" +#include "util/unistd.hh" + + +int main() { + + // TODO: Make these all command-line parameters + const std::size_t ONE_GB = 1 << 30; + const std::size_t SIXTY_FOUR_MB = 1 << 26; + const std::size_t NUMBER_OF_BLOCKS = 2; + + // Vocab strings will be written to this file, forgotten, and reconstituted + // later. This saves memory. + util::scoped_fd vocab_file(util::MakeTemp("/tmp/")); + std::vector counts; + util::stream::Chains chains; + { + // Use consistent vocab ids across models. + lm::ngram::GrowableVocab vocab(10, vocab_file.get()); + lm::interpolate::ARPAToStream reader(STDIN_FILENO, vocab); + counts = reader.Counts(); + + // Configure a chain for each order. TODO: extract chain balance heuristics from lm/builder/pipeline.cc + chains.Init(reader.Order()); + + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to chains.push_back() invokes the Chain constructor + // and appends the newly created Chain object to the chains array + chains.push_back(util::stream::ChainConfig(lm::builder::NGram::TotalSize(i + 1), NUMBER_OF_BLOCKS, ONE_GB)); + + } + + // The following call to the >> method of chains + // constructs a ChainPosition for each chain in chains using Chain::Add(); + // that function begins with a call to Chain::Start() + // that allocates memory for the chain. + // + // After the following call to the >> method of chains, + // a new thread will be running + // and will be executing the reader.Run() method + // to read through the body of the ARPA file from standard input. + // + // For each n-gram line in the ARPA file, + // the thread executing reader.Run() + // will write the probability, the n-gram, and the backoff + // to the appropriate location in the appropriate chain + // (for details, see the ReadNGram() method in read_arpa.hh). + // + // Normally >> copies then runs so inline >> works. But here we want a ref. + chains >> boost::ref(reader); + + + util::stream::SortConfig sort_config; + sort_config.temp_prefix = "/tmp/"; + sort_config.buffer_size = SIXTY_FOUR_MB; + sort_config.total_memory = ONE_GB; + + // Parallel sorts across orders (though somewhat limited because ARPA files are not being read in parallel across orders) + lm::builder::Sorts sorts(reader.Order()); + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to sorts.push_back() invokes the Sort constructor + // and appends the newly constructed Sort object to the sorts array. + // + // After the construction of the Sort object, + // two new threads will be running (each owned by the chains[i] object). + // + // The first new thread will execute BlockSorter.Run() to sort the n-gram entries of order (i+1) + // that were previously read into chains[i] by the ARPA input reader thread. + // + // The second new thread will execute WriteAndRecycle.Run() + // to write each sorted block of data to disk as a temporary file. + sorts.push_back(chains[i], sort_config, lm::builder::SuffixOrder(i + 1)); + + } + + // Output to the same chains. + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to Chain::Wait() + // joins the threads owned by chains[i]. + // + // As such the following call won't return + // until all threads owned by chains[i] have completed. + // + // The following call also resets chain[i] + // so that it can be reused + // (including free'ing the memory previously used by the chain) + chains[i].Wait(); + + + // In an ideal world (without memory restrictions) + // we could merge all of the previously sorted blocks + // by reading them all completely into memory + // and then running merge sort over them. + // + // In the real world, we have memory restrictions; + // depending on how many blocks we have, + // and how much memory we can use to read from each block (sort_config.buffer_size) + // it may be the case that we have insufficient memory + // to read sort_config.buffer_size of data from each block from disk. + // + // If this occurs, then it will be necessary to perform one or more rounds of merge sort on disk; + // doing so will reduce the number of blocks that we will eventually need to read from + // when performing the final round of merge sort in memory. + // + // So, the following call determines whether it is necessary + // to perform one or more rounds of merge sort on disk; + // if such on-disk merge sorting is required, such sorting is performed. + // + // Finally, the following method launches a thread that calls OwningMergingReader.Run() + // to perform the final round of merge sort in memory. + // + // Merge sort could have be invoked directly + // so that merge sort memory doesn't coexist with Chain memory. + sorts[i].Output(chains[i]); + } + + // sorts can go out of scope even though it's still writing to the chains. + // note that vocab going out of scope flushes to vocab_file. + } + + + // Get the vocabulary mapping used for this ARPA file + lm::builder::VocabReconstitute reconstitute(vocab_file.get()); + + // After the following call to the << method of chains, + // a new thread will be running + // and will be executing the Run() method of PrintARPA + // to print the final sorted ARPA file to standard output. + chains >> lm::builder::PrintARPA(reconstitute, counts, NULL, STDOUT_FILENO); + + // Joins all threads that chains owns, + // and does a for loop over each chain object in chains, + // calling chain.Wait() on each such chain object + chains.Wait(true); + +} -- cgit v1.2.3