summaryrefslogtreecommitdiff
path: root/klm/lm/sri.cc
blob: 7bd23d76e1ae849922785f5df65fc6920e30805b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "lm/exception.hh"
#include "lm/sri.hh"

#include <Ngram.h>
#include <Vocab.h>

#include <errno.h>

namespace lm {
namespace sri {

Vocabulary::Vocabulary() : sri_(new Vocab) {}

Vocabulary::~Vocabulary() {}

WordIndex Vocabulary::Index(const char *str) const {
  WordIndex ret = sri_->getIndex(str);
  // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(.
  if (ret == Vocab_None) {
    return not_found_;
  } else {
    return ret;
  }
}

const char *Vocabulary::Word(WordIndex index) const {
  return sri_->getWord(index);
}

void Vocabulary::FinishedLoading() {
  SetSpecial(
    sri_->ssIndex(),
    sri_->seIndex(),
    sri_->unkIndex(),
    sri_->highIndex() + 1);
}

namespace {
Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) {
  sri_vocab.unkIsWord() = true;
  std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length));
  File file(file_name, "r");
  errno = 0;
  if (!ret->read(file)) {
    UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed.");
  }
  return ret.release();
}
} // namespace

Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) {
  if (!sri_->setorder()) {
    UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0.");
  }
  vocab_.FinishedLoading();
  State begin_state = State();
  begin_state.valid_length_ = 1;
  if (kMaxOrder > 1) {
    begin_state.history_[0] = vocab_.BeginSentence();
    if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None;
  }
  State null_state = State();
  null_state.valid_length_ = 0;
  if (kMaxOrder > 1) null_state.history_[0] = Vocab_None;
  Init(begin_state, null_state, vocab_, sri_->setorder());
  not_found_ = vocab_.NotFound();
}

Model::~Model() {}

namespace {

/* Argh SRI's wordProb knows the ngram length but doesn't return it.  One more
 * reason you should use my model.  */
// TODO(stolcke): fix SRILM so I don't have to do this.   
unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) {
  unsigned int out_length = 0;
  // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0.
  model.contextID(new_word, const_history, out_length);
  return out_length + 1;
}

} // namespace

FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
  // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h.
  const SRIVocabIndex *const_history;
  SRIVocabIndex local_history[Order()];
  if (in_state.valid_length_ < kMaxOrder - 1) {
    const_history = in_state.history_;
  } else {
    std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history);
    local_history[in_state.valid_length_] = Vocab_None;
    const_history = local_history;
  }
  FullScoreReturn ret;
  if (new_word != not_found_) {
    ret.ngram_length = MatchedLength(*sri_, new_word, const_history);
    out_state.history_[0] = new_word;
    out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1);
    std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1);
    if (out_state.valid_length_ < kMaxOrder - 1) {
      out_state.history_[out_state.valid_length_] = Vocab_None;
    }
  } else {
    ret.ngram_length = 0;
    if (kMaxOrder > 1) out_state.history_[0] = Vocab_None;
    out_state.valid_length_ = 0;
  }
  ret.prob = sri_->wordProb(new_word, const_history);
  return ret;
}

} // namespace sri
} // namespace lm