summaryrefslogtreecommitdiff
path: root/klm/lm/filter/vocab.hh
blob: 2ee6e1f8aafb2cf77664582edacd8cde276912d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#ifndef LM_FILTER_VOCAB_H
#define LM_FILTER_VOCAB_H

// Vocabulary-based filters for language models.

#include "util/multi_intersection.hh"
#include "util/string_piece.hh"
#include "util/string_piece_hash.hh"
#include "util/tokenize_piece.hh"

#include <boost/noncopyable.hpp>
#include <boost/range/iterator_range.hpp>
#include <boost/unordered/unordered_map.hpp>
#include <boost/unordered/unordered_set.hpp>

#include <string>
#include <vector>

namespace lm {
namespace vocab {

void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);

// Read one sentence vocabulary per line.  Return the number of sentences.
unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);

/* Is this a special tag like <s> or <UNK>?  This actually includes anything
 * surrounded with < and >, which most tokenizers separate for real words, so
 * this should not catch real words as it looks at a single token.   
 */
inline bool IsTag(const StringPiece &value) {
  // The parser should never give an empty string.
  assert(!value.empty());
  return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
}

class Single {
  public:
    typedef boost::unordered_set<std::string> Words;

    explicit Single(const Words &vocab) : vocab_(vocab) {}

    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
      for (Iterator i = begin; i != end; ++i) {
        if (IsTag(*i)) continue;
        if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
      }
      return true;
    }

  private:
    const Words &vocab_;
};

class Union {
  public:
    typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;

    explicit Union(const Words &vocabs) : vocabs_(vocabs) {}

    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
      sets_.clear();

      for (Iterator i(begin); i != end; ++i) {
        if (IsTag(*i)) continue;
        Words::const_iterator found(FindStringPiece(vocabs_, *i));
        if (vocabs_.end() == found) return false;
        sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
      }
      return (sets_.empty() || util::FirstIntersection(sets_));
    }

  private:
    const Words &vocabs_;

    std::vector<boost::iterator_range<const unsigned int*> > sets_;
};

class Multiple {
  public:
    typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;

    Multiple(const Words &vocabs) : vocabs_(vocabs) {}

  private:
    // Callback from AllIntersection that does AddNGram.
    template <class Output> class Callback {
      public:
        Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}

        void operator()(unsigned int index) {
          out_.SingleAddNGram(index, line_);
        }

      private:
        Output &out_;
        const StringPiece &line_;
    };

  public:
    template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
      sets_.clear();
      for (Iterator i(begin); i != end; ++i) {
        if (IsTag(*i)) continue;
        Words::const_iterator found(FindStringPiece(vocabs_, *i));
        if (vocabs_.end() == found) return;
        sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
      }
      if (sets_.empty()) {
        output.AddNGram(line);
        return;
      }

      Callback<Output> cb(output, line);
      util::AllIntersection(sets_, cb);
    }

    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
      AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
    }

    void Flush() const {}

  private:
    const Words &vocabs_;

    std::vector<boost::iterator_range<const unsigned int*> > sets_;
};

} // namespace vocab
} // namespace lm

#endif // LM_FILTER_VOCAB_H