summaryrefslogtreecommitdiff
path: root/klm/lm/interpolate
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/interpolate')
-rw-r--r--klm/lm/interpolate/arpa_to_stream.cc47
-rw-r--r--klm/lm/interpolate/arpa_to_stream.hh38
-rw-r--r--klm/lm/interpolate/example_sort_main.cc144
3 files changed, 229 insertions, 0 deletions
diff --git a/klm/lm/interpolate/arpa_to_stream.cc b/klm/lm/interpolate/arpa_to_stream.cc
new file mode 100644
index 00000000..f2696f39
--- /dev/null
+++ b/klm/lm/interpolate/arpa_to_stream.cc
@@ -0,0 +1,47 @@
+#include "lm/interpolate/arpa_to_stream.hh"
+
+// TODO: should this move out of builder?
+#include "lm/builder/ngram_stream.hh"
+#include "lm/read_arpa.hh"
+#include "lm/vocab.hh"
+
+namespace lm { namespace interpolate {
+
+ARPAToStream::ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab)
+ : in_(fd), vocab_(vocab) {
+
+ // Read the ARPA file header.
+ //
+ // After the following call, counts_ will be correctly initialized,
+ // and in_ will be positioned for reading the body of the ARPA file.
+ ReadARPACounts(in_, counts_);
+
+}
+
+void ARPAToStream::Run(const util::stream::ChainPositions &positions) {
+ // Make one stream for each order.
+ builder::NGramStreams streams(positions);
+ PositiveProbWarn warn;
+
+ // Unigrams are handled specially because they're being inserted into the vocab.
+ ReadNGramHeader(in_, 1);
+ for (uint64_t i = 0; i < counts_[0]; ++i, ++streams[0]) {
+ streams[0]->begin()[0] = vocab_.FindOrInsert(Read1Gram(in_, streams[0]->Value().complete, warn));
+ }
+ // Finish off the unigram stream.
+ streams[0].Poison();
+
+ // TODO: don't waste backoff field for highest order.
+ for (unsigned char n = 2; n <= counts_.size(); ++n) {
+ ReadNGramHeader(in_, n);
+ builder::NGramStream &stream = streams[n - 1];
+ const uint64_t end = counts_[n - 1];
+ for (std::size_t i = 0; i < end; ++i, ++stream) {
+ ReadNGram(in_, n, vocab_, stream->begin(), stream->Value().complete, warn);
+ }
+ // Finish the stream for n-grams..
+ stream.Poison();
+ }
+}
+
+}} // namespaces
diff --git a/klm/lm/interpolate/arpa_to_stream.hh b/klm/lm/interpolate/arpa_to_stream.hh
new file mode 100644
index 00000000..4613998d
--- /dev/null
+++ b/klm/lm/interpolate/arpa_to_stream.hh
@@ -0,0 +1,38 @@
+#include "lm/read_arpa.hh"
+#include "util/file_piece.hh"
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace util { namespace stream { class ChainPositions; } }
+
+namespace lm {
+
+namespace ngram {
+template <class T> class GrowableVocab;
+class WriteUniqueWords;
+} // namespace ngram
+
+namespace interpolate {
+
+class ARPAToStream {
+ public:
+ // Takes ownership of fd.
+ explicit ARPAToStream(int fd, ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab);
+
+ std::size_t Order() const { return counts_.size(); }
+
+ const std::vector<uint64_t> &Counts() const { return counts_; }
+
+ void Run(const util::stream::ChainPositions &positions);
+
+ private:
+ util::FilePiece in_;
+
+ std::vector<uint64_t> counts_;
+
+ ngram::GrowableVocab<ngram::WriteUniqueWords> &vocab_;
+};
+
+}} // namespaces
diff --git a/klm/lm/interpolate/example_sort_main.cc b/klm/lm/interpolate/example_sort_main.cc
new file mode 100644
index 00000000..4282255e
--- /dev/null
+++ b/klm/lm/interpolate/example_sort_main.cc
@@ -0,0 +1,144 @@
+#include "lm/interpolate/arpa_to_stream.hh"
+
+#include "lm/builder/print.hh"
+#include "lm/builder/sort.hh"
+#include "lm/vocab.hh"
+#include "util/file.hh"
+#include "util/unistd.hh"
+
+
+int main() {
+
+ // TODO: Make these all command-line parameters
+ const std::size_t ONE_GB = 1 << 30;
+ const std::size_t SIXTY_FOUR_MB = 1 << 26;
+ const std::size_t NUMBER_OF_BLOCKS = 2;
+
+ // Vocab strings will be written to this file, forgotten, and reconstituted
+ // later. This saves memory.
+ util::scoped_fd vocab_file(util::MakeTemp("/tmp/"));
+ std::vector<uint64_t> counts;
+ util::stream::Chains chains;
+ {
+ // Use consistent vocab ids across models.
+ lm::ngram::GrowableVocab<lm::ngram::WriteUniqueWords> vocab(10, vocab_file.get());
+ lm::interpolate::ARPAToStream reader(STDIN_FILENO, vocab);
+ counts = reader.Counts();
+
+ // Configure a chain for each order. TODO: extract chain balance heuristics from lm/builder/pipeline.cc
+ chains.Init(reader.Order());
+
+ for (std::size_t i = 0; i < reader.Order(); ++i) {
+
+ // The following call to chains.push_back() invokes the Chain constructor
+ // and appends the newly created Chain object to the chains array
+ chains.push_back(util::stream::ChainConfig(lm::builder::NGram::TotalSize(i + 1), NUMBER_OF_BLOCKS, ONE_GB));
+
+ }
+
+ // The following call to the >> method of chains
+ // constructs a ChainPosition for each chain in chains using Chain::Add();
+ // that function begins with a call to Chain::Start()
+ // that allocates memory for the chain.
+ //
+ // After the following call to the >> method of chains,
+ // a new thread will be running
+ // and will be executing the reader.Run() method
+ // to read through the body of the ARPA file from standard input.
+ //
+ // For each n-gram line in the ARPA file,
+ // the thread executing reader.Run()
+ // will write the probability, the n-gram, and the backoff
+ // to the appropriate location in the appropriate chain
+ // (for details, see the ReadNGram() method in read_arpa.hh).
+ //
+ // Normally >> copies then runs so inline >> works. But here we want a ref.
+ chains >> boost::ref(reader);
+
+
+ util::stream::SortConfig sort_config;
+ sort_config.temp_prefix = "/tmp/";
+ sort_config.buffer_size = SIXTY_FOUR_MB;
+ sort_config.total_memory = ONE_GB;
+
+ // Parallel sorts across orders (though somewhat limited because ARPA files are not being read in parallel across orders)
+ lm::builder::Sorts<lm::builder::SuffixOrder> sorts(reader.Order());
+ for (std::size_t i = 0; i < reader.Order(); ++i) {
+
+ // The following call to sorts.push_back() invokes the Sort constructor
+ // and appends the newly constructed Sort object to the sorts array.
+ //
+ // After the construction of the Sort object,
+ // two new threads will be running (each owned by the chains[i] object).
+ //
+ // The first new thread will execute BlockSorter.Run() to sort the n-gram entries of order (i+1)
+ // that were previously read into chains[i] by the ARPA input reader thread.
+ //
+ // The second new thread will execute WriteAndRecycle.Run()
+ // to write each sorted block of data to disk as a temporary file.
+ sorts.push_back(chains[i], sort_config, lm::builder::SuffixOrder(i + 1));
+
+ }
+
+ // Output to the same chains.
+ for (std::size_t i = 0; i < reader.Order(); ++i) {
+
+ // The following call to Chain::Wait()
+ // joins the threads owned by chains[i].
+ //
+ // As such the following call won't return
+ // until all threads owned by chains[i] have completed.
+ //
+ // The following call also resets chain[i]
+ // so that it can be reused
+ // (including free'ing the memory previously used by the chain)
+ chains[i].Wait();
+
+
+ // In an ideal world (without memory restrictions)
+ // we could merge all of the previously sorted blocks
+ // by reading them all completely into memory
+ // and then running merge sort over them.
+ //
+ // In the real world, we have memory restrictions;
+ // depending on how many blocks we have,
+ // and how much memory we can use to read from each block (sort_config.buffer_size)
+ // it may be the case that we have insufficient memory
+ // to read sort_config.buffer_size of data from each block from disk.
+ //
+ // If this occurs, then it will be necessary to perform one or more rounds of merge sort on disk;
+ // doing so will reduce the number of blocks that we will eventually need to read from
+ // when performing the final round of merge sort in memory.
+ //
+ // So, the following call determines whether it is necessary
+ // to perform one or more rounds of merge sort on disk;
+ // if such on-disk merge sorting is required, such sorting is performed.
+ //
+ // Finally, the following method launches a thread that calls OwningMergingReader.Run()
+ // to perform the final round of merge sort in memory.
+ //
+ // Merge sort could have be invoked directly
+ // so that merge sort memory doesn't coexist with Chain memory.
+ sorts[i].Output(chains[i]);
+ }
+
+ // sorts can go out of scope even though it's still writing to the chains.
+ // note that vocab going out of scope flushes to vocab_file.
+ }
+
+
+ // Get the vocabulary mapping used for this ARPA file
+ lm::builder::VocabReconstitute reconstitute(vocab_file.get());
+
+ // After the following call to the << method of chains,
+ // a new thread will be running
+ // and will be executing the Run() method of PrintARPA
+ // to print the final sorted ARPA file to standard output.
+ chains >> lm::builder::PrintARPA(reconstitute, counts, NULL, STDOUT_FILENO);
+
+ // Joins all threads that chains owns,
+ // and does a for loop over each chain object in chains,
+ // calling chain.Wait() on each such chain object
+ chains.Wait(true);
+
+}