summaryrefslogtreecommitdiff
path: root/klm/lm/filter/vocab.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/filter/vocab.hh')
-rw-r--r--klm/lm/filter/vocab.hh132
1 files changed, 132 insertions, 0 deletions
diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh
new file mode 100644
index 00000000..e2b6adff
--- /dev/null
+++ b/klm/lm/filter/vocab.hh
@@ -0,0 +1,132 @@
+#ifndef LM_FILTER_VOCAB_H__
+#define LM_FILTER_VOCAB_H__
+
+// Vocabulary-based filters for language models.
+
+#include "util/multi_intersection.hh"
+#include "util/string_piece.hh"
+#include "util/tokenize_piece.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/range/iterator_range.hpp>
+#include <boost/unordered/unordered_map.hpp>
+#include <boost/unordered/unordered_set.hpp>
+
+#include <string>
+#include <vector>
+
+namespace lm {
+namespace vocab {
+
+void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
+
+// Read one sentence vocabulary per line. Return the number of sentences.
+unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
+
+/* Is this a special tag like <s> or <UNK>? This actually includes anything
+ * surrounded with < and >, which most tokenizers separate for real words, so
+ * this should not catch real words as it looks at a single token.
+ */
+inline bool IsTag(const StringPiece &value) {
+ // The parser should never give an empty string.
+ assert(!value.empty());
+ return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
+}
+
+class Single {
+ public:
+ typedef boost::unordered_set<std::string> Words;
+
+ explicit Single(const Words &vocab) : vocab_(vocab) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ for (Iterator i = begin; i != end; ++i) {
+ if (IsTag(*i)) continue;
+ if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
+ }
+ return true;
+ }
+
+ private:
+ const Words &vocab_;
+};
+
+class Union {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
+
+ template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
+ sets_.clear();
+
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return false;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ return (sets_.empty() || util::FirstIntersection(sets_));
+ }
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+class Multiple {
+ public:
+ typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
+
+ Multiple(const Words &vocabs) : vocabs_(vocabs) {}
+
+ private:
+ // Callback from AllIntersection that does AddNGram.
+ template <class Output> class Callback {
+ public:
+ Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
+
+ void operator()(unsigned int index) {
+ out_.SingleAddNGram(index, line_);
+ }
+
+ private:
+ Output &out_;
+ const StringPiece &line_;
+ };
+
+ public:
+ template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
+ sets_.clear();
+ for (Iterator i(begin); i != end; ++i) {
+ if (IsTag(*i)) continue;
+ Words::const_iterator found(FindStringPiece(vocabs_, *i));
+ if (vocabs_.end() == found) return;
+ sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
+ }
+ if (sets_.empty()) {
+ output.AddNGram(line);
+ return;
+ }
+
+ Callback<Output> cb(output, line);
+ util::AllIntersection(sets_, cb);
+ }
+
+ template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
+ AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
+ }
+
+ void Flush() const {}
+
+ private:
+ const Words &vocabs_;
+
+ std::vector<boost::iterator_range<const unsigned int*> > sets_;
+};
+
+} // namespace vocab
+} // namespace lm
+
+#endif // LM_FILTER_VOCAB_H__