diff options
Diffstat (limited to 'klm/lm/filter')
-rw-r--r-- | klm/lm/filter/arpa_io.cc | 108 | ||||
-rw-r--r-- | klm/lm/filter/arpa_io.hh | 115 | ||||
-rw-r--r-- | klm/lm/filter/count_io.hh | 91 | ||||
-rw-r--r-- | klm/lm/filter/filter_main.cc | 248 | ||||
-rw-r--r-- | klm/lm/filter/format.hh | 250 | ||||
-rw-r--r-- | klm/lm/filter/phrase.cc | 281 | ||||
-rw-r--r-- | klm/lm/filter/phrase.hh | 154 | ||||
-rw-r--r-- | klm/lm/filter/thread.hh | 167 | ||||
-rw-r--r-- | klm/lm/filter/vocab.cc | 54 | ||||
-rw-r--r-- | klm/lm/filter/vocab.hh | 133 | ||||
-rw-r--r-- | klm/lm/filter/wrapper.hh | 58 |
11 files changed, 1659 insertions, 0 deletions
diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc new file mode 100644 index 00000000..f8568ac4 --- /dev/null +++ b/klm/lm/filter/arpa_io.cc @@ -0,0 +1,108 @@ +#include "lm/filter/arpa_io.hh" +#include "util/file_piece.hh" + +#include <iostream> +#include <ostream> +#include <string> +#include <vector> + +#include <ctype.h> +#include <errno.h> +#include <string.h> + +namespace lm { + +ARPAInputException::ARPAInputException(const StringPiece &message) throw() { + *this << message; +} + +ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() { + *this << message << " in line " << line; +} + +ARPAInputException::~ARPAInputException() throw() {} + +ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw() { + *this << message << " in file " << file_name; +} + +ARPAOutputException::~ARPAOutputException() throw() {} + +// Seeking is the responsibility of the caller. +void WriteCounts(std::ostream &out, const std::vector<uint64_t> &number) { + out << "\n\\data\\\n"; + for (unsigned int i = 0; i < number.size(); ++i) { + out << "ngram " << i+1 << "=" << number[i] << '\n'; + } + out << '\n'; +} + +size_t SizeNeededForCounts(const std::vector<uint64_t> &number) { + std::ostringstream buf; + WriteCounts(buf, number); + return buf.tellp(); +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) { + try { + file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit); + if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) { + std::cerr << "Warning: could not enlarge buffer for " << name << std::endl; + buffer_.reset(); + } + file_.open(name, std::ios::out | std::ios::binary); + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Opening", file_name_); + } +} + +void ARPAOutput::ReserveForCounts(std::streampos reserve) { + try { + for (std::streampos i = 0; i < reserve; i += std::streampos(1)) { + file_ << '\n'; + } + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_); + } +} + +void ARPAOutput::BeginLength(unsigned int length) { + fast_counter_ = 0; + try { + file_ << '\\' << length << "-grams:" << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing n-gram header to ", file_name_); + } +} + +void ARPAOutput::EndLength(unsigned int length) { + try { + file_ << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blank at end of count list to ", file_name_); + } + if (length > counts_.size()) { + counts_.resize(length); + } + counts_[length - 1] = fast_counter_; +} + +void ARPAOutput::Finish() { + try { + file_ << "\\end\\\n"; + file_.seekp(0); + WriteCounts(file_, counts_); + file_ << std::flush; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_); + } +} + +} // namespace lm diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh new file mode 100644 index 00000000..5b31620b --- /dev/null +++ b/klm/lm/filter/arpa_io.hh @@ -0,0 +1,115 @@ +#ifndef LM_FILTER_ARPA_IO__ +#define LM_FILTER_ARPA_IO__ +/* Input and output for ARPA format language model files. + */ +#include "lm/read_arpa.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include <boost/noncopyable.hpp> +#include <boost/scoped_array.hpp> + +#include <fstream> +#include <string> +#include <vector> + +#include <err.h> +#include <string.h> +#include <stdint.h> + +namespace util { class FilePiece; } + +namespace lm { + +class ARPAInputException : public util::Exception { + public: + explicit ARPAInputException(const StringPiece &message) throw(); + explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw(); + virtual ~ARPAInputException() throw(); +}; + +class ARPAOutputException : public util::ErrnoException { + public: + ARPAOutputException(const char *prefix, const std::string &file_name) throw(); + virtual ~ARPAOutputException() throw(); + + const std::string &File() const throw() { return file_name_; } + + private: + const std::string file_name_; +}; + +// Handling for the counts of n-grams at the beginning of ARPA files. +size_t SizeNeededForCounts(const std::vector<uint64_t> &number); + +/* Writes an ARPA file. This has to be seekable so the counts can be written + * at the end. Hence, I just have it own a std::fstream instead of accepting + * a separately held std::ostream. TODO: use the fast one from estimation. + */ +class ARPAOutput : boost::noncopyable { + public: + explicit ARPAOutput(const char *name, size_t buffer_size = 65536); + + void ReserveForCounts(std::streampos reserve); + + void BeginLength(unsigned int length); + + void AddNGram(const StringPiece &line) { + try { + file_ << line << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing an n-gram", file_name_); + } + ++fast_counter_; + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void EndLength(unsigned int length); + + void Finish(); + + private: + const std::string file_name_; + boost::scoped_array<char> buffer_; + std::fstream file_; + size_t fast_counter_; + std::vector<uint64_t> counts_; +}; + + +template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, uint64_t number, Output &out) { + ReadNGramHeader(in, length); + out.BeginLength(length); + for (uint64_t i = 0; i < number; ++i) { + StringPiece line = in.ReadLine(); + util::TokenIter<util::SingleCharacter> tabber(line, '\t'); + if (!tabber) throw ARPAInputException("blank line", line); + if (!++tabber) throw ARPAInputException("no tab", line); + + out.AddNGram(*tabber, line); + } + out.EndLength(length); +} + +template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) { + std::vector<uint64_t> number; + ReadARPACounts(in_lm, number); + out.ReserveForCounts(SizeNeededForCounts(number)); + for (unsigned int i = 0; i < number.size(); ++i) { + ReadNGrams(in_lm, i + 1, number[i], out); + } + ReadEnd(in_lm); + out.Finish(); +} + +} // namespace lm + +#endif // LM_FILTER_ARPA_IO__ diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh new file mode 100644 index 00000000..97c0fa25 --- /dev/null +++ b/klm/lm/filter/count_io.hh @@ -0,0 +1,91 @@ +#ifndef LM_FILTER_COUNT_IO__ +#define LM_FILTER_COUNT_IO__ + +#include <fstream> +#include <iostream> +#include <string> + +#include <err.h> + +#include "util/file_piece.hh" + +namespace lm { + +class CountOutput : boost::noncopyable { + public: + explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + + void AddNGram(const StringPiece &line) { + if (!(file_ << line << '\n')) { + err(3, "Writing counts file failed"); + } + } + + template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + private: + std::fstream file_; +}; + +class CountBatch { + public: + explicit CountBatch(std::streamsize initial_read) + : initial_read_(initial_read) { + buffer_.reserve(initial_read); + } + + void Read(std::istream &in) { + buffer_.resize(initial_read_); + in.read(&*buffer_.begin(), initial_read_); + buffer_.resize(in.gcount()); + char got; + while (in.get(got) && got != '\n') + buffer_.push_back(got); + } + + template <class Output> void Send(Output &out) { + for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) { + util::TokenIter<util::SingleCharacter> tabber(*line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + util::TokenIter<util::SingleCharacter, true> words(*tabber, ' '); + if (!words) { + std::cerr << "Line has a tab but no words.\n"; + continue; + } + out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line); + } + } + + private: + std::streamsize initial_read_; + + // This could have been a std::string but that's less happy with raw writes. + std::vector<char> buffer_; +}; + +template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) { + try { + while (true) { + StringPiece line = in_file.ReadLine(); + util::TokenIter<util::SingleCharacter> tabber(line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + out.AddNGram(*tabber, line); + } + } catch (const util::EndOfFileException &e) {} +} + +} // namespace lm + +#endif // LM_FILTER_COUNT_IO__ diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc new file mode 100644 index 00000000..1a4ba84f --- /dev/null +++ b/klm/lm/filter/filter_main.cc @@ -0,0 +1,248 @@ +#include "lm/filter/arpa_io.hh" +#include "lm/filter/format.hh" +#include "lm/filter/phrase.hh" +#ifndef NTHREAD +#include "lm/filter/thread.hh" +#endif +#include "lm/filter/vocab.hh" +#include "lm/filter/wrapper.hh" +#include "util/file_piece.hh" + +#include <boost/ptr_container/ptr_vector.hpp> + +#include <cstring> +#include <fstream> +#include <iostream> +#include <memory> + +namespace lm { +namespace { + +void DisplayHelp(const char *name) { + std::cerr + << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" + "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" + " parser.\n" + "single mode treats the entire input as a single sentence.\n" + "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" + " a separate line. A separate file is created for each file by appending the\n" + " 0-indexed line number to the output file name.\n" + "union mode produces one filtered model that is the union of models created by\n" + " multiple mode.\n\n" + "context means only the context (all but last word) has to pass the filter, but\n" + " the entire n-gram is output.\n\n" + "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" + " phrases can generate the n-gram when assembled in arbitrary order and\n" + " clipped. Currently works with multiple or union mode.\n\n" + "The file format is set by [raw|arpa] with default arpa:\n" + "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" + " text. This is useful for ngram count files.\n" + "arpa means the ARPA file format for n-gram language models.\n\n" +#ifndef NTHREAD + "threads:m sets m threads (default: conccurrency detected by boost)\n" + "batch_size:m sets the batch size for threading. Expect memory usage from this\n" + " of 2*threads*batch_size n-grams.\n\n" +#else + "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n" + " threading, compile without this flag against Boost >=1.42.0.\n\n" +#endif + "There are two inputs: vocabulary and model. Either may be given as a file\n" + " while the other is on stdin. Specify the type given as a file using\n" + " vocab: or model: before the file name. \n\n" + "For ARPA format, the output must be seekable. For raw format, it can be a\n" + " stream i.e. /dev/stdout\n"; +} + +typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} FilterMode; +typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; + +struct Config { + Config() : +#ifndef NTHREAD + batch_size(25000), + threads(boost::thread::hardware_concurrency()), +#endif + phrase(false), + context(false), + format(FORMAT_ARPA) + { +#ifndef NTHREAD + if (!threads) threads = 1; +#endif + } + +#ifndef NTHREAD + size_t batch_size; + size_t threads; +#endif + bool phrase; + bool context; + FilterMode mode; + Format format; +}; + +template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { +#ifndef NTHREAD + if (config.threads == 1) { +#endif + Format::RunFilter(in_lm, filter, output); +#ifndef NTHREAD + } else { + typedef Controller<Filter, OutputBuffer, Output> Threaded; + Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); + Format::RunFilter(in_lm, threading, output); + } +#endif +} + +template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { + if (config.context) { + ContextFilter<Filter> context_filter(filter); + RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output); + } else { + RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output); + } +} + +template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { + typedef BinaryFilter<Binary> Filter; + RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out); +} + +template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { + if (config.mode == MODE_MULTIPLE) { + if (config.phrase) { + typedef phrase::Multiple Filter; + phrase::Substrings substrings; + typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); + RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out); + } else { + typedef vocab::Multiple Filter; + boost::unordered_map<std::string, std::vector<unsigned int> > words; + typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); + RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out); + } + return; + } + + typename Format::Output out(out_name); + + if (config.mode == MODE_COPY) { + Format::Copy(in_lm, out); + return; + } + + if (config.mode == MODE_SINGLE) { + vocab::Single::Words words; + vocab::ReadSingle(in_vocab, words); + DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out); + return; + } + + if (config.mode == MODE_UNION) { + if (config.phrase) { + phrase::Substrings substrings; + phrase::ReadMultiple(in_vocab, substrings); + DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out); + } else { + vocab::Union::Words words; + vocab::ReadMultiple(in_vocab, words); + DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out); + } + return; + } +} + +} // namespace +} // namespace lm + +int main(int argc, char *argv[]) { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } + + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; +#ifndef NTHREAD + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); + return 1; + } + } + + if (config.mode == lm::MODE_UNSET) { + lm::DisplayHelp(argv[0]); + return 1; + } + + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } + + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + if (!cmd_file) { + err(2, "Could not open input file %s", cmd_input); + } + vocab = &cmd_file; + } + + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; +} diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh new file mode 100644 index 00000000..7f945b0d --- /dev/null +++ b/klm/lm/filter/format.hh @@ -0,0 +1,250 @@ +#ifndef LM_FILTER_FORMAT_H__ +#define LM_FITLER_FORMAT_H__ + +#include "lm/filter/arpa_io.hh" +#include "lm/filter/count_io.hh" + +#include <boost/lexical_cast.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <iosfwd> + +namespace lm { + +template <class Single> class MultipleOutput { + private: + typedef boost::ptr_vector<Single> Singles; + typedef typename Singles::iterator SinglesIterator; + + public: + MultipleOutput(const char *prefix, size_t number) { + files_.reserve(number); + std::string tmp; + for (unsigned int i = 0; i < number; ++i) { + tmp = prefix; + tmp += boost::lexical_cast<std::string>(i); + files_.push_back(new Single(tmp.c_str())); + } + } + + void AddNGram(const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(line); + } + + template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(begin, end, line); + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + files_[offset].AddNGram(line); + } + + template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) { + files_[offset].AddNGram(begin, end, line); + } + + protected: + Singles files_; +}; + +class MultipleARPAOutput : public MultipleOutput<ARPAOutput> { + public: + MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {} + + void ReserveForCounts(std::streampos reserve) { + for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) + i->ReserveForCounts(reserve); + } + + void BeginLength(unsigned int length) { + for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) + i->BeginLength(length); + } + + void EndLength(unsigned int length) { + for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) + i->EndLength(length); + } + + void Finish() { + for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i) + i->Finish(); + } +}; + +template <class Filter, class Output> class DispatchInput { + public: + DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {} + +/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + filter_.AddNGram(begin, end, line, output_); + }*/ + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + filter_.AddNGram(ngram, line, output_); + } + + protected: + Filter &filter_; + Output &output_; +}; + +template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> { + private: + typedef DispatchInput<Filter, Output> B; + + public: + DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {} + + void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); } + void BeginLength(unsigned int length) { B::output_.BeginLength(length); } + + void EndLength(unsigned int length) { + B::filter_.Flush(); + B::output_.EndLength(length); + } + void Finish() { B::output_.Finish(); } +}; + +struct ARPAFormat { + typedef ARPAOutput Output; + typedef MultipleARPAOutput Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadARPA(in, out); + } + template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchARPAInput<Filter, Out> dispatcher(filter, output); + ReadARPA(in, dispatcher); + } +}; + +struct CountFormat { + typedef CountOutput Output; + typedef MultipleOutput<Output> Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadCount(in, out); + } + template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchInput<Filter, Out> dispatcher(filter, output); + ReadCount(in, dispatcher); + } +}; + +/* For multithreading, the buffer classes hold batches of filter inputs and + * outputs in memory. The strings get reused a lot, so keep them around + * instead of clearing each time. + */ +class InputBuffer { + public: + InputBuffer() : actual_(0) {} + + void Reserve(size_t size) { lines_.reserve(size); } + + template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + if (lines_.size() == actual_) lines_.resize(lines_.size() + 1); + // TODO avoid this copy. + std::string &copied = lines_[actual_].line; + copied.assign(line.data(), line.size()); + lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size()); + ++actual_; + } + + template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const { + for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) { + filter.AddNGram(i->ngram, i->line, output); + } + } + + void Clear() { actual_ = 0; } + bool Empty() { return actual_ == 0; } + size_t Size() { return actual_; } + + private: + struct Line { + std::string line; + StringPiece ngram; + }; + + size_t actual_; + + std::vector<Line> lines_; +}; + +class BinaryOutputBuffer { + public: + BinaryOutputBuffer() {} + + void Reserve(size_t size) { + lines_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + lines_.push_back(line); + } + + template <class Output> void Flush(Output &output) { + for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) { + output.AddNGram(*i); + } + lines_.clear(); + } + + private: + std::vector<StringPiece> lines_; +}; + +class MultipleOutputBuffer { + public: + MultipleOutputBuffer() : last_(NULL) {} + + void Reserve(size_t size) { + annotated_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + annotated_.resize(annotated_.size() + 1); + annotated_.back().line = line; + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + if ((line.data() == last_.data()) && (line.length() == last_.length())) { + annotated_.back().systems.push_back(offset); + } else { + annotated_.resize(annotated_.size() + 1); + annotated_.back().systems.push_back(offset); + annotated_.back().line = line; + last_ = line; + } + } + + template <class Output> void Flush(Output &output) { + for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) { + if (i->systems.empty()) { + output.AddNGram(i->line); + } else { + for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) { + output.SingleAddNGram(*j, i->line); + } + } + } + annotated_.clear(); + } + + private: + struct Annotated { + // If this is empty, send to all systems. + // A filter should never send to all systems and send to a single one. + std::vector<size_t> systems; + StringPiece line; + }; + + StringPiece last_; + + std::vector<Annotated> annotated_; +}; + +} // namespace lm + +#endif // LM_FILTER_FORMAT_H__ diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc new file mode 100644 index 00000000..1bef2a3f --- /dev/null +++ b/klm/lm/filter/phrase.cc @@ -0,0 +1,281 @@ +#include "lm/filter/phrase.hh" + +#include "lm/filter/format.hh" + +#include <algorithm> +#include <functional> +#include <iostream> +#include <queue> +#include <string> +#include <vector> + +#include <ctype.h> + +namespace lm { +namespace phrase { + +unsigned int ReadMultiple(std::istream &in, Substrings &out) { + bool sentence_content = false; + unsigned int sentence_id = 0; + std::vector<Hash> phrase; + std::string word; + while (in) { + char c; + // Gather a word. + while (!isspace(c = in.get()) && in) word += c; + // Treat EOF like a newline. + if (!in) c = '\n'; + // Add the word to the phrase. + if (!word.empty()) { + phrase.push_back(util::MurmurHashNative(word.data(), word.size())); + word.clear(); + } + if (c == ' ') continue; + // It's more than just a space. Close out the phrase. + if (!phrase.empty()) { + sentence_content = true; + out.AddPhrase(sentence_id, phrase.begin(), phrase.end()); + phrase.clear(); + } + if (c == '\t' || c == '\v') continue; + // It's more than a space or tab: a newline. + if (sentence_content) { + ++sentence_id; + sentence_content = false; + } + } + if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit); + return sentence_id + sentence_content; +} + +namespace detail { const StringPiece kEndSentence("</s>"); } + +namespace { + +typedef unsigned int Sentence; +typedef std::vector<Sentence> Sentences; + +class Vertex; + +class Arc { + public: + Arc() {} + + // For arcs from one vertex to another. + void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { + Set(to, intersect); + from_ = &from; + } + + /* For arcs from before the n-gram begins to somewhere in the n-gram (right + * aligned). These have no from_ vertex; it implictly matches every + * sentence. This also handles when the n-gram is a substring of a phrase. + */ + void SetRight(Vertex &to, const Sentences &complete) { + Set(to, complete); + from_ = NULL; + } + + Sentence Current() const { + return *current_; + } + + bool Empty() const { + return current_ == last_; + } + + /* When this function returns: + * If Empty() then there's nothing left from this intersection. + * + * If Current() == to then to is part of the intersection. + * + * Otherwise, Current() > to. In this case, to is not part of the + * intersection and neither is anything < Current(). To determine if + * any value >= Current() is in the intersection, call LowerBound again + * with the value. + */ + void LowerBound(const Sentence to); + + private: + void Set(Vertex &to, const Sentences &sentences); + + const Sentence *current_; + const Sentence *last_; + Vertex *from_; +}; + +struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { + bool operator()(const Arc *first, const Arc *second) const { + return first->Current() > second->Current(); + } +}; + +class Vertex { + public: + Vertex() : current_(0) {} + + Sentence Current() const { + return current_; + } + + bool Empty() const { + return incoming_.empty(); + } + + void LowerBound(const Sentence to); + + private: + friend class Arc; + + void AddIncoming(Arc *arc) { + if (!arc->Empty()) incoming_.push(arc); + } + + unsigned int current_; + std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_; +}; + +void Arc::LowerBound(const Sentence to) { + current_ = std::lower_bound(current_, last_, to); + // If *current_ > to, don't advance from_. The intervening values of + // from_ may be useful for another one of its outgoing arcs. + if (!from_ || Empty() || (Current() > to)) return; + assert(Current() == to); + from_->LowerBound(to); + if (from_->Empty()) { + current_ = last_; + return; + } + assert(from_->Current() >= to); + if (from_->Current() > to) { + current_ = std::lower_bound(current_ + 1, last_, from_->Current()); + } +} + +void Arc::Set(Vertex &to, const Sentences &sentences) { + current_ = &*sentences.begin(); + last_ = &*sentences.end(); + to.AddIncoming(this); +} + +void Vertex::LowerBound(const Sentence to) { + if (Empty()) return; + // Union lower bound. + while (true) { + Arc *top = incoming_.top(); + if (top->Current() > to) { + current_ = top->Current(); + return; + } + // If top->Current() == to, we still need to verify that's an actual + // element and not just a bound. + incoming_.pop(); + top->LowerBound(to); + if (!top->Empty()) { + incoming_.push(top); + if (top->Current() == to) { + current_ = to; + return; + } + } else if (Empty()) { + return; + } + } +} + +void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) { + assert(!hashes.empty()); + + const Hash *const first_word = &*hashes.begin(); + const Hash *const last_word = &*hashes.end() - 1; + + Hash hash = 0; + const Sentences *found; + // Phrases starting at or before the first word in the n-gram. + { + Vertex *vertex = vertices; + for (const Hash *word = first_word; ; ++word, ++vertex) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word); + // Now hash is [hashes.begin(), word]. + if (word == last_word) { + if (phrase.FindSubstring(hash, found)) + (free_arc++)->SetRight(*vertex, *found); + break; + } + if (!phrase.FindRight(hash, found)) break; + (free_arc++)->SetRight(*vertex, *found); + } + } + + // Phrases starting at the second or later word in the n-gram. + Vertex *vertex_from = vertices; + for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) { + hash = 0; + Vertex *vertex_to = vertex_from + 1; + for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) { + // Notice that word_to and vertex_to have the same index. + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to); + // Now hash covers [word_from, word_to]. + if (word_to == last_word) { + if (phrase.FindLeft(hash, found)) + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + break; + } + if (!phrase.FindPhrase(hash, found)) break; + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + } + } +} + +} // namespace + +namespace detail { + +} // namespace detail + +bool Union::Evaluate() { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return false; + if (last_vertex.Current() == lower) return true; + lower = last_vertex.Current(); + } +} + +template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return; + if (last_vertex.Current() == lower) { + output.SingleAddNGram(lower, line); + ++lower; + } else { + lower = last_vertex.Current(); + } + } +} + +template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output); +template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output); +template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output); + +} // namespace phrase +} // namespace lm diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh new file mode 100644 index 00000000..b4edff41 --- /dev/null +++ b/klm/lm/filter/phrase.hh @@ -0,0 +1,154 @@ +#ifndef LM_FILTER_PHRASE_H__ +#define LM_FILTER_PHRASE_H__ + +#include "util/murmur_hash.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <iosfwd> +#include <vector> + +#define LM_FILTER_PHRASE_METHOD(caps, lower) \ +bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\ + Table::const_iterator i(table_.find(key));\ + if (i==table_.end()) return false; \ + out = &i->second.lower; \ + return true; \ +} + +namespace lm { +namespace phrase { + +typedef uint64_t Hash; + +class Substrings { + private: + /* This is the value in a hash table where the key is a string. It indicates + * four sets of sentences: + * substring is sentences with a phrase containing the key as a substring. + * left is sentencess with a phrase that begins with the key (left aligned). + * right is sentences with a phrase that ends with the key (right aligned). + * phrase is sentences where the key is a phrase. + * Each set is encoded as a vector of sentence ids in increasing order. + */ + struct SentenceRelation { + std::vector<unsigned int> substring, left, right, phrase; + }; + /* Most of the CPU is hash table lookups, so let's not complicate it with + * vector equality comparisons. If a collision happens, the SentenceRelation + * structure will contain the union of sentence ids over the colliding strings. + * In that case, the filter will be slightly more permissive. + * The key here is the same as boost's hash of std::vector<std::string>. + */ + typedef boost::unordered_map<Hash, SentenceRelation> Table; + + public: + Substrings() {} + + /* If the string isn't a substring of any phrase, return NULL. Otherwise, + * return a pointer to std::vector<unsigned int> listing sentences with + * matching phrases. This set may be empty for Left, Right, or Phrase. + * Example: const std::vector<unsigned int> *FindSubstring(Hash key) + */ + LM_FILTER_PHRASE_METHOD(Substring, substring) + LM_FILTER_PHRASE_METHOD(Left, left) + LM_FILTER_PHRASE_METHOD(Right, right) + LM_FILTER_PHRASE_METHOD(Phrase, phrase) + +#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization + // sentence_id must be non-decreasing. Iterators are over words in the phrase. + template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { + // Iterate over all substrings. + for (Iterator start = begin; start != end; ++start) { + Hash hash = 0; + SentenceRelation *relation; + for (Iterator finish = start; finish != end; ++finish) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish); + // Now hash is of [start, finish]. + relation = &table_[hash]; + AppendSentence(relation->substring, sentence_id); + if (start == begin) AppendSentence(relation->left, sentence_id); + } + AppendSentence(relation->right, sentence_id); + if (start == begin) AppendSentence(relation->phrase, sentence_id); + } + } + + private: + void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) { + if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id); + } + + Table table_; +}; + +// Read a file with one sentence per line containing tab-delimited phrases of +// space-separated words. +unsigned int ReadMultiple(std::istream &in, Substrings &out); + +namespace detail { +extern const StringPiece kEndSentence; + +template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) { + hashes.clear(); + if (i == end) return; + // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags. + if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) { + ++i; + } + for (; i != end && (*i != kEndSentence); ++i) { + hashes.push_back(util::MurmurHashNative(i->data(), i->size())); + } +} + +} // namespace detail + +class Union { + public: + explicit Union(const Substrings &substrings) : substrings_(substrings) {} + + template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { + detail::MakeHashes(begin, end, hashes_); + return hashes_.empty() || Evaluate(); + } + + private: + bool Evaluate(); + + std::vector<Hash> hashes_; + + const Substrings &substrings_; +}; + +class Multiple { + public: + explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + + template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + detail::MakeHashes(begin, end, hashes_); + if (hashes_.empty()) { + output.AddNGram(line); + return; + } + Evaluate(line, output); + } + + template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); + } + + void Flush() const {} + + private: + template <class Output> void Evaluate(const StringPiece &line, Output &output); + + std::vector<Hash> hashes_; + + const Substrings &substrings_; +}; + +} // namespace phrase +} // namespace lm +#endif // LM_FILTER_PHRASE_H__ diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh new file mode 100644 index 00000000..e785b263 --- /dev/null +++ b/klm/lm/filter/thread.hh @@ -0,0 +1,167 @@ +#ifndef LM_FILTER_THREAD_H__ +#define LM_FILTER_THREAD_H__ + +#include "util/thread_pool.hh" + +#include <boost/utility/in_place_factory.hpp> + +#include <deque> +#include <stack> + +namespace lm { + +template <class OutputBuffer> class ThreadBatch { + public: + ThreadBatch() {} + + void Reserve(size_t size) { + input_.Reserve(size); + output_.Reserve(size); + } + + // File reading thread. + InputBuffer &Fill(uint64_t sequence) { + sequence_ = sequence; + // Why wait until now to clear instead of after output? free in the same + // thread as allocated. + input_.Clear(); + return input_; + } + + // Filter worker thread. + template <class Filter> void CallFilter(Filter &filter) { + input_.CallFilter(filter, output_); + } + + uint64_t Sequence() const { return sequence_; } + + // File writing thread. + template <class RealOutput> void Flush(RealOutput &output) { + output_.Flush(output); + } + + private: + InputBuffer input_; + OutputBuffer output_; + + uint64_t sequence_; +}; + +template <class Batch, class Filter> class FilterWorker { + public: + typedef Batch *Request; + + FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {} + + void operator()(Request request) { + request->CallFilter(filter_); + done_.Produce(request); + } + + private: + Filter filter_; + + util::PCQueue<Request> &done_; +}; + +// There should only be one OutputWorker. +template <class Batch, class Output> class OutputWorker { + public: + typedef Batch *Request; + + OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {} + + void operator()(Request request) { + assert(request->Sequence() >= base_sequence_); + // Assemble the output in order. + uint64_t pos = request->Sequence() - base_sequence_; + if (pos >= ordering_.size()) { + ordering_.resize(pos + 1, NULL); + } + ordering_[pos] = request; + while (!ordering_.empty() && ordering_.front()) { + ordering_.front()->Flush(output_); + done_.Produce(ordering_.front()); + ordering_.pop_front(); + ++base_sequence_; + } + } + + private: + Output &output_; + + util::PCQueue<Request> &done_; + + std::deque<Request> ordering_; + + uint64_t base_sequence_; +}; + +template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable { + private: + typedef ThreadBatch<OutputBuffer> Batch; + + public: + Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output) + : batch_size_(batch_size), queue_size_(queue), + batches_(queue), + to_read_(queue), + output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL), + filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL), + sequence_(0) { + for (size_t i = 0; i < queue; ++i) { + batches_[i].Reserve(batch_size); + local_read_.push(&batches_[i]); + } + NewInput(); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) { + input_->AddNGram(ngram, line, output); + if (input_->Size() == batch_size_) { + FlushInput(); + NewInput(); + } + } + + void Flush() { + FlushInput(); + while (local_read_.size() < queue_size_) { + MoveRead(); + } + NewInput(); + } + + private: + void FlushInput() { + if (input_->Empty()) return; + filter_.Produce(local_read_.top()); + local_read_.pop(); + if (local_read_.empty()) MoveRead(); + } + + void NewInput() { + input_ = &local_read_.top()->Fill(sequence_++); + } + + void MoveRead() { + local_read_.push(to_read_.Consume()); + } + + const size_t batch_size_; + const size_t queue_size_; + + std::vector<Batch> batches_; + + util::PCQueue<Batch*> to_read_; + std::stack<Batch*> local_read_; + util::ThreadPool<OutputWorker<Batch, RealOutput> > output_; + util::ThreadPool<FilterWorker<Batch, Filter> > filter_; + + uint64_t sequence_; + InputBuffer *input_; +}; + +} // namespace lm + +#endif // LM_FILTER_THREAD_H__ diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc new file mode 100644 index 00000000..7ee4e84b --- /dev/null +++ b/klm/lm/filter/vocab.cc @@ -0,0 +1,54 @@ +#include "lm/filter/vocab.hh" + +#include <istream> +#include <iostream> + +#include <ctype.h> +#include <err.h> + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) { + in.exceptions(std::istream::badbit); + std::string word; + while (in >> word) { + out.insert(word); + } +} + +namespace { +bool IsLineEnd(std::istream &in) { + int got; + do { + got = in.get(); + if (!in) return true; + if (got == '\n') return true; + } while (isspace(got)); + in.unget(); + return false; +} +}// namespace + +// Read space separated words in enter separated lines. These lines can be +// very long, so don't read an entire line at a time. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) { + in.exceptions(std::istream::badbit); + unsigned int sentence = 0; + bool used_id = false; + std::string word; + while (in >> word) { + used_id = true; + std::vector<unsigned int> &posting = out[word]; + if (posting.empty() || (posting.back() != sentence)) + posting.push_back(sentence); + if (IsLineEnd(in)) { + ++sentence; + used_id = false; + } + } + return sentence + used_id; +} + +} // namespace vocab +} // namespace lm diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh new file mode 100644 index 00000000..7f0fadaa --- /dev/null +++ b/klm/lm/filter/vocab.hh @@ -0,0 +1,133 @@ +#ifndef LM_FILTER_VOCAB_H__ +#define LM_FILTER_VOCAB_H__ + +// Vocabulary-based filters for language models. + +#include "util/multi_intersection.hh" +#include "util/string_piece.hh" +#include "util/string_piece_hash.hh" +#include "util/tokenize_piece.hh" + +#include <boost/noncopyable.hpp> +#include <boost/range/iterator_range.hpp> +#include <boost/unordered/unordered_map.hpp> +#include <boost/unordered/unordered_set.hpp> + +#include <string> +#include <vector> + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out); + +// Read one sentence vocabulary per line. Return the number of sentences. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out); + +/* Is this a special tag like <s> or <UNK>? This actually includes anything + * surrounded with < and >, which most tokenizers separate for real words, so + * this should not catch real words as it looks at a single token. + */ +inline bool IsTag(const StringPiece &value) { + // The parser should never give an empty string. + assert(!value.empty()); + return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>'); +} + +class Single { + public: + typedef boost::unordered_set<std::string> Words; + + explicit Single(const Words &vocab) : vocab_(vocab) {} + + template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { + for (Iterator i = begin; i != end; ++i) { + if (IsTag(*i)) continue; + if (FindStringPiece(vocab_, *i) == vocab_.end()) return false; + } + return true; + } + + private: + const Words &vocab_; +}; + +class Union { + public: + typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words; + + explicit Union(const Words &vocabs) : vocabs_(vocabs) {} + + template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { + sets_.clear(); + + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return false; + sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end())); + } + return (sets_.empty() || util::FirstIntersection(sets_)); + } + + private: + const Words &vocabs_; + + std::vector<boost::iterator_range<const unsigned int*> > sets_; +}; + +class Multiple { + public: + typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words; + + Multiple(const Words &vocabs) : vocabs_(vocabs) {} + + private: + // Callback from AllIntersection that does AddNGram. + template <class Output> class Callback { + public: + Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {} + + void operator()(unsigned int index) { + out_.SingleAddNGram(index, line_); + } + + private: + Output &out_; + const StringPiece &line_; + }; + + public: + template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + sets_.clear(); + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return; + sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end())); + } + if (sets_.empty()) { + output.AddNGram(line); + return; + } + + Callback<Output> cb(output, line); + util::AllIntersection(sets_, cb); + } + + template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); + } + + void Flush() const {} + + private: + const Words &vocabs_; + + std::vector<boost::iterator_range<const unsigned int*> > sets_; +}; + +} // namespace vocab +} // namespace lm + +#endif // LM_FILTER_VOCAB_H__ diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh new file mode 100644 index 00000000..90b07a08 --- /dev/null +++ b/klm/lm/filter/wrapper.hh @@ -0,0 +1,58 @@ +#ifndef LM_FILTER_WRAPPER_H__ +#define LM_FILTER_WRAPPER_H__ + +#include "util/string_piece.hh" + +#include <algorithm> +#include <string> +#include <vector> + +namespace lm { + +// Provide a single-output filter with the same interface as a +// multiple-output filter so clients code against one interface. +template <class Binary> class BinaryFilter { + public: + // Binary modes are just references (and a set) and it makes the API cleaner to copy them. + explicit BinaryFilter(Binary binary) : binary_(binary) {} + + template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + if (binary_.PassNGram(begin, end)) + output.AddNGram(line); + } + + template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output); + } + + void Flush() const {} + + private: + Binary binary_; +}; + +// Wrap another filter to pay attention only to context words +template <class FilterT> class ContextFilter { + public: + typedef FilterT Filter; + + explicit ContextFilter(Filter &backend) : backend_(backend) {} + + template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + pieces_.clear(); + // TODO: this copy could be avoided by a lookahead iterator. + std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); + backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + } + + void Flush() const {} + + private: + std::vector<StringPiece> pieces_; + + Filter backend_; +}; + +} // namespace lm + +#endif // LM_FILTER_WRAPPER_H__ |