path: root/klm/lm/filter/
diff options
Diffstat (limited to 'klm/lm/filter/')
1 files changed, 165 insertions, 0 deletions
diff --git a/klm/lm/filter/ b/klm/lm/filter/
new file mode 100644
index 00000000..e0f47d89
--- /dev/null
+++ b/klm/lm/filter/
@@ -0,0 +1,165 @@
+#include "util/fake_ofstream.hh"
+#include "util/file_piece.hh"
+#include "util/murmur_hash.hh"
+#include "util/pool.hh"
+#include "util/string_piece.hh"
+#include "util/string_piece_hash.hh"
+#include "util/tokenize_piece.hh"
+#include <boost/unordered_map.hpp>
+#include <boost/unordered_set.hpp>
+#include <cstddef>
+#include <vector>
+namespace {
+struct MutablePiece {
+ mutable StringPiece behind;
+ bool operator==(const MutablePiece &other) const {
+ return behind == other.behind;
+ }
+std::size_t hash_value(const MutablePiece &m) {
+ return hash_value(m.behind);
+class InternString {
+ public:
+ const char *Add(StringPiece str) {
+ MutablePiece mut;
+ mut.behind = str;
+ std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut));
+ if (res.second) {
+ void *mem = backing_.Allocate(str.size() + 1);
+ memcpy(mem,, str.size());
+ static_cast<char*>(mem)[str.size()] = 0;
+ res.first->behind = StringPiece(static_cast<char*>(mem), str.size());
+ }
+ return res.first->;
+ }
+ private:
+ util::Pool backing_;
+ boost::unordered_set<MutablePiece> strs_;
+class TargetWords {
+ public:
+ void Introduce(StringPiece source) {
+ vocab_.resize(vocab_.size() + 1);
+ std::vector<unsigned int> temp(1, vocab_.size() - 1);
+ Add(temp, source);
+ }
+ void Add(const std::vector<unsigned int> &sentences, StringPiece target) {
+ if (sentences.empty()) return;
+ interns_.clear();
+ for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) {
+ interns_.push_back(intern_.Add(*i));
+ }
+ for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) {
+ boost::unordered_set<const char *> &vocab = vocab_[*i];
+ for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) {
+ vocab.insert(*j);
+ }
+ }
+ }
+ void Print() const {
+ util::FakeOFStream out(1);
+ for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) {
+ for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) {
+ out << *j << ' ';
+ }
+ out << '\n';
+ }
+ }
+ private:
+ InternString intern_;
+ std::vector<boost::unordered_set<const char *> > vocab_;
+ // Temporary in Add.
+ std::vector<const char *> interns_;
+class Input {
+ public:
+ explicit Input(std::size_t max_length)
+ : max_length_(max_length), sentence_id_(0), empty_() {}
+ void AddSentence(StringPiece sentence, TargetWords &targets) {
+ canonical_.clear();
+ starts_.clear();
+ starts_.push_back(0);
+ for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) {
+ canonical_.append(i->data(), i->size());
+ canonical_ += ' ';
+ starts_.push_back(canonical_.size());
+ }
+ targets.Introduce(canonical_);
+ for (std::size_t i = 0; i < starts_.size() - 1; ++i) {
+ std::size_t subtract = starts_[i];
+ const char *start = &canonical_[subtract];
+ for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) {
+ map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_);
+ }
+ }
+ ++sentence_id_;
+ }
+ // Assumes single space-delimited phrase with no space at the beginning or end.
+ const std::vector<unsigned int> &Matches(StringPiece phrase) const {
+ Map::const_iterator i = map_.find(util::MurmurHash64A(, phrase.size()));
+ return i == map_.end() ? empty_ : i->second;
+ }
+ private:
+ const std::size_t max_length_;
+ // hash of phrase is the key, array of sentences is the value.
+ typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map;
+ Map map_;
+ std::size_t sentence_id_;
+ // Temporaries in AddSentence.
+ std::string canonical_;
+ std::vector<std::size_t> starts_;
+ const std::vector<unsigned int> empty_;
+} // namespace
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected source text on the command line" << std::endl;
+ return 1;
+ }
+ Input input(7);
+ TargetWords targets;
+ try {
+ util::FilePiece inputs(argv[1], &std::cerr);
+ while (true)
+ input.AddSentence(inputs.ReadLine(), targets);
+ } catch (const util::EndOfFileException &e) {}
+ util::FilePiece table(0, NULL, &std::cerr);
+ StringPiece line;
+ const StringPiece pipes("|||");
+ while (true) {
+ try {
+ line = table.ReadLine();
+ } catch (const util::EndOfFileException &e) { break; }
+ util::TokenIter<util::MultiCharacter> it(line, pipes);
+ StringPiece source(*it);
+ if (!source.empty() && source[source.size() - 1] == ' ')
+ source.remove_suffix(1);
+ targets.Add(input.Matches(source), *++it);
+ }
+ targets.Print();