summaryrefslogtreecommitdiff
path: root/klm/lm/filter
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/filter')
-rw-r--r--klm/lm/filter/arpa_io.cc122
-rw-r--r--klm/lm/filter/arpa_io.hh122
-rw-r--r--klm/lm/filter/count_io.hh91
-rw-r--r--klm/lm/filter/format.hh250
-rw-r--r--klm/lm/filter/main.cc249
-rw-r--r--klm/lm/filter/phrase.cc281
-rw-r--r--klm/lm/filter/phrase.hh153
-rw-r--r--klm/lm/filter/thread.hh167
-rw-r--r--klm/lm/filter/vocab.cc54
-rw-r--r--klm/lm/filter/vocab.hh132
-rw-r--r--klm/lm/filter/wrapper.hh58
11 files changed, 1679 insertions, 0 deletions
diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc
new file mode 100644
index 00000000..caf8df95
--- /dev/null
+++ b/klm/lm/filter/arpa_io.cc
@@ -0,0 +1,122 @@
+#include "lm/filter/arpa_io.hh"
+#include "util/file_piece.hh"
+
+#include <iostream>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+#include <errno.h>
+#include <string.h>
+
+namespace lm {
+
+ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") {
+ what_.append(message.data(), message.size());
+}
+
+ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
+ what_ = "Error: ";
+ what_.append(message.data(), message.size());
+ what_ += " in line '";
+ what_.append(line.data(), line.size());
+ what_ += "'.";
+}
+
+ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw()
+ : what_(std::string(message) + " file " + file_name), file_name_(file_name) {
+ if (errno) {
+ char buf[1024];
+ buf[0] = 0;
+#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE
+ const char *add = buf;
+ if (!strerror_r(errno, buf, 1024)) {
+#else
+ const char *add = strerror_r(errno, buf, 1024);
+ if (add) {
+#endif
+ what_ += " :";
+ what_ += add;
+ }
+ }
+}
+
+// Seeking is the responsibility of the caller.
+void WriteCounts(std::ostream &out, const std::vector<size_t> &number) {
+ out << "\n\\data\\\n";
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ out << "ngram " << i+1 << "=" << number[i] << '\n';
+ }
+ out << '\n';
+}
+
+size_t SizeNeededForCounts(const std::vector<size_t> &number) {
+ std::ostringstream buf;
+ WriteCounts(buf, number);
+ return buf.tellp();
+}
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) {
+ try {
+ file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit);
+ if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) {
+ std::cerr << "Warning: could not enlarge buffer for " << name << std::endl;
+ buffer_.reset();
+ }
+ file_.open(name, std::ios::out | std::ios::binary);
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Opening", file_name_);
+ }
+}
+
+void ARPAOutput::ReserveForCounts(std::streampos reserve) {
+ try {
+ for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
+ file_ << '\n';
+ }
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_);
+ }
+}
+
+void ARPAOutput::BeginLength(unsigned int length) {
+ fast_counter_ = 0;
+ try {
+ file_ << '\\' << length << "-grams:" << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing n-gram header to ", file_name_);
+ }
+}
+
+void ARPAOutput::EndLength(unsigned int length) {
+ try {
+ file_ << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing blank at end of count list to ", file_name_);
+ }
+ if (length > counts_.size()) {
+ counts_.resize(length);
+ }
+ counts_[length - 1] = fast_counter_;
+}
+
+void ARPAOutput::Finish() {
+ try {
+ file_ << "\\end\\\n";
+ file_.seekp(0);
+ WriteCounts(file_, counts_);
+ file_ << std::flush;
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_);
+ }
+}
+
+} // namespace lm
diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh
new file mode 100644
index 00000000..90f48447
--- /dev/null
+++ b/klm/lm/filter/arpa_io.hh
@@ -0,0 +1,122 @@
+#ifndef LM_FILTER_ARPA_IO__
+#define LM_FILTER_ARPA_IO__
+/* Input and output for ARPA format language model files.
+ */
+#include "lm/read_arpa.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/scoped_array.hpp>
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include <err.h>
+#include <string.h>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+
+class ARPAInputException : public util::Exception {
+ public:
+ explicit ARPAInputException(const StringPiece &message) throw();
+ explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
+ virtual ~ARPAInputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ private:
+ std::string what_;
+};
+
+class ARPAOutputException : public std::exception {
+ public:
+ ARPAOutputException(const char *prefix, const std::string &file_name) throw();
+ virtual ~ARPAOutputException() throw() {}
+
+ const char *what() const throw() { return what_.c_str(); }
+
+ const std::string &File() const throw() { return file_name_; }
+
+ private:
+ std::string what_;
+ const std::string file_name_;
+};
+
+// Handling for the counts of n-grams at the beginning of ARPA files.
+size_t SizeNeededForCounts(const std::vector<size_t> &number);
+
+/* Writes an ARPA file. This has to be seekable so the counts can be written
+ * at the end. Hence, I just have it own a std::fstream instead of accepting
+ * a separately held std::ostream.
+ */
+class ARPAOutput : boost::noncopyable {
+ public:
+ explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
+
+ void ReserveForCounts(std::streampos reserve);
+
+ void BeginLength(unsigned int length);
+
+ void AddNGram(const StringPiece &line) {
+ try {
+ file_ << line << '\n';
+ } catch (const std::ios_base::failure &f) {
+ throw ARPAOutputException("Writing an n-gram", file_name_);
+ }
+ ++fast_counter_;
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void EndLength(unsigned int length);
+
+ void Finish();
+
+ private:
+ const std::string file_name_;
+ boost::scoped_array<char> buffer_;
+ std::fstream file_;
+ size_t fast_counter_;
+ std::vector<size_t> counts_;
+};
+
+
+template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, size_t number, Output &out) {
+ ReadNGramHeader(in, length);
+ out.BeginLength(length);
+ for (size_t i = 0; i < number; ++i) {
+ StringPiece line = in.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) throw ARPAInputException("blank line", line);
+ if (!++tabber) throw ARPAInputException("no tab", line);
+
+ out.AddNGram(*tabber, line);
+ }
+ out.EndLength(length);
+}
+
+template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
+ std::vector<size_t> number;
+ ReadARPACounts(in_lm, number);
+ out.ReserveForCounts(SizeNeededForCounts(number));
+ for (unsigned int i = 0; i < number.size(); ++i) {
+ ReadNGrams(in_lm, i + 1, number[i], out);
+ }
+ ReadEnd(in_lm);
+ out.Finish();
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_ARPA_IO__
diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh
new file mode 100644
index 00000000..97c0fa25
--- /dev/null
+++ b/klm/lm/filter/count_io.hh
@@ -0,0 +1,91 @@
+#ifndef LM_FILTER_COUNT_IO__
+#define LM_FILTER_COUNT_IO__
+
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <err.h>
+
+#include "util/file_piece.hh"
+
+namespace lm {
+
+class CountOutput : boost::noncopyable {
+ public:
+ explicit CountOutput(const char *name) : file_(name, std::ios::out) {}
+
+ void AddNGram(const StringPiece &line) {
+ if (!(file_ << line << '\n')) {
+ err(3, "Writing counts file failed");
+ }
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ AddNGram(line);
+ }
+
+ private:
+ std::fstream file_;
+};
+
+class CountBatch {
+ public:
+ explicit CountBatch(std::streamsize initial_read)
+ : initial_read_(initial_read) {
+ buffer_.reserve(initial_read);
+ }
+
+ void Read(std::istream &in) {
+ buffer_.resize(initial_read_);
+ in.read(&*buffer_.begin(), initial_read_);
+ buffer_.resize(in.gcount());
+ char got;
+ while (in.get(got) && got != '\n')
+ buffer_.push_back(got);
+ }
+
+ template <class Output> void Send(Output &out) {
+ for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) {
+ util::TokenIter<util::SingleCharacter> tabber(*line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ util::TokenIter<util::SingleCharacter, true> words(*tabber, ' ');
+ if (!words) {
+ std::cerr << "Line has a tab but no words.\n";
+ continue;
+ }
+ out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line);
+ }
+ }
+
+ private:
+ std::streamsize initial_read_;
+
+ // This could have been a std::string but that's less happy with raw writes.
+ std::vector<char> buffer_;
+};
+
+template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
+ try {
+ while (true) {
+ StringPiece line = in_file.ReadLine();
+ util::TokenIter<util::SingleCharacter> tabber(line, '\t');
+ if (!tabber) {
+ std::cerr << "Warning: empty n-gram count line being removed\n";
+ continue;
+ }
+ out.AddNGram(*tabber, line);
+ }
+ } catch (const util::EndOfFileException &e) {}
+}
+
+} // namespace lm
+
+#endif // LM_FILTER_COUNT_IO__
diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh
new file mode 100644
index 00000000..7f945b0d
--- /dev/null
+++ b/klm/lm/filter/format.hh
@@ -0,0 +1,250 @@
+#ifndef LM_FILTER_FORMAT_H__
+#define LM_FITLER_FORMAT_H__
+
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/count_io.hh"
+
+#include <boost/lexical_cast.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <iosfwd>
+
+namespace lm {
+
+template <class Single> class MultipleOutput {
+ private:
+ typedef boost::ptr_vector<Single> Singles;
+ typedef typename Singles::iterator SinglesIterator;
+
+ public:
+ MultipleOutput(const char *prefix, size_t number) {
+ files_.reserve(number);
+ std::string tmp;
+ for (unsigned int i = 0; i < number; ++i) {
+ tmp = prefix;
+ tmp += boost::lexical_cast<std::string>(i);
+ files_.push_back(new Single(tmp.c_str()));
+ }
+ }
+
+ void AddNGram(const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(line);
+ }
+
+ template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
+ i->AddNGram(begin, end, line);
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ files_[offset].AddNGram(line);
+ }
+
+ template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ files_[offset].AddNGram(begin, end, line);
+ }
+
+ protected:
+ Singles files_;
+};
+
+class MultipleARPAOutput : public MultipleOutput<ARPAOutput> {
+ public:
+ MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {}
+
+ void ReserveForCounts(std::streampos reserve) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->ReserveForCounts(reserve);
+ }
+
+ void BeginLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->BeginLength(length);
+ }
+
+ void EndLength(unsigned int length) {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->EndLength(length);
+ }
+
+ void Finish() {
+ for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
+ i->Finish();
+ }
+};
+
+template <class Filter, class Output> class DispatchInput {
+ public:
+ DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {}
+
+/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
+ filter_.AddNGram(begin, end, line, output_);
+ }*/
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line) {
+ filter_.AddNGram(ngram, line, output_);
+ }
+
+ protected:
+ Filter &filter_;
+ Output &output_;
+};
+
+template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> {
+ private:
+ typedef DispatchInput<Filter, Output> B;
+
+ public:
+ DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {}
+
+ void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); }
+ void BeginLength(unsigned int length) { B::output_.BeginLength(length); }
+
+ void EndLength(unsigned int length) {
+ B::filter_.Flush();
+ B::output_.EndLength(length);
+ }
+ void Finish() { B::output_.Finish(); }
+};
+
+struct ARPAFormat {
+ typedef ARPAOutput Output;
+ typedef MultipleARPAOutput Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadARPA(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchARPAInput<Filter, Out> dispatcher(filter, output);
+ ReadARPA(in, dispatcher);
+ }
+};
+
+struct CountFormat {
+ typedef CountOutput Output;
+ typedef MultipleOutput<Output> Multiple;
+ static void Copy(util::FilePiece &in, Output &out) {
+ ReadCount(in, out);
+ }
+ template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
+ DispatchInput<Filter, Out> dispatcher(filter, output);
+ ReadCount(in, dispatcher);
+ }
+};
+
+/* For multithreading, the buffer classes hold batches of filter inputs and
+ * outputs in memory. The strings get reused a lot, so keep them around
+ * instead of clearing each time.
+ */
+class InputBuffer {
+ public:
+ InputBuffer() : actual_(0) {}
+
+ void Reserve(size_t size) { lines_.reserve(size); }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ if (lines_.size() == actual_) lines_.resize(lines_.size() + 1);
+ // TODO avoid this copy.
+ std::string &copied = lines_[actual_].line;
+ copied.assign(line.data(), line.size());
+ lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size());
+ ++actual_;
+ }
+
+ template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const {
+ for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) {
+ filter.AddNGram(i->ngram, i->line, output);
+ }
+ }
+
+ void Clear() { actual_ = 0; }
+ bool Empty() { return actual_ == 0; }
+ size_t Size() { return actual_; }
+
+ private:
+ struct Line {
+ std::string line;
+ StringPiece ngram;
+ };
+
+ size_t actual_;
+
+ std::vector<Line> lines_;
+};
+
+class BinaryOutputBuffer {
+ public:
+ BinaryOutputBuffer() {}
+
+ void Reserve(size_t size) {
+ lines_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ lines_.push_back(line);
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) {
+ output.AddNGram(*i);
+ }
+ lines_.clear();
+ }
+
+ private:
+ std::vector<StringPiece> lines_;
+};
+
+class MultipleOutputBuffer {
+ public:
+ MultipleOutputBuffer() : last_(NULL) {}
+
+ void Reserve(size_t size) {
+ annotated_.reserve(size);
+ }
+
+ void AddNGram(const StringPiece &line) {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().line = line;
+ }
+
+ void SingleAddNGram(size_t offset, const StringPiece &line) {
+ if ((line.data() == last_.data()) && (line.length() == last_.length())) {
+ annotated_.back().systems.push_back(offset);
+ } else {
+ annotated_.resize(annotated_.size() + 1);
+ annotated_.back().systems.push_back(offset);
+ annotated_.back().line = line;
+ last_ = line;
+ }
+ }
+
+ template <class Output> void Flush(Output &output) {
+ for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) {
+ if (i->systems.empty()) {
+ output.AddNGram(i->line);
+ } else {
+ for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) {
+ output.SingleAddNGram(*j, i->line);
+ }
+ }
+ }
+ annotated_.clear();
+ }
+
+ private:
+ struct Annotated {
+ // If this is empty, send to all systems.
+ // A filter should never send to all systems and send to a single one.
+ std::vector<size_t> systems;
+ StringPiece line;
+ };
+
+ StringPiece last_;
+
+ std::vector<Annotated> annotated_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_FORMAT_H__
diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc
new file mode 100644
index 00000000..c42243e2
--- /dev/null
+++ b/klm/lm/filter/main.cc
@@ -0,0 +1,249 @@
+#include "lm/filter/arpa_io.hh"
+#include "lm/filter/format.hh"
+#include "lm/filter/phrase.hh"
+#ifndef NTHREAD
+#include "lm/filter/thread.hh"
+#endif
+#include "lm/filter/vocab.hh"
+#include "lm/filter/wrapper.hh"
+#include "util/file_piece.hh"
+
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <memory>
+
+namespace lm {
+namespace {
+
+void DisplayHelp(const char *name) {
+ std::cerr
+ << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
+ "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
+ " parser.\n"
+ "single mode treats the entire input as a single sentence.\n"
+ "multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
+ " a separate line. A separate file is created for each file by appending the\n"
+ " 0-indexed line number to the output file name.\n"
+ "union mode produces one filtered model that is the union of models created by\n"
+ " multiple mode.\n\n"
+ "context means only the context (all but last word) has to pass the filter, but\n"
+ " the entire n-gram is output.\n\n"
+ "phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
+ " phrases can generate the n-gram when assembled in arbitrary order and\n"
+ " clipped. Currently works with multiple or union mode.\n\n"
+ "The file format is set by [raw|arpa] with default arpa:\n"
+ "raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
+ " text. This is useful for ngram count files.\n"
+ "arpa means the ARPA file format for n-gram language models.\n\n"
+#ifndef NTHREAD
+ "threads:m sets m threads (default: conccurrency detected by boost)\n"
+ "batch_size:m sets the batch size for threading. Expect memory usage from this\n"
+ " of 2*threads*batch_size n-grams.\n\n"
+#else
+ "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
+ " threading, compile without this flag against Boost >=1.42.0.\n\n"
+#endif
+ "There are two inputs: vocabulary and model. Either may be given as a file\n"
+ " while the other is on stdin. Specify the type given as a file using\n"
+ " vocab: or model: before the file name. \n\n"
+ "For ARPA format, the output must be seekable. For raw format, it can be a\n"
+ " stream i.e. /dev/stdout\n";
+}
+
+typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode;
+typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
+
+struct Config {
+ Config() :
+#ifndef NTHREAD
+ batch_size(25000),
+ threads(boost::thread::hardware_concurrency()),
+#endif
+ phrase(false),
+ context(false),
+ format(FORMAT_ARPA)
+ {
+#ifndef NTHREAD
+ if (!threads) threads = 1;
+#endif
+ }
+
+#ifndef NTHREAD
+ size_t batch_size;
+ size_t threads;
+#endif
+ bool phrase;
+ bool context;
+ FilterMode mode;
+ Format format;
+};
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
+#ifndef NTHREAD
+ if (config.threads == 1) {
+#endif
+ Format::RunFilter(in_lm, filter, output);
+#ifndef NTHREAD
+ } else {
+ typedef Controller<Filter, OutputBuffer, Output> Threaded;
+ Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
+ Format::RunFilter(in_lm, threading, output);
+ }
+#endif
+}
+
+template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
+ if (config.context) {
+ ContextFilter<Filter> context_filter(filter);
+ RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
+ } else {
+ RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
+ }
+}
+
+template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
+ typedef BinaryFilter<Binary> Filter;
+ RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
+}
+
+template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
+ if (config.mode == MODE_MULTIPLE) {
+ if (config.phrase) {
+ typedef phrase::Multiple Filter;
+ phrase::Substrings substrings;
+ typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
+ } else {
+ typedef vocab::Multiple Filter;
+ boost::unordered_map<std::string, std::vector<unsigned int> > words;
+ typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
+ RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
+ }
+ return;
+ }
+
+ typename Format::Output out(out_name);
+
+ if (config.mode == MODE_COPY) {
+ Format::Copy(in_lm, out);
+ return;
+ }
+
+ if (config.mode == MODE_SINGLE) {
+ vocab::Single::Words words;
+ vocab::ReadSingle(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
+ return;
+ }
+
+ if (config.mode == MODE_UNION) {
+ if (config.phrase) {
+ phrase::Substrings substrings;
+ phrase::ReadMultiple(in_vocab, substrings);
+ DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
+ } else {
+ vocab::Union::Words words;
+ vocab::ReadMultiple(in_vocab, words);
+ DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
+ }
+ return;
+ }
+}
+
+} // namespace
+} // namespace lm
+
+int main(int argc, char *argv[]) {
+ if (argc < 4) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+
+ // I used to have boost::program_options, but some users didn't want to compile boost.
+ lm::Config config;
+ boost::optional<lm::FilterMode> mode;
+ for (int i = 1; i < argc - 2; ++i) {
+ const char *str = argv[i];
+ if (!std::strcmp(str, "copy")) {
+ mode = lm::MODE_COPY;
+ } else if (!std::strcmp(str, "single")) {
+ mode = lm::MODE_SINGLE;
+ } else if (!std::strcmp(str, "multiple")) {
+ mode = lm::MODE_MULTIPLE;
+ } else if (!std::strcmp(str, "union")) {
+ mode = lm::MODE_UNION;
+ } else if (!std::strcmp(str, "phrase")) {
+ config.phrase = true;
+ } else if (!std::strcmp(str, "context")) {
+ config.context = true;
+ } else if (!std::strcmp(str, "arpa")) {
+ config.format = lm::FORMAT_ARPA;
+ } else if (!std::strcmp(str, "raw")) {
+ config.format = lm::FORMAT_COUNT;
+#ifndef NTHREAD
+ } else if (!std::strncmp(str, "threads:", 8)) {
+ config.threads = boost::lexical_cast<size_t>(str + 8);
+ if (!config.threads) {
+ std::cerr << "Specify at least one thread." << std::endl;
+ return 1;
+ }
+ } else if (!std::strncmp(str, "batch_size:", 11)) {
+ config.batch_size = boost::lexical_cast<size_t>(str + 11);
+ if (config.batch_size < 5000) {
+ std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
+ if (!config.batch_size) return 1;
+ }
+#endif
+ } else {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ }
+
+ if (!mode) {
+ lm::DisplayHelp(argv[0]);
+ return 1;
+ }
+ config.mode = *mode;
+
+ if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) {
+ std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
+ return 1;
+ }
+
+ bool cmd_is_model = true;
+ const char *cmd_input = argv[argc - 2];
+ if (!strncmp(cmd_input, "vocab:", 6)) {
+ cmd_is_model = false;
+ cmd_input += 6;
+ } else if (!strncmp(cmd_input, "model:", 6)) {
+ cmd_input += 6;
+ } else if (strchr(cmd_input, ':')) {
+ errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input);
+ } else {
+ std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
+ }
+ std::ifstream cmd_file;
+ std::istream *vocab;
+ if (cmd_is_model) {
+ vocab = &std::cin;
+ } else {
+ cmd_file.open(cmd_input, std::ios::in);
+ if (!cmd_file) {
+ err(2, "Could not open input file %s", cmd_input);
+ }
+ vocab = &cmd_file;
+ }
+
+ util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
+
+ if (config.format == lm::FORMAT_ARPA) {
+ lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
+ } else if (config.format == lm::FORMAT_COUNT) {
+ lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
+ }
+ return 0;
+}
diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc
new file mode 100644
index 00000000..1bef2a3f
--- /dev/null
+++ b/klm/lm/filter/phrase.cc
@@ -0,0 +1,281 @@
+#include "lm/filter/phrase.hh"
+
+#include "lm/filter/format.hh"
+
+#include <algorithm>
+#include <functional>
+#include <iostream>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include <ctype.h>
+
+namespace lm {
+namespace phrase {
+
+unsigned int ReadMultiple(std::istream &in, Substrings &out) {
+ bool sentence_content = false;
+ unsigned int sentence_id = 0;
+ std::vector<Hash> phrase;
+ std::string word;
+ while (in) {
+ char c;
+ // Gather a word.
+ while (!isspace(c = in.get()) && in) word += c;
+ // Treat EOF like a newline.
+ if (!in) c = '\n';
+ // Add the word to the phrase.
+ if (!word.empty()) {
+ phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
+ word.clear();
+ }
+ if (c == ' ') continue;
+ // It's more than just a space. Close out the phrase.
+ if (!phrase.empty()) {
+ sentence_content = true;
+ out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
+ phrase.clear();
+ }
+ if (c == '\t' || c == '\v') continue;
+ // It's more than a space or tab: a newline.
+ if (sentence_content) {
+ ++sentence_id;
+ sentence_content = false;
+ }
+ }
+ if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
+ return sentence_id + sentence_content;
+}
+
+namespace detail { const StringPiece kEndSentence("</s>"); }
+
+namespace {
+
+typedef unsigned int Sentence;
+typedef std::vector<Sentence> Sentences;
+
+class Vertex;
+
+class Arc {
+ public:
+ Arc() {}
+
+ // For arcs from one vertex to another.
+ void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) {
+ Set(to, intersect);
+ from_ = &from;
+ }
+
+ /* For arcs from before the n-gram begins to somewhere in the n-gram (right
+ * aligned). These have no from_ vertex; it implictly matches every
+ * sentence. This also handles when the n-gram is a substring of a phrase.
+ */
+ void SetRight(Vertex &to, const Sentences &complete) {
+ Set(to, complete);
+ from_ = NULL;
+ }
+
+ Sentence Current() const {
+ return *current_;
+ }
+
+ bool Empty() const {
+ return current_ == last_;
+ }
+
+ /* When this function returns:
+ * If Empty() then there's nothing left from this intersection.
+ *
+ * If Current() == to then to is part of the intersection.
+ *
+ * Otherwise, Current() > to. In this case, to is not part of the
+ * intersection and neither is anything < Current(). To determine if
+ * any value >= Current() is in the intersection, call LowerBound again
+ * with the value.
+ */
+ void LowerBound(const Sentence to);
+
+ private:
+ void Set(Vertex &to, const Sentences &sentences);
+
+ const Sentence *current_;
+ const Sentence *last_;
+ Vertex *from_;
+};
+
+struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
+ bool operator()(const Arc *first, const Arc *second) const {
+ return first->Current() > second->Current();
+ }
+};
+
+class Vertex {
+ public:
+ Vertex() : current_(0) {}
+
+ Sentence Current() const {
+ return current_;
+ }
+
+ bool Empty() const {
+ return incoming_.empty();
+ }
+
+ void LowerBound(const Sentence to);
+
+ private:
+ friend class Arc;
+
+ void AddIncoming(Arc *arc) {
+ if (!arc->Empty()) incoming_.push(arc);
+ }
+
+ unsigned int current_;
+ std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
+};
+
+void Arc::LowerBound(const Sentence to) {
+ current_ = std::lower_bound(current_, last_, to);
+ // If *current_ > to, don't advance from_. The intervening values of
+ // from_ may be useful for another one of its outgoing arcs.
+ if (!from_ || Empty() || (Current() > to)) return;
+ assert(Current() == to);
+ from_->LowerBound(to);
+ if (from_->Empty()) {
+ current_ = last_;
+ return;
+ }
+ assert(from_->Current() >= to);
+ if (from_->Current() > to) {
+ current_ = std::lower_bound(current_ + 1, last_, from_->Current());
+ }
+}
+
+void Arc::Set(Vertex &to, const Sentences &sentences) {
+ current_ = &*sentences.begin();
+ last_ = &*sentences.end();
+ to.AddIncoming(this);
+}
+
+void Vertex::LowerBound(const Sentence to) {
+ if (Empty()) return;
+ // Union lower bound.
+ while (true) {
+ Arc *top = incoming_.top();
+ if (top->Current() > to) {
+ current_ = top->Current();
+ return;
+ }
+ // If top->Current() == to, we still need to verify that's an actual
+ // element and not just a bound.
+ incoming_.pop();
+ top->LowerBound(to);
+ if (!top->Empty()) {
+ incoming_.push(top);
+ if (top->Current() == to) {
+ current_ = to;
+ return;
+ }
+ } else if (Empty()) {
+ return;
+ }
+ }
+}
+
+void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) {
+ assert(!hashes.empty());
+
+ const Hash *const first_word = &*hashes.begin();
+ const Hash *const last_word = &*hashes.end() - 1;
+
+ Hash hash = 0;
+ const Sentences *found;
+ // Phrases starting at or before the first word in the n-gram.
+ {
+ Vertex *vertex = vertices;
+ for (const Hash *word = first_word; ; ++word, ++vertex) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
+ // Now hash is [hashes.begin(), word].
+ if (word == last_word) {
+ if (phrase.FindSubstring(hash, found))
+ (free_arc++)->SetRight(*vertex, *found);
+ break;
+ }
+ if (!phrase.FindRight(hash, found)) break;
+ (free_arc++)->SetRight(*vertex, *found);
+ }
+ }
+
+ // Phrases starting at the second or later word in the n-gram.
+ Vertex *vertex_from = vertices;
+ for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
+ hash = 0;
+ Vertex *vertex_to = vertex_from + 1;
+ for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
+ // Notice that word_to and vertex_to have the same index.
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
+ // Now hash covers [word_from, word_to].
+ if (word_to == last_word) {
+ if (phrase.FindLeft(hash, found))
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ break;
+ }
+ if (!phrase.FindPhrase(hash, found)) break;
+ (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
+ }
+ }
+}
+
+} // namespace
+
+namespace detail {
+
+} // namespace detail
+
+bool Union::Evaluate() {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return false;
+ if (last_vertex.Current() == lower) return true;
+ lower = last_vertex.Current();
+ }
+}
+
+template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
+ assert(!hashes_.empty());
+ // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.
+ Vertex vertices[hashes_.size()];
+ // One for every substring.
+ Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2];
+ BuildGraph(substrings_, hashes_, vertices, arcs);
+ Vertex &last_vertex = vertices[hashes_.size() - 1];
+
+ unsigned int lower = 0;
+ while (true) {
+ last_vertex.LowerBound(lower);
+ if (last_vertex.Empty()) return;
+ if (last_vertex.Current() == lower) {
+ output.SingleAddNGram(lower, line);
+ ++lower;
+ } else {
+ lower = last_vertex.Current();
+ }
+ }
+}
+
+template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
+template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
+template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
+
+} // namespace phrase
+} // namespace lm
diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh
new file mode 100644
index 00000000..07479dea
--- /dev/null
+++ b/klm/lm/filter/phrase.hh
@@ -0,0 +1,153 @@
+#ifndef LM_FILTER_PHRASE_H__
+#define LM_FILTER_PHRASE_H__
+
+#include "util/murmur_hash.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <iosfwd>
+#include <vector>
+
+#define LM_FILTER_PHRASE_METHOD(caps, lower) \
+bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
+ Table::const_iterator i(table_.find(key));\
+ if (i==table_.end()) return false; \
+ out = &i->second.lower; \
+ return true; \
+}
+
+namespace lm {
+namespace phrase {
+
+typedef uint64_t Hash;
+
+class Substrings {
+ private:
+ /* This is the value in a hash table where the key is a string. It indicates
+ * four sets of sentences:
+ * substring is sentences with a phrase containing the key as a substring.
+ * left is sentencess with a phrase that begins with the key (left aligned).
+ * right is sentences with a phrase that ends with the key (right aligned).
+ * phrase is sentences where the key is a phrase.
+ * Each set is encoded as a vector of sentence ids in increasing order.
+ */
+ struct SentenceRelation {
+ std::vector<unsigned int> substring, left, right, phrase;
+ };
+ /* Most of the CPU is hash table lookups, so let's not complicate it with
+ * vector equality comparisons. If a collision happens, the SentenceRelation
+ * structure will contain the union of sentence ids over the colliding strings.
+ * In that case, the filter will be slightly more permissive.
+ * The key here is the same as boost's hash of std::vector<std::string>.
+ */
+ typedef boost::unordered_map<Hash, SentenceRelation> Table;
+
+ public:
+ Substrings() {}
+
+ /* If the string isn't a substring of any phrase, return NULL. Otherwise,
+ * return a pointer to std::vector<unsigned int> listing sentences with
+ * matching phrases. This set may be empty for Left, Right, or Phrase.
+ * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
+ */
+ LM_FILTER_PHRASE_METHOD(Substring, substring)
+ LM_FILTER_PHRASE_METHOD(Left, left)
+ LM_FILTER_PHRASE_METHOD(Right, right)
+ LM_FILTER_PHRASE_METHOD(Phrase, phrase)
+
+ // sentence_id must be non-decreasing. Iterators are over words in the phrase.
+ template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
+ // Iterate over all substrings.
+ for (Iterator start = begin; start != end; ++start) {
+ Hash hash = 0;
+ SentenceRelation *relation;
+ for (Iterator finish = start; finish != end; ++finish) {
+ hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
+ // Now hash is of [start, finish].
+ relation = &table_[hash];
+ AppendSentence(relation->substring, sentence_id);
+ if (start == begin) AppendSentence(relation->left, sentence_id);
+ }
+ AppendSentence(relation->right, sentence_id);
+ if (start == begin) AppendSentence(relation->phrase, sentence_id);
+ }
+ }
+
+ private:
+ void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
+ if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
+ }
+
+ Table table_;
+};
+
+// Read a file with one sentence per line containing tab-delimited phrases of
+// space-separated words.
+unsigned int ReadMultiple(std::istream &in, Substrings &out);
+
+namespace detail {
+extern const StringPiece kEndSentence;
+
+template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
+ hashes.clear();
+ if (i == end) return;
+ // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
+ if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
+ ++i;
+ }
+ for (; i != end && (*i != kEndSentence); ++i) {
+ hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
+ }
+}
+
+} // namespace detail
+
+class Union {
+ public:
+ explicit Union(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ detail::MakeHashes(begin, end, hashes_);
+ return hashes_.empty() || Evaluate();
+ }
+
+ private:
+ bool Evaluate();
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+class Multiple {
+ public:
+ explicit Multiple(const Substrings &substrings) : substrings_(substrings) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ detail::MakeHashes(begin, end, hashes_);
+ if (hashes_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+ Evaluate(line, output);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ template <class Output> void Evaluate(const StringPiece &line, Output &output);
+
+ std::vector<Hash> hashes_;
+
+ const Substrings &substrings_;
+};
+
+} // namespace phrase
+} // namespace lm
+#endif // LM_FILTER_PHRASE_H__
diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh
new file mode 100644
index 00000000..e785b263
--- /dev/null
+++ b/klm/lm/filter/thread.hh
@@ -0,0 +1,167 @@
+#ifndef LM_FILTER_THREAD_H__
+#define LM_FILTER_THREAD_H__
+
+#include "util/thread_pool.hh"
+
+#include <boost/utility/in_place_factory.hpp>
+
+#include <deque>
+#include <stack>
+
+namespace lm {
+
+template <class OutputBuffer> class ThreadBatch {
+ public:
+ ThreadBatch() {}
+
+ void Reserve(size_t size) {
+ input_.Reserve(size);
+ output_.Reserve(size);
+ }
+
+ // File reading thread.
+ InputBuffer &Fill(uint64_t sequence) {
+ sequence_ = sequence;
+ // Why wait until now to clear instead of after output? free in the same
+ // thread as allocated.
+ input_.Clear();
+ return input_;
+ }
+
+ // Filter worker thread.
+ template <class Filter> void CallFilter(Filter &filter) {
+ input_.CallFilter(filter, output_);
+ }
+
+ uint64_t Sequence() const { return sequence_; }
+
+ // File writing thread.
+ template <class RealOutput> void Flush(RealOutput &output) {
+ output_.Flush(output);
+ }
+
+ private:
+ InputBuffer input_;
+ OutputBuffer output_;
+
+ uint64_t sequence_;
+};
+
+template <class Batch, class Filter> class FilterWorker {
+ public:
+ typedef Batch *Request;
+
+ FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
+
+ void operator()(Request request) {
+ request->CallFilter(filter_);
+ done_.Produce(request);
+ }
+
+ private:
+ Filter filter_;
+
+ util::PCQueue<Request> &done_;
+};
+
+// There should only be one OutputWorker.
+template <class Batch, class Output> class OutputWorker {
+ public:
+ typedef Batch *Request;
+
+ OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
+
+ void operator()(Request request) {
+ assert(request->Sequence() >= base_sequence_);
+ // Assemble the output in order.
+ uint64_t pos = request->Sequence() - base_sequence_;
+ if (pos >= ordering_.size()) {
+ ordering_.resize(pos + 1, NULL);
+ }
+ ordering_[pos] = request;
+ while (!ordering_.empty() && ordering_.front()) {
+ ordering_.front()->Flush(output_);
+ done_.Produce(ordering_.front());
+ ordering_.pop_front();
+ ++base_sequence_;
+ }
+ }
+
+ private:
+ Output &output_;
+
+ util::PCQueue<Request> &done_;
+
+ std::deque<Request> ordering_;
+
+ uint64_t base_sequence_;
+};
+
+template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
+ private:
+ typedef ThreadBatch<OutputBuffer> Batch;
+
+ public:
+ Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
+ : batch_size_(batch_size), queue_size_(queue),
+ batches_(queue),
+ to_read_(queue),
+ output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
+ filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
+ sequence_(0) {
+ for (size_t i = 0; i < queue; ++i) {
+ batches_[i].Reserve(batch_size);
+ local_read_.push(&batches_[i]);
+ }
+ NewInput();
+ }
+
+ void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
+ input_->AddNGram(ngram, line, output);
+ if (input_->Size() == batch_size_) {
+ FlushInput();
+ NewInput();
+ }
+ }
+
+ void Flush() {
+ FlushInput();
+ while (local_read_.size() < queue_size_) {
+ MoveRead();
+ }
+ NewInput();
+ }
+
+ private:
+ void FlushInput() {
+ if (input_->Empty()) return;
+ filter_.Produce(local_read_.top());
+ local_read_.pop();
+ if (local_read_.empty()) MoveRead();
+ }
+
+ void NewInput() {
+ input_ = &local_read_.top()->Fill(sequence_++);
+ }
+
+ void MoveRead() {
+ local_read_.push(to_read_.Consume());
+ }
+
+ const size_t batch_size_;
+ const size_t queue_size_;
+
+ std::vector<Batch> batches_;
+
+ util::PCQueue<Batch*> to_read_;
+ std::stack<Batch*> local_read_;
+ util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
+ util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
+
+ uint64_t sequence_;
+ InputBuffer *input_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_THREAD_H__
diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc
new file mode 100644
index 00000000..7ee4e84b
--- /dev/null
+++ b/klm/lm/filter/vocab.cc
@@ -0,0 +1,54 @@
+#include "lm/filter/vocab.hh"
+
+#include <istream>
+#include <iostream>
+
+#include <ctype.h>
+#include <err.h>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) {
+ in.exceptions(std::istream::badbit);
+ std::string word;
+ while (in >> word) {
+ out.insert(word);
+ }
+}
+
+namespace {
+bool IsLineEnd(std::istream &in) {
+ int got;
+ do {
+ got = in.get();
+ if (!in) return true;
+ if (got == '\n') return true;
+ } while (isspace(got));
+ in.unget();
+ return false;
+}
+}// namespace
+
+// Read space separated words in enter separated lines. These lines can be
+// very long, so don't read an entire line at a time.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
+ in.exceptions(std::istream::badbit);
+ unsigned int sentence = 0;
+ bool used_id = false;
+ std::string word;
+ while (in >> word) {
+ used_id = true;
+ std::vector<unsigned int> &posting = out[word];
+ if (posting.empty() || (posting.back() != sentence))
+ posting.push_back(sentence);
+ if (IsLineEnd(in)) {
+ ++sentence;
+ used_id = false;
+ }
+ }
+ return sentence + used_id;
+}
+
+} // namespace vocab
+} // namespace lm
diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh
new file mode 100644
index 00000000..e2b6adff
--- /dev/null
+++ b/klm/lm/filter/vocab.hh
@@ -0,0 +1,132 @@
+#ifndef LM_FILTER_VOCAB_H__
+#define LM_FILTER_VOCAB_H__
+
+// Vocabulary-based filters for language models.
+
+#include "util/multi_intersection.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/range/iterator_range.hpp>
+#include <boost/unordered/unordered_map.hpp>
+#include <boost/unordered/unordered_set.hpp>
+
+#include <string>
+#include <vector>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
+
+// Read one sentence vocabulary per line. Return the number of sentences.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
+
+/* Is this a special tag like <s> or <UNK>? This actually includes anything
+ * surrounded with < and >, which most tokenizers separate for real words, so
+ * this should not catch real words as it looks at a single token.
+ */
+inline bool IsTag(const StringPiece &value) {
+ // The parser should never give an empty string.
+ assert(!value.empty());
+ return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
+}
+
+class Single {
+ public:
+ typedef boost::unordered_set<std::string> Words;
+
+ explicit Single(const Words &vocab) : vocab_(vocab) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ for (Iterator i = begin; i != end; ++i) {
+ if (IsTag(*i)) continue;
+ if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
+ }
+ return true;
+ }
+
+ private:
+ const Words &vocab_;
+};
+
+class Union {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ sets_.clear();
+
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return false;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ return (sets_.empty() || util::FirstIntersection(sets_));
+ }
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+class Multiple {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ Multiple(const Words &vocabs) : vocabs_(vocabs) {}
+
+ private:
+ // Callback from AllIntersection that does AddNGram.
+ template <class Output> class Callback {
+ public:
+ Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
+
+ void operator()(unsigned int index) {
+ out_.SingleAddNGram(index, line_);
+ }
+
+ private:
+ Output &out_;
+ const StringPiece &line_;
+ };
+
+ public:
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ sets_.clear();
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ if (sets_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+
+ Callback<Output> cb(output, line);
+ util::AllIntersection(sets_, cb);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+} // namespace vocab
+} // namespace lm
+
+#endif // LM_FILTER_VOCAB_H__
diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh
new file mode 100644
index 00000000..90b07a08
--- /dev/null
+++ b/klm/lm/filter/wrapper.hh
@@ -0,0 +1,58 @@
+#ifndef LM_FILTER_WRAPPER_H__
+#define LM_FILTER_WRAPPER_H__
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace lm {
+
+// Provide a single-output filter with the same interface as a
+// multiple-output filter so clients code against one interface.
+template <class Binary> class BinaryFilter {
+ public:
+ // Binary modes are just references (and a set) and it makes the API cleaner to copy them.
+ explicit BinaryFilter(Binary binary) : binary_(binary) {}
+
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ if (binary_.PassNGram(begin, end))
+ output.AddNGram(line);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ Binary binary_;
+};
+
+// Wrap another filter to pay attention only to context words
+template <class FilterT> class ContextFilter {
+ public:
+ typedef FilterT Filter;
+
+ explicit ContextFilter(Filter &backend) : backend_(backend) {}
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ pieces_.clear();
+ // TODO: this copy could be avoided by a lookahead iterator.
+ std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_));
+ backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ std::vector<StringPiece> pieces_;
+
+ Filter backend_;
+};
+
+} // namespace lm
+
+#endif // LM_FILTER_WRAPPER_H__