diff options
Diffstat (limited to 'klm/lm/filter/phrase_table_vocab_main.cc')
-rw-r--r-- | klm/lm/filter/phrase_table_vocab_main.cc | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/klm/lm/filter/phrase_table_vocab_main.cc b/klm/lm/filter/phrase_table_vocab_main.cc new file mode 100644 index 00000000..e0f47d89 --- /dev/null +++ b/klm/lm/filter/phrase_table_vocab_main.cc @@ -0,0 +1,165 @@ +#include "util/fake_ofstream.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/pool.hh" +#include "util/string_piece.hh" +#include "util/string_piece_hash.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> +#include <boost/unordered_set.hpp> + +#include <cstddef> +#include <vector> + +namespace { + +struct MutablePiece { + mutable StringPiece behind; + bool operator==(const MutablePiece &other) const { + return behind == other.behind; + } +}; + +std::size_t hash_value(const MutablePiece &m) { + return hash_value(m.behind); +} + +class InternString { + public: + const char *Add(StringPiece str) { + MutablePiece mut; + mut.behind = str; + std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut)); + if (res.second) { + void *mem = backing_.Allocate(str.size() + 1); + memcpy(mem, str.data(), str.size()); + static_cast<char*>(mem)[str.size()] = 0; + res.first->behind = StringPiece(static_cast<char*>(mem), str.size()); + } + return res.first->behind.data(); + } + + private: + util::Pool backing_; + boost::unordered_set<MutablePiece> strs_; +}; + +class TargetWords { + public: + void Introduce(StringPiece source) { + vocab_.resize(vocab_.size() + 1); + std::vector<unsigned int> temp(1, vocab_.size() - 1); + Add(temp, source); + } + + void Add(const std::vector<unsigned int> &sentences, StringPiece target) { + if (sentences.empty()) return; + interns_.clear(); + for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) { + interns_.push_back(intern_.Add(*i)); + } + for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) { + boost::unordered_set<const char *> &vocab = vocab_[*i]; + for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) { + vocab.insert(*j); + } + } + } + + void Print() const { + util::FakeOFStream out(1); + for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) { + for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) { + out << *j << ' '; + } + out << '\n'; + } + } + + private: + InternString intern_; + + std::vector<boost::unordered_set<const char *> > vocab_; + + // Temporary in Add. + std::vector<const char *> interns_; +}; + +class Input { + public: + explicit Input(std::size_t max_length) + : max_length_(max_length), sentence_id_(0), empty_() {} + + void AddSentence(StringPiece sentence, TargetWords &targets) { + canonical_.clear(); + starts_.clear(); + starts_.push_back(0); + for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) { + canonical_.append(i->data(), i->size()); + canonical_ += ' '; + starts_.push_back(canonical_.size()); + } + targets.Introduce(canonical_); + for (std::size_t i = 0; i < starts_.size() - 1; ++i) { + std::size_t subtract = starts_[i]; + const char *start = &canonical_[subtract]; + for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) { + map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_); + } + } + ++sentence_id_; + } + + // Assumes single space-delimited phrase with no space at the beginning or end. + const std::vector<unsigned int> &Matches(StringPiece phrase) const { + Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size())); + return i == map_.end() ? empty_ : i->second; + } + + private: + const std::size_t max_length_; + + // hash of phrase is the key, array of sentences is the value. + typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map; + Map map_; + + std::size_t sentence_id_; + + // Temporaries in AddSentence. + std::string canonical_; + std::vector<std::size_t> starts_; + + const std::vector<unsigned int> empty_; +}; + +} // namespace + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected source text on the command line" << std::endl; + return 1; + } + Input input(7); + TargetWords targets; + try { + util::FilePiece inputs(argv[1], &std::cerr); + while (true) + input.AddSentence(inputs.ReadLine(), targets); + } catch (const util::EndOfFileException &e) {} + + util::FilePiece table(0, NULL, &std::cerr); + StringPiece line; + const StringPiece pipes("|||"); + while (true) { + try { + line = table.ReadLine(); + } catch (const util::EndOfFileException &e) { break; } + util::TokenIter<util::MultiCharacter> it(line, pipes); + StringPiece source(*it); + if (!source.empty() && source[source.size() - 1] == ' ') + source.remove_suffix(1); + targets.Add(input.Matches(source), *++it); + } + targets.Print(); +} |