summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-11-07 18:10:00 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-11-07 18:10:00 -0500
commitc2b05499ffa82cfadc668e140b8f96ab43b1c715 (patch)
tree696f234835b7758bbb6f6b528d6bdbef1f6193e5 /klm/lm
parentbcda3258ab35cba2f71e28e1c93863958f5aca8b (diff)
parentbdd7fe7b513ade0b979fc050766e375044e84e86 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/bhiksha.hh5
-rw-r--r--klm/lm/build_binary.cc2
-rw-r--r--klm/lm/config.hh6
-rw-r--r--klm/lm/enumerate_vocab.hh2
-rw-r--r--klm/lm/left.hh43
-rw-r--r--klm/lm/model.cc4
-rw-r--r--klm/lm/read_arpa.cc2
-rw-r--r--klm/lm/search_hashed.cc16
-rw-r--r--klm/lm/search_trie.cc3
-rw-r--r--klm/lm/sri.cc108
-rw-r--r--klm/lm/sri.hh102
-rw-r--r--klm/lm/vocab.cc1
-rw-r--r--klm/lm/vocab.hh3
13 files changed, 53 insertions, 244 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index bc705959..3df43dda 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -10,6 +10,9 @@
* Currently only used for next pointers.
*/
+#ifndef LM_BHIKSHA__
+#define LM_BHIKSHA__
+
#include <inttypes.h>
#include <assert.h>
@@ -108,3 +111,5 @@ class ArrayBhiksha {
} // namespace trie
} // namespace ngram
} // namespace lm
+
+#endif // LM_BHIKSHA__
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index b7aee4de..fdb62a71 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -15,7 +15,7 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 227b8512..8564661b 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -8,10 +8,12 @@
/* Configuration for ngram model. Separate header to reduce pollution. */
-namespace lm { namespace ngram {
-
+namespace lm {
+
class EnumerateVocab;
+namespace ngram {
+
struct Config {
// EFFECTIVE FOR BOTH ARPA AND BINARY READS
diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh
index e734316b..27263621 100644
--- a/klm/lm/enumerate_vocab.hh
+++ b/klm/lm/enumerate_vocab.hh
@@ -5,7 +5,6 @@
#include "util/string_piece.hh"
namespace lm {
-namespace ngram {
/* If you need the actual strings in the vocabulary, inherit from this class
* and implement Add. Then put a pointer in Config.enumerate_vocab; it does
@@ -23,7 +22,6 @@ class EnumerateVocab {
EnumerateVocab() {}
};
-} // namespace ngram
} // namespace lm
#endif // LM_ENUMERATE_VOCAB__
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index bb3f5539..41f71f84 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -175,22 +175,14 @@ template <class M> class RuleScore {
float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1];
float *back = backoffs, *back2 = backoffs2;
- unsigned char next_use;
- FullScoreReturn ret;
- ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use));
- if (!next_use) {
- left_done_ = true;
- out_.right = in.right;
- return;
- }
- unsigned char extend_length = 2;
- for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) {
- ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use));
- if (!next_use) {
- left_done_ = true;
- out_.right = in.right;
- return;
- }
+ unsigned char next_use = out_.right.length;
+
+ // First word
+ if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return;
+
+ // Words after the first, so extending a bigram to begin with
+ for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
+ if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
std::swap(back, back2);
}
@@ -226,6 +218,25 @@ template <class M> class RuleScore {
}
private:
+ bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
+ ProcessRet(model_.ExtendLeft(
+ out_.right.words, out_.right.words + next_use, // Words to extend into
+ back_in, // Backoffs to use
+ in.left.pointers[extend_length - 1], extend_length, // Words to be extended
+ back_out, // Backoffs for the next score
+ next_use)); // Length of n-gram to use in next scoring.
+ if (next_use != out_.right.length) {
+ left_done_ = true;
+ if (!next_use) {
+ out_.right = in.right;
+ // Early exit.
+ return true;
+ }
+ }
+ // Continue scoring.
+ return false;
+ }
+
void ProcessRet(const FullScoreReturn &ret) {
prob_ += ret.prob;
if (left_done_) return;
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 25f1ab7c..e4c1ec1d 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -91,8 +91,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
- if (ret.ngram_length - 1 < in_state.length) {
- ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob);
+ for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
+ ret.prob += *i;
}
return ret;
}
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 455bc4ba..dce73f77 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -38,6 +38,8 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
}
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
+ UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?");
+ UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst.");
UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");
}
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 334adf12..247832b0 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -87,14 +87,14 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
ReadNGramHeader(f, n);
// vocab ids of words in reverse order
- WordIndex vocab_ids[n];
- uint64_t keys[n - 1];
+ std::vector<WordIndex> vocab_ids(n);
+ std::vector<uint64_t> keys(n-1);
typename Store::Packing::Value value;
typename Middle::MutableIterator found;
for (size_t i = 0; i < count; ++i) {
- ReadNGram(f, n, vocab, vocab_ids, value, warn);
+ ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn);
- keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]);
+ keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
@@ -106,9 +106,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
util::FloatEnc fix_prob;
for (lower = n - 3; ; --lower) {
if (lower == -1) {
- fix_prob.f = unigrams[vocab_ids[0]].prob;
+ fix_prob.f = unigrams[vocab_ids.front()].prob;
fix_prob.i &= ~util::kSignBit;
- unigrams[vocab_ids[0]].prob = fix_prob.f;
+ unigrams[vocab_ids.front()].prob = fix_prob.f;
break;
}
if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
@@ -120,8 +120,8 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
break;
}
}
- if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle);
- activate(vocab_ids, n);
+ if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, &*keys.begin(), &*vocab_ids.begin(), unigrams, middle);
+ activate(&*vocab_ids.begin(), n);
}
store.FinishedInserting();
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 5d8c70db..4bd3f4ee 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -358,6 +358,7 @@ template <class Doing> class BlankManager {
// Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
+ *pre = *cur;
been_length_ = length;
}
@@ -493,7 +494,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
util::scoped_FILE unigram_file;
{
std::string name(file_prefix + "unigrams");
- unigram_file.reset(OpenOrThrow(name.c_str(), "r"));
+ unigram_file.reset(OpenOrThrow(name.c_str(), "r+"));
util::RemoveOrThrow(name.c_str());
}
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc
deleted file mode 100644
index 825f699b..00000000
--- a/klm/lm/sri.cc
+++ /dev/null
@@ -1,108 +0,0 @@
-#include "lm/lm_exception.hh"
-#include "lm/sri.hh"
-
-#include <Ngram.h>
-#include <Vocab.h>
-
-#include <errno.h>
-
-namespace lm {
-namespace sri {
-
-Vocabulary::Vocabulary() : sri_(new Vocab) {}
-
-Vocabulary::~Vocabulary() {}
-
-WordIndex Vocabulary::Index(const char *str) const {
- WordIndex ret = sri_->getIndex(str);
- // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(.
- if (ret == Vocab_None) {
- return not_found_;
- } else {
- return ret;
- }
-}
-
-const char *Vocabulary::Word(WordIndex index) const {
- return sri_->getWord(index);
-}
-
-void Vocabulary::FinishedLoading() {
- SetSpecial(
- sri_->ssIndex(),
- sri_->seIndex(),
- sri_->unkIndex());
-}
-
-namespace {
-Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) {
- sri_vocab.unkIsWord() = true;
- std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length));
- File file(file_name, "r");
- errno = 0;
- if (!ret->read(file)) {
- UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed.");
- }
- return ret.release();
-}
-} // namespace
-
-Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) {
- if (!sri_->setorder()) {
- UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0.");
- }
- vocab_.FinishedLoading();
- State begin_state = State();
- begin_state.valid_length_ = 1;
- if (kMaxOrder > 1) {
- begin_state.history_[0] = vocab_.BeginSentence();
- if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None;
- }
- State null_state = State();
- null_state.valid_length_ = 0;
- if (kMaxOrder > 1) null_state.history_[0] = Vocab_None;
- Init(begin_state, null_state, vocab_, sri_->setorder());
- not_found_ = vocab_.NotFound();
-}
-
-Model::~Model() {}
-
-namespace {
-
-/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more
- * reason you should use my model. */
-// TODO(stolcke): fix SRILM so I don't have to do this.
-unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) {
- unsigned int out_length = 0;
- // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0.
- model.contextID(new_word, const_history, out_length);
- return out_length + 1;
-}
-
-} // namespace
-
-FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
- // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h.
- const SRIVocabIndex *const_history;
- SRIVocabIndex local_history[Order()];
- if (in_state.valid_length_ < kMaxOrder - 1) {
- const_history = in_state.history_;
- } else {
- std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history);
- local_history[in_state.valid_length_] = Vocab_None;
- const_history = local_history;
- }
- FullScoreReturn ret;
- ret.ngram_length = MatchedLength(*sri_, new_word, const_history);
- out_state.history_[0] = new_word;
- out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1);
- std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1);
- if (out_state.valid_length_ < kMaxOrder - 1) {
- out_state.history_[out_state.valid_length_] = Vocab_None;
- }
- ret.prob = sri_->wordProb(new_word, const_history);
- return ret;
-}
-
-} // namespace sri
-} // namespace lm
diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh
deleted file mode 100644
index b57e9b73..00000000
--- a/klm/lm/sri.hh
+++ /dev/null
@@ -1,102 +0,0 @@
-#ifndef LM_SRI__
-#define LM_SRI__
-
-#include "lm/facade.hh"
-#include "util/murmur_hash.hh"
-
-#include <cmath>
-#include <exception>
-#include <memory>
-
-class Ngram;
-class Vocab;
-
-/* The ngram length reported uses some random API I found and may be wrong.
- *
- * See ngram, which should return equivalent results.
- */
-
-namespace lm {
-namespace sri {
-
-static const unsigned int kMaxOrder = 6;
-
-/* This should match VocabIndex found in SRI's Vocab.h
- * The reason I define this here independently is that SRI's headers
- * pollute and increase compile time.
- * It's difficult to extract this from their header and anyway would
- * break packaging.
- * If these differ there will be a compiler error in ActuallyCall.
- */
-typedef unsigned int SRIVocabIndex;
-
-class State {
- public:
- // You shouldn't need to touch these, but they're public so State will be a POD.
- // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None.
- SRIVocabIndex history_[kMaxOrder - 1];
- unsigned char valid_length_;
-};
-
-inline bool operator==(const State &left, const State &right) {
- if (left.valid_length_ != right.valid_length_) {
- return false;
- }
- for (const SRIVocabIndex *l = left.history_, *r = right.history_;
- l != left.history_ + left.valid_length_;
- ++l, ++r) {
- if (*l != *r) return false;
- }
- return true;
-}
-
-inline size_t hash_value(const State &state) {
- return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_);
-}
-
-class Vocabulary : public base::Vocabulary {
- public:
- Vocabulary();
-
- ~Vocabulary();
-
- WordIndex Index(const StringPiece &str) const {
- std::string temp(str.data(), str.length());
- return Index(temp.c_str());
- }
- WordIndex Index(const std::string &str) const {
- return Index(str.c_str());
- }
- WordIndex Index(const char *str) const;
-
- const char *Word(WordIndex index) const;
-
- private:
- friend class Model;
- void FinishedLoading();
-
- // The parent class isn't copyable so auto_ptr is the same as scoped_ptr
- // but without the boost dependence.
- mutable std::auto_ptr<Vocab> sri_;
-};
-
-class Model : public base::ModelFacade<Model, State, Vocabulary> {
- public:
- Model(const char *file_name, unsigned int ngram_length);
-
- ~Model();
-
- FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
-
- private:
- Vocabulary vocab_;
-
- mutable std::auto_ptr<Ngram> sri_;
-
- WordIndex not_found_;
-};
-
-} // namespace sri
-} // namespace lm
-
-#endif // LM_SRI__
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 03b0767a..ffec41ca 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
+ bound_ = end_ - begin_ + 1;
}
namespace {
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 41e97052..3c3414fb 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -15,10 +15,10 @@
namespace lm {
class ProbBackoff;
+class EnumerateVocab;
namespace ngram {
class Config;
-class EnumerateVocab;
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len);
@@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
- // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.
WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.