diff options
Diffstat (limited to 'klm/lm/interpolate')
| -rw-r--r-- | klm/lm/interpolate/arpa_to_stream.cc | 47 | ||||
| -rw-r--r-- | klm/lm/interpolate/arpa_to_stream.hh | 38 | ||||
| -rw-r--r-- | klm/lm/interpolate/example_sort_main.cc | 144 | 
3 files changed, 229 insertions, 0 deletions
| diff --git a/klm/lm/interpolate/arpa_to_stream.cc b/klm/lm/interpolate/arpa_to_stream.cc new file mode 100644 index 00000000..f2696f39 --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.cc @@ -0,0 +1,47 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +// TODO: should this move out of builder? +#include "lm/builder/ngram_stream.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" + +namespace lm { namespace interpolate { + +ARPAToStream::ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab) +  : in_(fd), vocab_(vocab) { +     +  // Read the ARPA file header. +  // +  // After the following call, counts_ will be correctly initialized, +  // and in_ will be positioned for reading the body of the ARPA file.   +  ReadARPACounts(in_, counts_); +   +} + +void ARPAToStream::Run(const util::stream::ChainPositions &positions) { +  // Make one stream for each order. +  builder::NGramStreams streams(positions); +  PositiveProbWarn warn; + +  // Unigrams are handled specially because they're being inserted into the vocab. +  ReadNGramHeader(in_, 1); +  for (uint64_t i = 0; i < counts_[0]; ++i, ++streams[0]) { +    streams[0]->begin()[0] = vocab_.FindOrInsert(Read1Gram(in_, streams[0]->Value().complete, warn)); +  } +  // Finish off the unigram stream. +  streams[0].Poison(); + +  // TODO: don't waste backoff field for highest order. +  for (unsigned char n = 2; n <= counts_.size(); ++n) { +    ReadNGramHeader(in_, n); +    builder::NGramStream &stream = streams[n - 1]; +    const uint64_t end = counts_[n - 1]; +    for (std::size_t i = 0; i < end; ++i, ++stream) { +      ReadNGram(in_, n, vocab_, stream->begin(), stream->Value().complete, warn); +    } +    // Finish the stream for n-grams.. +    stream.Poison(); +  } +} + +}} // namespaces diff --git a/klm/lm/interpolate/arpa_to_stream.hh b/klm/lm/interpolate/arpa_to_stream.hh new file mode 100644 index 00000000..4613998d --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.hh @@ -0,0 +1,38 @@ +#include "lm/read_arpa.hh" +#include "util/file_piece.hh" + +#include <vector> + +#include <stdint.h> + +namespace util { namespace stream { class ChainPositions; } } + +namespace lm { + +namespace ngram { +template <class T> class GrowableVocab; +class WriteUniqueWords; +} // namespace ngram + +namespace interpolate { + +class ARPAToStream { +  public: +    // Takes ownership of fd. +    explicit ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab); + +    std::size_t Order() const { return counts_.size(); } + +    const std::vector<uint64_t> &Counts() const { return counts_; } + +    void Run(const util::stream::ChainPositions &positions); + +  private: +    util::FilePiece in_; + +    std::vector<uint64_t> counts_; + +    ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab_; +}; + +}} // namespaces diff --git a/klm/lm/interpolate/example_sort_main.cc b/klm/lm/interpolate/example_sort_main.cc new file mode 100644 index 00000000..4282255e --- /dev/null +++ b/klm/lm/interpolate/example_sort_main.cc @@ -0,0 +1,144 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" +#include "lm/vocab.hh" +#include "util/file.hh" +#include "util/unistd.hh" + + +int main() { +   +  // TODO: Make these all command-line parameters +  const std::size_t ONE_GB = 1 << 30; +  const std::size_t SIXTY_FOUR_MB = 1 << 26; +  const std::size_t NUMBER_OF_BLOCKS = 2; +   +  // Vocab strings will be written to this file, forgotten, and reconstituted +  // later.  This saves memory. +  util::scoped_fd vocab_file(util::MakeTemp("/tmp/")); +  std::vector<uint64_t> counts; +  util::stream::Chains chains; +  { +    // Use consistent vocab ids across models. +    lm::ngram::GrowableVocab<lm::ngram::WriteUniqueWords> vocab(10, vocab_file.get()); +    lm::interpolate::ARPAToStream reader(STDIN_FILENO, vocab); +    counts = reader.Counts(); + +    // Configure a chain for each order.  TODO: extract chain balance heuristics from lm/builder/pipeline.cc +    chains.Init(reader.Order()); + +    for (std::size_t i = 0; i < reader.Order(); ++i) { +       +      // The following call to chains.push_back() invokes the Chain constructor +      //     and appends the newly created Chain object to the chains array +      chains.push_back(util::stream::ChainConfig(lm::builder::NGram::TotalSize(i + 1), NUMBER_OF_BLOCKS, ONE_GB)); +       +    } +     +    // The following call to the >> method of chains +    //    constructs a ChainPosition for each chain in chains using Chain::Add(); +    //    that function begins with a call to Chain::Start() +    //    that allocates memory for the chain. +    // +    // After the following call to the >> method of chains,  +    //    a new thread will be running +    //    and will be executing the reader.Run() method +    //    to read through the body of the ARPA file from standard input. +    // +    // For each n-gram line in the ARPA file, +    //    the thread executing reader.Run()  +    //    will write the probability, the n-gram, and the backoff +    //    to the appropriate location in the appropriate chain  +    //    (for details, see the ReadNGram() method in read_arpa.hh). +    // +    // Normally >> copies then runs so inline >> works.  But here we want a ref. +    chains >> boost::ref(reader); +     + +    util::stream::SortConfig sort_config; +    sort_config.temp_prefix  = "/tmp/"; +    sort_config.buffer_size  = SIXTY_FOUR_MB; +    sort_config.total_memory = ONE_GB; +     +    // Parallel sorts across orders (though somewhat limited because ARPA files are not being read in parallel across orders) +    lm::builder::Sorts<lm::builder::SuffixOrder> sorts(reader.Order()); +    for (std::size_t i = 0; i < reader.Order(); ++i) { + +      // The following call to sorts.push_back() invokes the Sort constructor +      //     and appends the newly constructed Sort object to the sorts array. +      // +      // After the construction of the Sort object, +      //    two new threads will be running (each owned by the chains[i] object). +      // +      // The first new thread will execute BlockSorter.Run() to sort the n-gram entries of order (i+1) +      //    that were previously read into chains[i] by the ARPA input reader thread. +      // +      // The second new thread will execute WriteAndRecycle.Run()  +      //    to write each sorted block of data to disk as a temporary file. +      sorts.push_back(chains[i], sort_config, lm::builder::SuffixOrder(i + 1)); + +    } +     +    // Output to the same chains. +    for (std::size_t i = 0; i < reader.Order(); ++i) { +       +      // The following call to Chain::Wait() +      //     joins the threads owned by chains[i]. +      // +      // As such the following call won't return +      //     until all threads owned by chains[i] have completed. +      // +      // The following call also resets chain[i] +      //     so that it can be reused  +      //     (including free'ing the memory previously used by the chain) +      chains[i].Wait(); +       +       +      // In an ideal world (without memory restrictions) +      //     we could merge all of the previously sorted blocks +      //     by reading them all completely into memory +      //     and then running merge sort over them. +      // +      // In the real world, we have memory restrictions; +      //     depending on how many blocks we have, +      //     and how much memory we can use to read from each block (sort_config.buffer_size) +      //     it may be the case that we have insufficient memory  +      //     to read sort_config.buffer_size of data from each block from disk. +      // +      // If this occurs, then it will be necessary to perform one or more rounds of merge sort on disk; +      //     doing so will reduce the number of blocks that we will eventually need to read from +      //     when performing the final round of merge sort in memory. +      // +      // So, the following call determines whether it is necessary +      //     to perform one or more rounds of merge sort on disk; +      //     if such on-disk merge sorting is required, such sorting is performed. +      // +      // Finally, the following method launches a thread that calls OwningMergingReader.Run() +      //     to perform the final round of merge sort in memory. +      // +      // Merge sort could have be invoked directly +      //     so that merge sort memory doesn't coexist with Chain memory. +      sorts[i].Output(chains[i]); +    } +     +    // sorts can go out of scope even though it's still writing to the chains. +    // note that vocab going out of scope flushes to vocab_file. +  } + +   +  // Get the vocabulary mapping used for this ARPA file +  lm::builder::VocabReconstitute reconstitute(vocab_file.get()); +   +  // After the following call to the << method of chains, +  //    a new thread will be running +  //    and will be executing the Run() method of PrintARPA +  //    to print the final sorted ARPA file to standard output. +  chains >> lm::builder::PrintARPA(reconstitute, counts, NULL, STDOUT_FILENO); +   +  // Joins all threads that chains owns,  +  //    and does a for loop over each chain object in chains, +  //    calling chain.Wait() on each such chain object +  chains.Wait(true); +   +} | 
