summaryrefslogtreecommitdiff
path: root/klm/lm/wrappers/nplm.hh
blob: b7dd4a21e9949d5fa6f09502513d6bcf8a62e7d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#ifndef LM_WRAPPERS_NPLM_H
#define LM_WRAPPERS_NPLM_H

#include "lm/facade.hh"
#include "lm/max_order.hh"
#include "util/string_piece.hh"

#include <boost/thread/tss.hpp>
#include <boost/scoped_ptr.hpp>

/* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang
 * and Victoria Fossum."  
 * http://nlg.isi.edu/software/nplm/
 */

namespace nplm {
class vocabulary;
class neuralLM;
} // namespace nplm

namespace lm {
namespace np {

class Vocabulary : public base::Vocabulary {
  public:
    Vocabulary(const nplm::vocabulary &vocab);

    ~Vocabulary();

    WordIndex Index(const std::string &str) const;

    // TODO: lobby them to support StringPiece
    WordIndex Index(const StringPiece &str) const {
      return Index(std::string(str.data(), str.size()));
    }

    lm::WordIndex NullWord() const { return null_word_; }

  private:
    const nplm::vocabulary &vocab_;

    const lm::WordIndex null_word_;
};

// Sorry for imposing my limitations on your code.
#define NPLM_MAX_ORDER 7

struct State {
  WordIndex words[NPLM_MAX_ORDER - 1];
};

class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
  private:
    typedef lm::base::ModelFacade<Model, State, Vocabulary> P;

  public:
    // Does this look like an NPLM?
    static bool Recognize(const std::string &file);

    explicit Model(const std::string &file, std::size_t cache_size = 1 << 20);

    ~Model();

    FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const;

    FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;

  private:
    boost::scoped_ptr<nplm::neuralLM> base_instance_;

    mutable boost::thread_specific_ptr<nplm::neuralLM> backend_;

    Vocabulary vocab_;

    lm::WordIndex null_word_;

    const std::size_t cache_size_;
};

} // namespace np
} // namespace lm

#endif // LM_WRAPPERS_NPLM_H