diff options
Diffstat (limited to 'klm/lm/filter')
-rw-r--r-- | klm/lm/filter/arpa_io.hh | 1 | ||||
-rw-r--r-- | klm/lm/filter/count_io.hh | 12 | ||||
-rw-r--r-- | klm/lm/filter/filter_main.cc | 155 | ||||
-rw-r--r-- | klm/lm/filter/format.hh | 2 | ||||
-rw-r--r-- | klm/lm/filter/phrase.cc | 59 | ||||
-rw-r--r-- | klm/lm/filter/phrase.hh | 42 | ||||
-rw-r--r-- | klm/lm/filter/phrase_table_vocab_main.cc | 165 | ||||
-rw-r--r-- | klm/lm/filter/vocab.cc | 1 | ||||
-rw-r--r-- | klm/lm/filter/wrapper.hh | 10 |
9 files changed, 318 insertions, 129 deletions
diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh index 5b31620b..602b5b31 100644 --- a/klm/lm/filter/arpa_io.hh +++ b/klm/lm/filter/arpa_io.hh @@ -14,7 +14,6 @@ #include <string> #include <vector> -#include <err.h> #include <string.h> #include <stdint.h> diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh index 97c0fa25..d992026f 100644 --- a/klm/lm/filter/count_io.hh +++ b/klm/lm/filter/count_io.hh @@ -5,20 +5,18 @@ #include <iostream> #include <string> -#include <err.h> - +#include "util/fake_ofstream.hh" +#include "util/file.hh" #include "util/file_piece.hh" namespace lm { class CountOutput : boost::noncopyable { public: - explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {} void AddNGram(const StringPiece &line) { - if (!(file_ << line << '\n')) { - err(3, "Writing counts file failed"); - } + file_ << line << '\n'; } template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { @@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable { } private: - std::fstream file_; + util::FakeOFStream file_; }; class CountBatch { diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc index 1736bc40..82fdc1ef 100644 --- a/klm/lm/filter/filter_main.cc +++ b/klm/lm/filter/filter_main.cc @@ -6,6 +6,7 @@ #endif #include "lm/filter/vocab.hh" #include "lm/filter/wrapper.hh" +#include "util/exception.hh" #include "util/file_piece.hh" #include <boost/ptr_container/ptr_vector.hpp> @@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr } // namespace lm int main(int argc, char *argv[]) { - if (argc < 4) { - lm::DisplayHelp(argv[0]); - return 1; - } + try { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } - // I used to have boost::program_options, but some users didn't want to compile boost. - lm::Config config; - config.mode = lm::MODE_UNSET; - for (int i = 1; i < argc - 2; ++i) { - const char *str = argv[i]; - if (!std::strcmp(str, "copy")) { - config.mode = lm::MODE_COPY; - } else if (!std::strcmp(str, "single")) { - config.mode = lm::MODE_SINGLE; - } else if (!std::strcmp(str, "multiple")) { - config.mode = lm::MODE_MULTIPLE; - } else if (!std::strcmp(str, "union")) { - config.mode = lm::MODE_UNION; - } else if (!std::strcmp(str, "phrase")) { - config.phrase = true; - } else if (!std::strcmp(str, "context")) { - config.context = true; - } else if (!std::strcmp(str, "arpa")) { - config.format = lm::FORMAT_ARPA; - } else if (!std::strcmp(str, "raw")) { - config.format = lm::FORMAT_COUNT; + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + config.mode = lm::MODE_UNSET; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + config.mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + config.mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + config.mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + config.mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; #ifndef NTHREAD - } else if (!std::strncmp(str, "threads:", 8)) { - config.threads = boost::lexical_cast<size_t>(str + 8); - if (!config.threads) { - std::cerr << "Specify at least one thread." << std::endl; + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast<size_t>(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast<size_t>(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); return 1; } - } else if (!std::strncmp(str, "batch_size:", 11)) { - config.batch_size = boost::lexical_cast<size_t>(str + 11); - if (config.batch_size < 5000) { - std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; - if (!config.batch_size) return 1; - } -#endif - } else { + } + + if (config.mode == lm::MODE_UNSET) { lm::DisplayHelp(argv[0]); return 1; } - } - - if (config.mode == lm::MODE_UNSET) { - lm::DisplayHelp(argv[0]); - return 1; - } - if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { - std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; - return 1; - } + if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } - bool cmd_is_model = true; - const char *cmd_input = argv[argc - 2]; - if (!strncmp(cmd_input, "vocab:", 6)) { - cmd_is_model = false; - cmd_input += 6; - } else if (!strncmp(cmd_input, "model:", 6)) { - cmd_input += 6; - } else if (strchr(cmd_input, ':')) { - errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); - } else { - std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; - } - std::ifstream cmd_file; - std::istream *vocab; - if (cmd_is_model) { - vocab = &std::cin; - } else { - cmd_file.open(cmd_input, std::ios::in); - if (!cmd_file) { - err(2, "Could not open input file %s", cmd_input); + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; + return 1; + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); + vocab = &cmd_file; } - vocab = &cmd_file; - } - util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); - if (config.format == lm::FORMAT_ARPA) { - lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); - } else if (config.format == lm::FORMAT_COUNT) { - lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); + } + return 0; + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; } - return 0; } diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh index 7f945b0d..7d8c28db 100644 --- a/klm/lm/filter/format.hh +++ b/klm/lm/filter/format.hh @@ -1,5 +1,5 @@ #ifndef LM_FILTER_FORMAT_H__ -#define LM_FITLER_FORMAT_H__ +#define LM_FILTER_FORMAT_H__ #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc index 1bef2a3f..e2946b14 100644 --- a/klm/lm/filter/phrase.cc +++ b/klm/lm/filter/phrase.cc @@ -48,21 +48,21 @@ unsigned int ReadMultiple(std::istream &in, Substrings &out) { return sentence_id + sentence_content; } -namespace detail { const StringPiece kEndSentence("</s>"); } - namespace { - typedef unsigned int Sentence; typedef std::vector<Sentence> Sentences; +} // namespace -class Vertex; +namespace detail { + +const StringPiece kEndSentence("</s>"); class Arc { public: Arc() {} // For arcs from one vertex to another. - void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { + void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) { Set(to, intersect); from_ = &from; } @@ -71,7 +71,7 @@ class Arc { * aligned). These have no from_ vertex; it implictly matches every * sentence. This also handles when the n-gram is a substring of a phrase. */ - void SetRight(Vertex &to, const Sentences &complete) { + void SetRight(detail::Vertex &to, const Sentences &complete) { Set(to, complete); from_ = NULL; } @@ -97,11 +97,11 @@ class Arc { void LowerBound(const Sentence to); private: - void Set(Vertex &to, const Sentences &sentences); + void Set(detail::Vertex &to, const Sentences &sentences); const Sentence *current_; const Sentence *last_; - Vertex *from_; + detail::Vertex *from_; }; struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { @@ -183,7 +183,13 @@ void Vertex::LowerBound(const Sentence to) { } } -void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) { +} // namespace detail + +namespace { + +void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) { + using detail::Vertex; + using detail::Arc; assert(!hashes.empty()); const Hash *const first_word = &*hashes.begin(); @@ -231,17 +237,29 @@ void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Verte namespace detail { -} // namespace detail +// Here instead of header due to forward declaration. +ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {} -bool Union::Evaluate() { +// Rest of the variables are temporaries anyway +ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {} + +ConditionCommon::~ConditionCommon() {} + +detail::Vertex &ConditionCommon::MakeGraph() { assert(!hashes_.empty()); - // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. - Vertex vertices[hashes_.size()]; + vertices_.clear(); + vertices_.resize(hashes_.size()); + arcs_.clear(); // One for every substring. - Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; - BuildGraph(substrings_, hashes_, vertices, arcs); - Vertex &last_vertex = vertices[hashes_.size() - 1]; + arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2); + BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin()); + return vertices_[hashes_.size() - 1]; +} + +} // namespace detail +bool Union::Evaluate() { + detail::Vertex &last_vertex = MakeGraph(); unsigned int lower = 0; while (true) { last_vertex.LowerBound(lower); @@ -252,14 +270,7 @@ bool Union::Evaluate() { } template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { - assert(!hashes_.empty()); - // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. - Vertex vertices[hashes_.size()]; - // One for every substring. - Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; - BuildGraph(substrings_, hashes_, vertices, arcs); - Vertex &last_vertex = vertices[hashes_.size() - 1]; - + detail::Vertex &last_vertex = MakeGraph(); unsigned int lower = 0; while (true) { last_vertex.LowerBound(lower); diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index b4edff41..e8e85835 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -103,11 +103,33 @@ template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std:: } } +class Vertex; +class Arc; + +class ConditionCommon { + protected: + ConditionCommon(const Substrings &substrings); + ConditionCommon(const ConditionCommon &from); + + ~ConditionCommon(); + + detail::Vertex &MakeGraph(); + + // Temporaries in PassNGram and Evaluate to avoid reallocation. + std::vector<Hash> hashes_; + + private: + std::vector<detail::Vertex> vertices_; + std::vector<detail::Arc> arcs_; + + const Substrings &substrings_; +}; + } // namespace detail -class Union { +class Union : public detail::ConditionCommon { public: - explicit Union(const Substrings &substrings) : substrings_(substrings) {} + explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) { detail::MakeHashes(begin, end, hashes_); @@ -116,23 +138,19 @@ class Union { private: bool Evaluate(); - - std::vector<Hash> hashes_; - - const Substrings &substrings_; }; -class Multiple { +class Multiple : public detail::ConditionCommon { public: - explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { detail::MakeHashes(begin, end, hashes_); if (hashes_.empty()) { output.AddNGram(line); - return; + } else { + Evaluate(line, output); } - Evaluate(line, output); } template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { @@ -143,10 +161,6 @@ class Multiple { private: template <class Output> void Evaluate(const StringPiece &line, Output &output); - - std::vector<Hash> hashes_; - - const Substrings &substrings_; }; } // namespace phrase diff --git a/klm/lm/filter/phrase_table_vocab_main.cc b/klm/lm/filter/phrase_table_vocab_main.cc new file mode 100644 index 00000000..e0f47d89 --- /dev/null +++ b/klm/lm/filter/phrase_table_vocab_main.cc @@ -0,0 +1,165 @@ +#include "util/fake_ofstream.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/pool.hh" +#include "util/string_piece.hh" +#include "util/string_piece_hash.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> +#include <boost/unordered_set.hpp> + +#include <cstddef> +#include <vector> + +namespace { + +struct MutablePiece { + mutable StringPiece behind; + bool operator==(const MutablePiece &other) const { + return behind == other.behind; + } +}; + +std::size_t hash_value(const MutablePiece &m) { + return hash_value(m.behind); +} + +class InternString { + public: + const char *Add(StringPiece str) { + MutablePiece mut; + mut.behind = str; + std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut)); + if (res.second) { + void *mem = backing_.Allocate(str.size() + 1); + memcpy(mem, str.data(), str.size()); + static_cast<char*>(mem)[str.size()] = 0; + res.first->behind = StringPiece(static_cast<char*>(mem), str.size()); + } + return res.first->behind.data(); + } + + private: + util::Pool backing_; + boost::unordered_set<MutablePiece> strs_; +}; + +class TargetWords { + public: + void Introduce(StringPiece source) { + vocab_.resize(vocab_.size() + 1); + std::vector<unsigned int> temp(1, vocab_.size() - 1); + Add(temp, source); + } + + void Add(const std::vector<unsigned int> &sentences, StringPiece target) { + if (sentences.empty()) return; + interns_.clear(); + for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) { + interns_.push_back(intern_.Add(*i)); + } + for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) { + boost::unordered_set<const char *> &vocab = vocab_[*i]; + for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) { + vocab.insert(*j); + } + } + } + + void Print() const { + util::FakeOFStream out(1); + for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) { + for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) { + out << *j << ' '; + } + out << '\n'; + } + } + + private: + InternString intern_; + + std::vector<boost::unordered_set<const char *> > vocab_; + + // Temporary in Add. + std::vector<const char *> interns_; +}; + +class Input { + public: + explicit Input(std::size_t max_length) + : max_length_(max_length), sentence_id_(0), empty_() {} + + void AddSentence(StringPiece sentence, TargetWords &targets) { + canonical_.clear(); + starts_.clear(); + starts_.push_back(0); + for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) { + canonical_.append(i->data(), i->size()); + canonical_ += ' '; + starts_.push_back(canonical_.size()); + } + targets.Introduce(canonical_); + for (std::size_t i = 0; i < starts_.size() - 1; ++i) { + std::size_t subtract = starts_[i]; + const char *start = &canonical_[subtract]; + for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) { + map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_); + } + } + ++sentence_id_; + } + + // Assumes single space-delimited phrase with no space at the beginning or end. + const std::vector<unsigned int> &Matches(StringPiece phrase) const { + Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size())); + return i == map_.end() ? empty_ : i->second; + } + + private: + const std::size_t max_length_; + + // hash of phrase is the key, array of sentences is the value. + typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map; + Map map_; + + std::size_t sentence_id_; + + // Temporaries in AddSentence. + std::string canonical_; + std::vector<std::size_t> starts_; + + const std::vector<unsigned int> empty_; +}; + +} // namespace + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected source text on the command line" << std::endl; + return 1; + } + Input input(7); + TargetWords targets; + try { + util::FilePiece inputs(argv[1], &std::cerr); + while (true) + input.AddSentence(inputs.ReadLine(), targets); + } catch (const util::EndOfFileException &e) {} + + util::FilePiece table(0, NULL, &std::cerr); + StringPiece line; + const StringPiece pipes("|||"); + while (true) { + try { + line = table.ReadLine(); + } catch (const util::EndOfFileException &e) { break; } + util::TokenIter<util::MultiCharacter> it(line, pipes); + StringPiece source(*it); + if (!source.empty() && source[source.size() - 1] == ' ') + source.remove_suffix(1); + targets.Add(input.Matches(source), *++it); + } + targets.Print(); +} diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc index 7ee4e84b..011ab599 100644 --- a/klm/lm/filter/vocab.cc +++ b/klm/lm/filter/vocab.cc @@ -4,7 +4,6 @@ #include <iostream> #include <ctype.h> -#include <err.h> namespace lm { namespace vocab { diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh index 90b07a08..eb657501 100644 --- a/klm/lm/filter/wrapper.hh +++ b/klm/lm/filter/wrapper.hh @@ -39,17 +39,15 @@ template <class FilterT> class ContextFilter { explicit ContextFilter(Filter &backend) : backend_(backend) {} template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { - pieces_.clear(); - // TODO: this copy could be avoided by a lookahead iterator. - std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); - backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + // Find beginning of string or last space. + const char *last_space; + for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {} + backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output); } void Flush() const {} private: - std::vector<StringPiece> pieces_; - Filter backend_; }; |