summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-09 15:33:35 +0200
committerPatrick Simianer <p@simianer.de>2011-09-23 19:13:58 +0200
commitedb0cc0cbae1e75e4aeedb6360eab325effe6573 (patch)
treea2fed4614b88f177f91e88fef3b269fa75e80188
parent2e6ef7cbec77b22ce3d64416a5ada3a6c081f9e2 (diff)
partial merge, ruleid feature
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/cdec_ff.cc9
-rw-r--r--decoder/ff_klm.cc38
-rw-r--r--decoder/ff_klm.h7
-rw-r--r--decoder/ff_ngrams.cc341
-rw-r--r--decoder/ff_ngrams.h29
-rw-r--r--decoder/ff_rules.cc107
-rw-r--r--decoder/ff_rules.h40
-rw-r--r--decoder/ff_spans.cc74
-rw-r--r--decoder/ff_spans.h15
-rw-r--r--dtrain/dtrain.cc50
-rwxr-xr-xdtrain/run.sh5
-rw-r--r--dtrain/sample.h18
-rw-r--r--dtrain/test/EXAMPLE/dtrain.ini2
-rw-r--r--klm/lm/Makefile.am1
-rw-r--r--klm/lm/bhiksha.cc93
-rw-r--r--klm/lm/bhiksha.hh108
-rw-r--r--klm/lm/binary_format.cc13
-rw-r--r--klm/lm/binary_format.hh9
-rw-r--r--klm/lm/build_binary.cc54
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh5
-rw-r--r--klm/lm/model.cc67
-rw-r--r--klm/lm/model.hh12
-rw-r--r--klm/lm/model_test.cc73
-rw-r--r--klm/lm/ngram_query.cc9
-rw-r--r--klm/lm/quantize.cc1
-rw-r--r--klm/lm/quantize.hh4
-rw-r--r--klm/lm/read_arpa.cc6
-rw-r--r--klm/lm/search_hashed.cc2
-rw-r--r--klm/lm/search_hashed.hh3
-rw-r--r--klm/lm/search_trie.cc45
-rw-r--r--klm/lm/search_trie.hh20
-rw-r--r--klm/lm/test_nounk.arpa120
-rw-r--r--klm/lm/trie.cc57
-rw-r--r--klm/lm/trie.hh24
-rw-r--r--klm/lm/vocab.cc6
-rw-r--r--klm/lm/vocab.hh4
-rw-r--r--klm/util/bit_packing.hh13
-rw-r--r--klm/util/exception.cc28
-rw-r--r--klm/util/exception.hh56
-rw-r--r--klm/util/file_piece.cc42
-rw-r--r--klm/util/file_piece.hh34
-rw-r--r--klm/util/murmur_hash.cc258
-rw-r--r--klm/util/probing_hash_table.hh2
-rw-r--r--klm/util/sorted_uniform.hh23
46 files changed, 1490 insertions, 440 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 244da2de..e5f7505f 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -61,10 +61,12 @@ libcdec_a_SOURCES = \
phrasetable_fst.cc \
trule.cc \
ff.cc \
+ ff_rules.cc \
ff_wordset.cc \
ff_charset.cc \
ff_lm.cc \
ff_klm.cc \
+ ff_ngrams.cc \
ff_spans.cc \
ff_ruleshape.cc \
ff_wordalign.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 31f88a4f..588842f1 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -4,10 +4,12 @@
#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
+#include "ff_ngrams.h"
#include "ff_csplit.h"
#include "ff_wordalign.h"
#include "ff_tagger.h"
#include "ff_factory.h"
+#include "ff_rules.h"
#include "ff_ruleshape.h"
#include "ff_bleu.h"
#include "ff_lm_fsa.h"
@@ -51,12 +53,11 @@ void register_feature_functions() {
ff_registry.Register("RandLM", new FFFactory<LanguageModelRandLM>);
#endif
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
+ ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
+ ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
ff_registry.Register("RuleNgramFeatures", new FFFactory<RuleNgramFeatures>());
ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>());
- ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
- ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >());
- ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory<KLanguageModel<lm::ngram::QuantTrieModel> >());
- ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >());
+ ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 9b7fe2d3..24dcb9c3 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -9,6 +9,7 @@
#include "stringlib.h"
#include "hg.h"
#include "tdict.h"
+#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
using namespace std;
@@ -434,8 +435,37 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
features->set_value(oov_fid_, oovs);
}
-// instantiate templates
-template class KLanguageModel<lm::ngram::ProbingModel>;
-template class KLanguageModel<lm::ngram::TrieModel>;
-template class KLanguageModel<lm::ngram::QuantTrieModel>;
+template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string &param) {
+ KLanguageModel<Model> *ret = new KLanguageModel<Model>(param);
+ ret->Init();
+ return boost::shared_ptr<FeatureFunction>(ret);
+}
+boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const {
+ using namespace lm::ngram;
+ std::string filename, ignored_map;
+ bool ignored_markers;
+ std::string ignored_featname;
+ ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname);
+ ModelType m;
+ if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING;
+
+ switch (m) {
+ case HASH_PROBING:
+ return CreateModel<ProbingModel>(param);
+ case TRIE_SORTED:
+ return CreateModel<TrieModel>(param);
+ case ARRAY_TRIE_SORTED:
+ return CreateModel<ArrayTrieModel>(param);
+ case QUANT_TRIE_SORTED:
+ return CreateModel<QuantTrieModel>(param);
+ case QUANT_ARRAY_TRIE_SORTED:
+ return CreateModel<QuantArrayTrieModel>(param);
+ default:
+ UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m);
+ }
+}
+
+std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
+ return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
+}
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
index 5eafe8be..6efe50f6 100644
--- a/decoder/ff_klm.h
+++ b/decoder/ff_klm.h
@@ -4,8 +4,8 @@
#include <vector>
#include <string>
+#include "ff_factory.h"
#include "ff.h"
-#include "lm/model.hh"
template <class Model> struct KLanguageModelImpl;
@@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction {
KLanguageModelImpl<Model>* pimpl_;
};
+struct KLanguageModelFactory : public FactoryBase<FeatureFunction> {
+ FP Create(std::string param) const;
+ std::string usage(bool params,bool verbose) const;
+};
+
#endif
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
new file mode 100644
index 00000000..04dd1906
--- /dev/null
+++ b/decoder/ff_ngrams.cc
@@ -0,0 +1,341 @@
+#include "ff_ngrams.h"
+
+#include <cstring>
+#include <iostream>
+
+#include <boost/scoped_ptr.hpp>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "hg.h"
+#include "tdict.h"
+
+using namespace std;
+
+static const unsigned char HAS_FULL_CONTEXT = 1;
+static const unsigned char HAS_EOS_ON_RIGHT = 2;
+static const unsigned char MASK = 7;
+
+namespace {
+template <unsigned MAX_ORDER = 5>
+struct State {
+ explicit State() {
+ memset(state, 0, sizeof(state));
+ }
+ explicit State(int order) {
+ memset(state, 0, (order - 1) * sizeof(WordID));
+ }
+ State<MAX_ORDER>(char order, const WordID* mem) {
+ memcpy(state, mem, (order - 1) * sizeof(WordID));
+ }
+ State(const State<MAX_ORDER>& other) {
+ memcpy(state, other.state, sizeof(state));
+ }
+ const State& operator=(const State<MAX_ORDER>& other) {
+ memcpy(state, other.state, sizeof(state));
+ }
+ explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) {
+ char om1 = order - 1;
+ assert(om1 > 0);
+ for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i];
+ state[om1 - 1] = extend;
+ }
+ const WordID& operator[](size_t i) const { return state[i]; }
+ WordID& operator[](size_t i) { return state[i]; }
+ WordID state[MAX_ORDER];
+};
+}
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+class NgramDetectorImpl {
+
+ // returns the number of unscored words at the left edge of a span
+ inline int UnscoredSize(const void* state) const {
+ return *(static_cast<const char*>(state) + unscored_size_offset_);
+ }
+
+ inline void SetUnscoredSize(int size, void* state) const {
+ *(static_cast<char*>(state) + unscored_size_offset_) = size;
+ }
+
+ inline State<5> RemnantLMState(const void* cstate) const {
+ return State<5>(order_, static_cast<const WordID*>(cstate));
+ }
+
+ inline const State<5> BeginSentenceState() const {
+ State<5> state(order_);
+ state.state[0] = kSOS_;
+ return state;
+ }
+
+ inline void SetRemnantLMState(const State<5>& lmstate, void* state) const {
+ // if we were clever, we could use the memory pointed to by state to do all
+ // the work, avoiding this copy
+ memcpy(state, lmstate.state, (order_-1) * sizeof(WordID));
+ }
+
+ WordID IthUnscoredWord(int i, const void* state) const {
+ const WordID* const mem = reinterpret_cast<const WordID*>(static_cast<const char*>(state) + unscored_words_offset_);
+ return mem[i];
+ }
+
+ void SetIthUnscoredWord(int i, const WordID index, void *state) const {
+ WordID* mem = reinterpret_cast<WordID*>(static_cast<char*>(state) + unscored_words_offset_);
+ mem[i] = index;
+ }
+
+ inline bool GetFlag(const void *state, unsigned char flag) const {
+ return (*(static_cast<const char*>(state) + is_complete_offset_) & flag);
+ }
+
+ inline void SetFlag(bool on, unsigned char flag, void *state) const {
+ if (on) {
+ *(static_cast<char*>(state) + is_complete_offset_) |= flag;
+ } else {
+ *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag);
+ }
+ }
+
+ inline bool HasFullContext(const void *state) const {
+ return GetFlag(state, HAS_FULL_CONTEXT);
+ }
+
+ inline void SetHasFullContext(bool flag, void *state) const {
+ SetFlag(flag, HAS_FULL_CONTEXT, state);
+ }
+
+ void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) {
+ FidTree* ft = &fidroot_;
+ int n = 0;
+ WordID buf[10];
+ int ci = order_ - 1;
+ WordID curword = cur;
+ while(curword) {
+ buf[n] = curword;
+ int& fid = ft->fids[curword];
+ ++n;
+ if (!fid) {
+ const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.)
+ ostringstream os;
+ os << code[n] << ':';
+ for (int i = n-1; i >= 0; --i) {
+ os << (i != n-1 ? "_" : "");
+ const string& tok = TD::Convert(buf[i]);
+ if (tok.find('=') == string::npos)
+ os << tok;
+ else
+ os << Escape(tok);
+ }
+ fid = FD::Convert(os.str());
+ }
+ feats->set_value(fid, 1);
+ ft = &ft->levels[curword];
+ --ci;
+ if (ci < 0) break;
+ curword = state[ci];
+ }
+ }
+
+ public:
+ void LookupWords(const TRule& rule, const vector<const void*>& ant_states, SparseVector<double>* feats, SparseVector<double>* est_feats, void* remnant) {
+ double sum = 0.0;
+ double est_sum = 0.0;
+ int num_scored = 0;
+ int num_estimated = 0;
+ bool saw_eos = false;
+ bool has_some_history = false;
+ State<5> state;
+ const vector<WordID>& e = rule.e();
+ bool context_complete = false;
+ for (int j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) { // handle non-terminal substitution
+ const void* astate = (ant_states[-e[j]]);
+ int unscored_ant_len = UnscoredSize(astate);
+ for (int k = 0; k < unscored_ant_len; ++k) {
+ const WordID cur_word = IthUnscoredWord(k, astate);
+ const bool is_oov = (cur_word == 0);
+ SparseVector<double> p;
+ if (cur_word == kSOS_) {
+ state = BeginSentenceState();
+ if (has_some_history) { // this is immediately fully scored, and bad
+ p.set_value(FD::Convert("Malformed"), 1.0);
+ context_complete = true;
+ } else { // this might be a real <s>
+ num_scored = max(0, order_ - 2);
+ }
+ } else {
+ FireFeatures(state, cur_word, &p);
+ const State<5> scopy = State<5>(state, order_, cur_word);
+ state = scopy;
+ if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); }
+ saw_eos = (cur_word == kEOS_);
+ }
+ has_some_history = true;
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ (*feats) += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ (*est_feats) += p;
+ }
+ }
+ saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
+ if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
+ state = RemnantLMState(astate);
+ context_complete = true;
+ }
+ } else { // handle terminal
+ const WordID cur_word = e[j];
+ SparseVector<double> p;
+ if (cur_word == kSOS_) {
+ state = BeginSentenceState();
+ if (has_some_history) { // this is immediately fully scored, and bad
+ p.set_value(FD::Convert("Malformed"), -100);
+ context_complete = true;
+ } else { // this might be a real <s>
+ num_scored = max(0, order_ - 2);
+ }
+ } else {
+ FireFeatures(state, cur_word, &p);
+ const State<5> scopy = State<5>(state, order_, cur_word);
+ state = scopy;
+ if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); }
+ saw_eos = (cur_word == kEOS_);
+ }
+ has_some_history = true;
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ (*feats) += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ (*est_feats) += p;
+ }
+ }
+ }
+ if (remnant) {
+ SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant);
+ SetRemnantLMState(state, remnant);
+ SetUnscoredSize(num_estimated, remnant);
+ SetHasFullContext(context_complete || (num_scored >= order_), remnant);
+ }
+ }
+
+ // this assumes no target words on final unary -> goal rule. is that ok?
+ // for <s> (n-1 left words) and (n-1 right words) </s>
+ void FinalTraversal(const void* state, SparseVector<double>* feats) {
+ if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
+ SetRemnantLMState(BeginSentenceState(), dummy_state_);
+ SetHasFullContext(1, dummy_state_);
+ SetUnscoredSize(0, dummy_state_);
+ dummy_ants_[1] = state;
+ LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL);
+ } else { // rules DO produce <s> ... </s>
+#if 0
+ double p = 0;
+ if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
+ if (UnscoredSize(state) > 0) { // are there unscored words
+ if (kSOS_ != IthUnscoredWord(0, state)) {
+ p -= 100 * UnscoredSize(state);
+ }
+ }
+ return p;
+#endif
+ }
+ }
+
+ public:
+ explicit NgramDetectorImpl(bool explicit_markers) :
+ kCDEC_UNK(TD::Convert("<unk>")) ,
+ add_sos_eos_(!explicit_markers) {
+ order_ = 3;
+ state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID);
+ unscored_size_offset_ = (order_ - 1) * sizeof(WordID);
+ is_complete_offset_ = unscored_size_offset_ + 1;
+ unscored_words_offset_ = is_complete_offset_ + 1;
+
+ // special handling of beginning / ending sentence markers
+ dummy_state_ = new char[state_size_];
+ memset(dummy_state_, 0, state_size_);
+ dummy_ants_.push_back(dummy_state_);
+ dummy_ants_.push_back(NULL);
+ dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
+ kSOS_ = TD::Convert("<s>");
+ kEOS_ = TD::Convert("</s>");
+ }
+
+ ~NgramDetectorImpl() {
+ delete[] dummy_state_;
+ }
+
+ int ReserveStateSize() const { return state_size_; }
+
+ private:
+ const WordID kCDEC_UNK;
+ WordID kSOS_; // <s> - requires special handling.
+ WordID kEOS_; // </s>
+ const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
+ // if this is true, FinalTransitionFeatures will "add" <s> and </s>
+ // if false, FinalTransitionFeatures will score anything with the
+ // markers in the right place (i.e., the beginning and end of
+ // the sentence) with 0, and anything else with -100
+
+ int order_;
+ int state_size_;
+ int unscored_size_offset_;
+ int is_complete_offset_;
+ int unscored_words_offset_;
+ char* dummy_state_;
+ vector<const void*> dummy_ants_;
+ TRulePtr dummy_rule_;
+ struct FidTree {
+ map<WordID, int> fids;
+ map<WordID, FidTree> levels;
+ };
+ mutable FidTree fidroot_;
+};
+
+NgramDetector::NgramDetector(const string& param) {
+ string filename, mapfile, featname;
+ bool explicit_markers = (param == "-x");
+ pimpl_ = new NgramDetectorImpl(explicit_markers);
+ SetStateSize(pimpl_->ReserveStateSize());
+}
+
+NgramDetector::~NgramDetector() {
+ delete pimpl_;
+}
+
+void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state);
+}
+
+void NgramDetector::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ pimpl_->FinalTraversal(ant_state, features);
+}
+
diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h
new file mode 100644
index 00000000..82f61b33
--- /dev/null
+++ b/decoder/ff_ngrams.h
@@ -0,0 +1,29 @@
+#ifndef _NGRAMS_FF_H_
+#define _NGRAMS_FF_H_
+
+#include <vector>
+#include <map>
+#include <string>
+
+#include "ff.h"
+
+struct NgramDetectorImpl;
+class NgramDetector : public FeatureFunction {
+ public:
+ // param = "filename.lm [-o n]"
+ NgramDetector(const std::string& param);
+ ~NgramDetector();
+ virtual void FinalTraversalFeatures(const void* context,
+ SparseVector<double>* features) const;
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ NgramDetectorImpl* pimpl_;
+};
+
+#endif
diff --git a/decoder/ff_rules.cc b/decoder/ff_rules.cc
new file mode 100644
index 00000000..bd4c4cc0
--- /dev/null
+++ b/decoder/ff_rules.cc
@@ -0,0 +1,107 @@
+#include "ff_rules.h"
+
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+RuleIdentityFeatures::RuleIdentityFeatures(const std::string& param) {
+}
+
+void RuleIdentityFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+// std::map<const TRule*, SparseVector<double> >
+ rule2_fid_.clear();
+}
+
+void RuleIdentityFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ map<const TRule*, int>::iterator it = rule2_fid_.find(edge.rule_.get());
+ if (it == rule2_fid_.end()) {
+ const TRule& rule = *edge.rule_;
+ ostringstream os;
+ os << "R:";
+ if (rule.lhs_ < 0) os << TD::Convert(-rule.lhs_) << ':';
+ for (unsigned i = 0; i < rule.f_.size(); ++i) {
+ if (i > 0) os << '_';
+ WordID w = rule.f_[i];
+ if (w < 0) { os << 'N'; w = -w; }
+ assert(w > 0);
+ os << TD::Convert(w);
+ }
+ os << ':';
+ for (unsigned i = 0; i < rule.e_.size(); ++i) {
+ if (i > 0) os << '_';
+ WordID w = rule.e_[i];
+ if (w <= 0) {
+ os << 'N' << (1-w);
+ } else {
+ os << TD::Convert(w);
+ }
+ }
+ it = rule2_fid_.insert(make_pair(&rule, FD::Convert(Escape(os.str())))).first;
+ }
+ features->add_value(it->second, 1);
+}
+
+RuleNgramFeatures::RuleNgramFeatures(const std::string& param) {
+}
+
+void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+// std::map<const TRule*, SparseVector<double> >
+ rule2_feats_.clear();
+}
+
+void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get());
+ if (it == rule2_feats_.end()) {
+ const TRule& rule = *edge.rule_;
+ it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
+ SparseVector<double>& f = it->second;
+ string prev = "<r>";
+ for (int i = 0; i < rule.f_.size(); ++i) {
+ WordID w = rule.f_[i];
+ if (w < 0) w = -w;
+ assert(w > 0);
+ const string& cur = TD::Convert(w);
+ ostringstream os;
+ os << "RB:" << prev << '_' << cur;
+ const int fid = FD::Convert(Escape(os.str()));
+ if (fid <= 0) return;
+ f.add_value(fid, 1.0);
+ prev = cur;
+ }
+ ostringstream os;
+ os << "RB:" << prev << '_' << "</r>";
+ f.set_value(FD::Convert(Escape(os.str())), 1.0);
+ }
+ (*features) += it->second;
+}
+
diff --git a/decoder/ff_rules.h b/decoder/ff_rules.h
new file mode 100644
index 00000000..48d8bd05
--- /dev/null
+++ b/decoder/ff_rules.h
@@ -0,0 +1,40 @@
+#ifndef _FF_RULES_H_
+#define _FF_RULES_H_
+
+#include <vector>
+#include <map>
+#include "ff.h"
+#include "array2d.h"
+#include "wordid.h"
+
+class RuleIdentityFeatures : public FeatureFunction {
+ public:
+ RuleIdentityFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ mutable std::map<const TRule*, int> rule2_fid_;
+};
+
+class RuleNgramFeatures : public FeatureFunction {
+ public:
+ RuleNgramFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
+};
+
+#endif
diff --git a/decoder/ff_spans.cc b/decoder/ff_spans.cc
index e1da088d..0483517b 100644
--- a/decoder/ff_spans.cc
+++ b/decoder/ff_spans.cc
@@ -13,6 +13,17 @@
using namespace std;
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
// log transform to make long spans cluster together
// but preserve differences
int SpanSizeTransform(unsigned span_size) {
@@ -140,19 +151,19 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
word = MapIfNecessary(word);
ostringstream sfid;
sfid << "ES:" << TD::Convert(word);
- end_span_ids_[i] = FD::Convert(sfid.str());
+ end_span_ids_[i] = FD::Convert(Escape(sfid.str()));
ostringstream esbiid;
esbiid << "EBI:" << TD::Convert(bword) << "_" << TD::Convert(word);
- end_bigram_ids_[i] = FD::Convert(esbiid.str());
+ end_bigram_ids_[i] = FD::Convert(Escape(esbiid.str()));
ostringstream bsbiid;
bsbiid << "BBI:" << TD::Convert(bword) << "_" << TD::Convert(word);
- beg_bigram_ids_[i] = FD::Convert(bsbiid.str());
+ beg_bigram_ids_[i] = FD::Convert(Escape(bsbiid.str()));
ostringstream bfid;
bfid << "BS:" << TD::Convert(bword);
- beg_span_ids_[i] = FD::Convert(bfid.str());
+ beg_span_ids_[i] = FD::Convert(Escape(bfid.str()));
if (use_collapsed_features_) {
- end_span_vals_[i] = feat2val_[sfid.str()] + feat2val_[esbiid.str()];
- beg_span_vals_[i] = feat2val_[bfid.str()] + feat2val_[bsbiid.str()];
+ end_span_vals_[i] = feat2val_[Escape(sfid.str())] + feat2val_[Escape(esbiid.str())];
+ beg_span_vals_[i] = feat2val_[Escape(bfid.str())] + feat2val_[Escape(bsbiid.str())];
}
}
for (int i = 0; i <= lattice.size(); ++i) {
@@ -167,60 +178,21 @@ void SpanFeatures::PrepareForInput(const SentenceMetadata& smeta) {
word = MapIfNecessary(word);
ostringstream pf;
pf << "S:" << TD::Convert(bword) << "_" << TD::Convert(word);
- span_feats_(i,j).first = FD::Convert(pf.str());
- span_feats_(i,j).second = FD::Convert("S_" + pf.str());
+ span_feats_(i,j).first = FD::Convert(Escape(pf.str()));
+ span_feats_(i,j).second = FD::Convert(Escape("S_" + pf.str()));
ostringstream lf;
const unsigned span_size = (i < j ? j - i : i - j);
lf << "LS:" << SpanSizeTransform(span_size) << "_" << TD::Convert(bword) << "_" << TD::Convert(word);
- len_span_feats_(i,j).first = FD::Convert(lf.str());
- len_span_feats_(i,j).second = FD::Convert("S_" + lf.str());
+ len_span_feats_(i,j).first = FD::Convert(Escape(lf.str()));
+ len_span_feats_(i,j).second = FD::Convert(Escape("S_" + lf.str()));
if (use_collapsed_features_) {
- span_vals_(i,j).first = feat2val_[pf.str()] + feat2val_[lf.str()];
- span_vals_(i,j).second = feat2val_["S_" + pf.str()] + feat2val_["S_" + lf.str()];
+ span_vals_(i,j).first = feat2val_[Escape(pf.str())] + feat2val_[Escape(lf.str())];
+ span_vals_(i,j).second = feat2val_[Escape("S_" + pf.str())] + feat2val_[Escape("S_" + lf.str())];
}
}
}
}
-RuleNgramFeatures::RuleNgramFeatures(const std::string& param) {
-}
-
-void RuleNgramFeatures::PrepareForInput(const SentenceMetadata& smeta) {
-// std::map<const TRule*, SparseVector<double> >
- rule2_feats_.clear();
-}
-
-void RuleNgramFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- map<const TRule*, SparseVector<double> >::iterator it = rule2_feats_.find(edge.rule_.get());
- if (it == rule2_feats_.end()) {
- const TRule& rule = *edge.rule_;
- it = rule2_feats_.insert(make_pair(&rule, SparseVector<double>())).first;
- SparseVector<double>& f = it->second;
- string prev = "<r>";
- for (int i = 0; i < rule.f_.size(); ++i) {
- WordID w = rule.f_[i];
- if (w < 0) w = -w;
- assert(w > 0);
- const string& cur = TD::Convert(w);
- ostringstream os;
- os << "RB:" << prev << '_' << cur;
- const int fid = FD::Convert(os.str());
- if (fid <= 0) return;
- f.add_value(fid, 1.0);
- prev = cur;
- }
- ostringstream os;
- os << "RB:" << prev << '_' << "</r>";
- f.set_value(FD::Convert(os.str()), 1.0);
- }
- (*features) += it->second;
-}
-
inline bool IsArity2RuleReordered(const TRule& rule) {
const vector<WordID>& e = rule.e_;
for (int i = 0; i < e.size(); ++i) {
diff --git a/decoder/ff_spans.h b/decoder/ff_spans.h
index b22c4d03..24e0dede 100644
--- a/decoder/ff_spans.h
+++ b/decoder/ff_spans.h
@@ -44,21 +44,6 @@ class SpanFeatures : public FeatureFunction {
WordID oov_;
};
-class RuleNgramFeatures : public FeatureFunction {
- public:
- RuleNgramFeatures(const std::string& param);
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
- private:
- mutable std::map<const TRule*, SparseVector<double> > rule2_feats_;
-};
-
class CMR2008ReorderingFeatures : public FeatureFunction {
public:
CMR2008ReorderingFeatures(const std::string& param);
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index d58478a8..35996d6d 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -215,7 +215,7 @@ main( int argc, char** argv )
// for the perceptron/SVM; TODO as params
double eta = 0.0005;
- double gamma = 0.01; // -> SVM
+ double gamma = 0.;//01; // -> SVM
lambdas.add_value( FD::Convert("__bias"), 0 );
// for random sampling
@@ -388,10 +388,8 @@ main( int argc, char** argv )
if ( !noup ) {
TrainingInstances pairs;
-
- sample_all_rand(kb, pairs);
- cout << pairs.size() << endl;
-
+ sample_all( kb, pairs );
+
for ( TrainingInstances::iterator ti = pairs.begin();
ti != pairs.end(); ti++ ) {
@@ -401,31 +399,29 @@ main( int argc, char** argv )
//} else {
//dv = ti->first - ti->second;
//}
- dv.add_value( FD::Convert("__bias"), -1 );
+ dv.add_value( FD::Convert("__bias"), -1 );
- SparseVector<double> reg;
- reg = lambdas * ( 2 * gamma );
- dv -= reg;
- lambdas += dv * eta;
-
- if ( verbose ) {
- cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
- cout << " i " << TD::GetString(kb->sents[ti->first_rank]) << endl;
- cout << " " << kb->feats[ti->first_rank] << endl;
- cout << " j " << TD::GetString(kb->sents[ti->second_rank]) << endl;
- cout << " " << kb->feats[ti->second_rank] << endl;
- cout << " diff vec: " << dv << endl;
- cout << " lambdas after update: " << lambdas << endl;
- cout << "}}" << endl;
- }
-
+ //SparseVector<double> reg;
+ //reg = lambdas * ( 2 * gamma );
+ //dv -= reg;
+ lambdas += dv * eta;
+
+ if ( verbose ) {
+ cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
+ cout << " i " << TD::GetString(kb->sents[ti->first_rank]) << endl;
+ cout << " " << kb->feats[ti->first_rank] << endl;
+ cout << " j " << TD::GetString(kb->sents[ti->second_rank]) << endl;
+ cout << " " << kb->feats[ti->second_rank] << endl;
+ cout << " diff vec: " << dv << endl;
+ cout << " lambdas after update: " << lambdas << endl;
+ cout << "}}" << endl;
+ }
} else {
- //if ( 0 ) {
- SparseVector<double> reg;
- reg = lambdas * ( gamma * 2 );
- lambdas += reg * ( -eta );
- //}
+ //SparseVector<double> reg;
+ //reg = lambdas * ( 2 * gamma );
+ //lambdas += reg * ( -eta );
}
+
}
//double l2 = lambdas.l2norm();
diff --git a/dtrain/run.sh b/dtrain/run.sh
index 16575c25..97123dfa 100755
--- a/dtrain/run.sh
+++ b/dtrain/run.sh
@@ -2,9 +2,10 @@
#INI=test/blunsom08.dtrain.ini
#INI=test/nc-wmt11/dtrain.ini
-#INI=test/EXAMPLE/dtrain.ini
-INI=test/EXAMPLE/dtrain.ruleids.ini
+INI=test/EXAMPLE/dtrain.ini
+#INI=test/EXAMPLE/dtrain.ruleids.ini
#INI=test/toy.dtrain.ini
+#INI=test/EXAMPLE/dtrain.cdecrid.ini
rm /tmp/dtrain-*
./dtrain -c $INI $1 $2 $3 $4
diff --git a/dtrain/sample.h b/dtrain/sample.h
index b6aa9abd..502901af 100644
--- a/dtrain/sample.h
+++ b/dtrain/sample.h
@@ -37,20 +37,20 @@ sample_all( KBestList* kb, TrainingInstances &training )
}
void
-sample_all_rand( KBestList* kb, TrainingInstances &training )
+sample_rand( KBestList* kb, TrainingInstances &training )
{
srand( time(NULL) );
for ( size_t i = 0; i < kb->GetSize()-1; i++ ) {
for ( size_t j = i+1; j < kb->GetSize(); j++ ) {
if ( rand() % 2 ) {
- TPair p;
- p.first = kb->feats[i];
- p.second = kb->feats[j];
- p.first_rank = i;
- p.second_rank = j;
- p.first_score = kb->scores[i];
- p.second_score = kb->scores[j];
- training.push_back( p );
+ TPair p;
+ p.first = kb->feats[i];
+ p.second = kb->feats[j];
+ p.first_rank = i;
+ p.second_rank = j;
+ p.first_score = kb->scores[i];
+ p.second_score = kb->scores[j];
+ training.push_back( p );
}
}
}
diff --git a/dtrain/test/EXAMPLE/dtrain.ini b/dtrain/test/EXAMPLE/dtrain.ini
index 7645921a..7221ba3f 100644
--- a/dtrain/test/EXAMPLE/dtrain.ini
+++ b/dtrain/test/EXAMPLE/dtrain.ini
@@ -5,6 +5,6 @@ epochs=3
input=test/EXAMPLE/dtrain.nc-1k
scorer=stupid_bleu
output=test/EXAMPLE/weights.gz
-stop_after=100
+stop_after=1000
wprint=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index 395494bc..fae6b41a 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz
noinst_LIBRARIES = libklm.a
libklm_a_SOURCES = \
+ bhiksha.cc \
binary_format.cc \
config.cc \
lm_exception.cc \
diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc
new file mode 100644
index 00000000..bf86fd4b
--- /dev/null
+++ b/klm/lm/bhiksha.cc
@@ -0,0 +1,93 @@
+#include "lm/bhiksha.hh"
+#include "lm/config.hh"
+
+#include <limits>
+
+namespace lm {
+namespace ngram {
+namespace trie {
+
+DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) :
+ next_(util::BitsMask::ByMax(max_next)) {}
+
+const uint8_t kArrayBhikshaVersion = 0;
+
+void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) {
+ uint8_t version;
+ uint8_t configured_bits;
+ if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) {
+ UTIL_THROW(util::ErrnoException, "Could not read from binary file");
+ }
+ if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
+ config.pointer_bhiksha_bits = configured_bits;
+}
+
+namespace {
+
+// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset)
+uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ uint8_t required = util::RequiredBits(max_next);
+ uint8_t best_chop = 0;
+ int64_t lowest_change = std::numeric_limits<int64_t>::max();
+ // There are probably faster ways but I don't care because this is only done once per order at construction time.
+ for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) {
+ int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */
+ - max_offset * static_cast<int64_t>(chop); /* savings in bits*/
+ if (change < lowest_change) {
+ lowest_change = change;
+ best_chop = chop;
+ }
+ }
+ return best_chop;
+}
+
+std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ uint8_t required = util::RequiredBits(max_next);
+ uint8_t chopping = ChopBits(max_offset, max_next, config);
+ return (max_next >> (required - chopping)) + 1 /* we store 0 too */;
+}
+} // namespace
+
+std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
+}
+
+uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
+ return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config);
+}
+
+namespace {
+
+void *AlignTo8(void *from) {
+ uint8_t *val = reinterpret_cast<uint8_t*>(from);
+ std::size_t remainder = reinterpret_cast<std::size_t>(val) & 7;
+ if (!remainder) return val;
+ return val + 8 - remainder;
+}
+
+} // namespace
+
+ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config)
+ : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))),
+ offset_begin_(reinterpret_cast<const uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */),
+ offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)),
+ write_to_(reinterpret_cast<uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */),
+ original_base_(base) {}
+
+void ArrayBhiksha::FinishedLoading(const Config &config) {
+ // *offset_begin_ = 0 but without a const_cast.
+ *(write_to_ - (write_to_ - offset_begin_)) = 0;
+
+ if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected.");
+
+ uint8_t *head_write = reinterpret_cast<uint8_t*>(original_base_);
+ *(head_write++) = kArrayBhikshaVersion;
+ *(head_write++) = config.pointer_bhiksha_bits;
+}
+
+void ArrayBhiksha::LoadedBinary() {
+}
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
new file mode 100644
index 00000000..cfb2b053
--- /dev/null
+++ b/klm/lm/bhiksha.hh
@@ -0,0 +1,108 @@
+/* Simple implementation of
+ * @inproceedings{bhikshacompression,
+ * author={Bhiksha Raj and Ed Whittaker},
+ * year={2003},
+ * title={Lossless Compression of Language Model Structure and Word Identifiers},
+ * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing},
+ * pages={388--391},
+ * }
+ *
+ * Currently only used for next pointers.
+ */
+
+#include <inttypes.h>
+
+#include "lm/binary_format.hh"
+#include "lm/trie.hh"
+#include "util/bit_packing.hh"
+#include "util/sorted_uniform.hh"
+
+namespace lm {
+namespace ngram {
+class Config;
+
+namespace trie {
+
+class DontBhiksha {
+ public:
+ static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
+
+ static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {}
+
+ static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
+
+ static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
+ return util::RequiredBits(max_next);
+ }
+
+ DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const {
+ out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask);
+ out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask);
+ //assert(out.end >= out.begin);
+ }
+
+ void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) {
+ util::WriteInt57(base, bit_offset, next_.bits, value);
+ }
+
+ void FinishedLoading(const Config &/*config*/) {}
+
+ void LoadedBinary() {}
+
+ uint8_t InlineBits() const { return next_.bits; }
+
+ private:
+ util::BitsMask next_;
+};
+
+class ArrayBhiksha {
+ public:
+ static const ModelType kModelTypeAdd = kArrayAdd;
+
+ static void UpdateConfigFromBinary(int fd, Config &config);
+
+ static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);
+
+ ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config);
+
+ void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const {
+ const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index);
+ const uint64_t *end_it;
+ for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {}
+ --end_it;
+ out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |
+ util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask);
+ out.end = ((end_it - offset_begin_) << next_inline_.bits) |
+ util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask);
+ }
+
+ void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) {
+ uint64_t encode = value >> next_inline_.bits;
+ for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index;
+ util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask);
+ }
+
+ void FinishedLoading(const Config &config);
+
+ void LoadedBinary();
+
+ uint8_t InlineBits() const { return next_inline_.bits; }
+
+ private:
+ const util::BitsMask next_inline_;
+
+ const uint64_t *const offset_begin_;
+ const uint64_t *const offset_end_;
+
+ uint64_t *write_to_;
+
+ void *original_base_;
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 92b1008b..e02e621a 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -40,7 +40,7 @@ struct Sanity {
}
};
-const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"};
+const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
std::size_t Align8(std::size_t in) {
std::size_t off = in % 8;
@@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
}
}
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
+ std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
+ if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
+ UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
off_t page_size = sysconf(_SC_PAGE_SIZE);
- off_t alignment_cruft = backing.vocab.size() % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ off_t alignment_cruft = adjusted_vocab % page_size;
+ backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index 2b32b450..d28cb6c5 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -16,7 +16,12 @@
namespace lm {
namespace ngram {
-typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType;
+/* Not the best numbering system, but it grew this way for historical reasons
+ * and I want to preserve existing binary files. */
+typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType;
+
+const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED);
+const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
@@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off);
// Create just enough of a binary file to write vocabulary to it.
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing);
+uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 4552c419..b7aee4de 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -15,12 +15,12 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n"
-"-u sets the default log10 probability for <unk> if the ARPA file does not have\n"
-"one.\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n"
+"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
+" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
-"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
-"type is either probing or trie:\n\n"
+"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n"
+"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
@@ -29,10 +29,11 @@ void Usage(const char *name) {
"-t is the temporary directory prefix. Default is the output file name.\n"
"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
-"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n"
-"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
-"Passing only an input file will print memory usage of each data structure.\n"
-"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
+"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
+"-a compresses pointers using an array of offsets. The parameter is the\n"
+" maximum number of bits encoded by the array. Memory is minimized subject\n"
+" to the maximum, so pick 255 to minimize memory.\n\n"
+"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
@@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
- std::size_t sizes[3];
+ std::size_t sizes[5];
sizes[0] = ProbingModel::Size(counts, config);
sizes[1] = TrieModel::Size(counts, config);
sizes[2] = QuantTrieModel::Size(counts, config);
- std::size_t max_length = *std::max_element(sizes, sizes + 3);
- std::size_t min_length = *std::max_element(sizes, sizes + 3);
+ sizes[3] = ArrayTrieModel::Size(counts, config);
+ sizes[4] = QuantArrayTrieModel::Size(counts, config);
+ std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
+ std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t divide;
char prefix;
if (min_length < (1 << 10) * 10) {
@@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::cout << prefix << "B\n"
"probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n";
+ "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
}
void ProbingQuantizationUnsupported() {
@@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() {
int main(int argc, char *argv[]) {
using namespace lm::ngram;
- bool quantize = false, set_backoff_bits = false;
try {
+ bool quantize = false, set_backoff_bits = false, bhiksha = false;
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) {
+ while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -121,6 +126,9 @@ int main(int argc, char *argv[]) {
config.backoff_bits = ParseBitCount(optarg);
set_backoff_bits = true;
break;
+ case 'a':
+ config.pointer_bhiksha_bits = ParseBitCount(optarg);
+ bhiksha = true;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
@@ -162,9 +170,17 @@ int main(int argc, char *argv[]) {
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
if (quantize) {
- QuantTrieModel(from_file, config);
+ if (bhiksha) {
+ QuantArrayTrieModel(from_file, config);
+ } else {
+ QuantTrieModel(from_file, config);
+ }
} else {
- TrieModel(from_file, config);
+ if (bhiksha) {
+ ArrayTrieModel(from_file, config);
+ } else {
+ TrieModel(from_file, config);
+ }
}
} else {
Usage(argv[0]);
@@ -173,9 +189,9 @@ int main(int argc, char *argv[]) {
Usage(argv[0]);
}
}
- catch (std::exception &e) {
+ catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
- abort();
+ return 1;
}
return 0;
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index 08e1af5c..297589a4 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -20,6 +20,7 @@ Config::Config() :
include_vocab(true),
prob_bits(8),
backoff_bits(8),
+ pointer_bhiksha_bits(22),
load_method(util::POPULATE_OR_READ) {}
} // namespace ngram
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index dcc7cf35..227b8512 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -73,9 +73,12 @@ struct Config {
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
- // to quantize.
+ // to quantize (and one of the remaining backoffs will be 0).
uint8_t prob_bits, backoff_bits;
+ // Bhiksha compression (simple form). Only works with trie.
+ uint8_t pointer_bhiksha_bits;
+
// ONLY EFFECTIVE WHEN READING BINARY
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a1d10b3d..27e24b1c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -21,6 +21,8 @@ size_t hash_value(const State &state) {
namespace detail {
+template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
+
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
@@ -56,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
- std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
- ReadARPACounts(f, counts);
-
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
-
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
-
- if (config.write_mmap) {
- WriteWordsWrapper wrap(config.enumerate_vocab);
- vocab_.ConfigureEnumerate(&wrap, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
- } else {
- vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- }
+ try {
+ std::vector<uint64_t> counts;
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
+ ReadARPACounts(f, counts);
+
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
+
+ std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+
+ if (config.write_mmap) {
+ WriteWordsWrapper wrap(config.enumerate_vocab);
+ vocab_.ConfigureEnumerate(&wrap, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ wrap.Write(backing_.file.get());
+ } else {
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ }
- if (!vocab_.SawUnk()) {
- assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ if (!vocab_.SawUnk()) {
+ assert(config.unknown_missing != THROW_UP);
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ }
+ FinishFile(config, kModelType, counts, backing_);
+ } catch (util::Exception &e) {
+ e << " Byte: " << f.Offset();
+ throw;
}
- FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
@@ -225,8 +232,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
-template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED
-template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
} // namespace ngram
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 1f49a382..21595321 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -1,6 +1,7 @@
#ifndef LM_MODEL__
#define LM_MODEL__
+#include "lm/bhiksha.hh"
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
@@ -71,6 +72,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
+ // This is the model type returned by RecognizeBinary.
+ static const ModelType kModelType;
+
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
@@ -131,8 +135,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
Backing &MutableBacking() { return backing_; }
- static const ModelType kModelType = Search::kModelType;
-
Backing backing_;
VocabularyT vocab_;
@@ -152,9 +154,11 @@ typedef ProbingModel Model;
// Smaller implementation.
typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel;
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
+typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 8bf040ff..57c7291c 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);
}
+template <class M> void NoUnkCheck(const M &model) {
+ WordIndex unk_index = 0;
+ State state;
+
+ FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
+ BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001);
+}
+
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
@@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
- ExpectEnumerateVocab enumerate;
- config.enumerate_vocab = &enumerate;
config.probing_multiplier = 2.0;
- ModelT m("test.arpa", config);
- enumerate.Check(m.GetVocabulary());
- Everything(m);
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ Everything(m);
+ }
+ {
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test_nounk.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ NoUnkCheck(m);
+ }
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
-
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
-
-BOOST_AUTO_TEST_CASE(quant) {
+BOOST_AUTO_TEST_CASE(quant_trie) {
LoadingTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(bhiksha_trie) {
+ LoadingTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
+ LoadingTest<QuantArrayTrieModel>();
+}
template <class ModelT> void BinaryTest() {
Config config;
@@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() {
config.write_mmap = NULL;
- ModelT binary("test.binary", config);
- enumerate.Check(binary.GetVocabulary());
- Everything(binary);
+ ModelType type;
+ BOOST_REQUIRE(RecognizeBinary("test.binary", type));
+ BOOST_CHECK_EQUAL(ModelT::kModelType, type);
+
+ {
+ ModelT binary("test.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ Everything(binary);
+ }
unlink("test.binary");
+
+ // Now test without <unk>.
+ config.write_mmap = "test_nounk.binary";
+ config.messages = NULL;
+ enumerate.Clear();
+ {
+ ModelT copy_model("test_nounk.arpa", config);
+ enumerate.Check(copy_model.GetVocabulary());
+ enumerate.Clear();
+ NoUnkCheck(copy_model);
+ }
+ config.write_mmap = NULL;
+ {
+ ModelT binary("test_nounk.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ NoUnkCheck(binary);
+ }
+ unlink("test_nounk.binary");
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
@@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
BinaryTest<QuantTrieModel>();
}
+BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
+ BinaryTest<ArrayTrieModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
+ BinaryTest<QuantArrayTrieModel>();
+}
} // namespace
} // namespace ngram
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 9454a6d1..d9db4aa2 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -99,6 +99,15 @@ int main(int argc, char *argv[]) {
case lm::ngram::TRIE_SORTED:
Query<lm::ngram::TrieModel>(argv[1], sentence_context);
break;
+ case lm::ngram::QUANT_TRIE_SORTED:
+ Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::ARRAY_TRIE_SORTED:
+ Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context);
+ break;
+ case lm::ngram::QUANT_ARRAY_TRIE_SORTED:
+ Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context);
+ break;
case lm::ngram::HASH_SORTED:
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index 4bb6b1b8..fd371cc8 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64
if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)
UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
+ AdvanceOrThrow(fd, -3);
}
void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index aae72b34..0b71d14a 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -21,7 +21,7 @@ class Config;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
- static const ModelType kModelType = TRIE_SORTED;
+ static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
@@ -108,7 +108,7 @@ class SeparatelyQuantize {
};
public:
- static const ModelType kModelType = QUANT_TRIE_SORTED;
+ static const ModelType kModelTypeAdd = kQuantAdd;
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 060a97ea..455bc4ba 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line;
- if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
+ while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
- UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank");
+ UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");
}
- if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
// So strtol doesn't go off the end of line.
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index c56ba7b8..82c53ec8 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -98,7 +98,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index f3acdefc..c62985e4 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -52,12 +52,11 @@ struct HashedSearch {
Unigram unigram;
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
const ProbBackoff &entry = unigram.Lookup(word);
prob = entry.prob;
backoff = entry.backoff;
next = static_cast<Node>(word);
- return true;
}
};
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 91f87f1c..05059ffb 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -1,6 +1,7 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
@@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t extra_count = counts[0] + 1;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
+ size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -610,9 +611,9 @@ class JustCount {
};
// Phase to actually write n-grams to the trie.
-template <class Quant> class WriteEntries {
+template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
@@ -649,7 +650,7 @@ template <class Quant> class WriteEntries {
private:
ContextReader *contexts_;
UnigramValue *const unigrams_;
- BitPackedMiddle<typename Quant::Middle> *const middle_;
+ BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
};
@@ -821,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So
} // namespace
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
@@ -846,7 +847,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
@@ -863,7 +864,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
@@ -901,14 +902,14 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
- for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
- i->FinishedLoading((i+1)->InsertIndex());
+ for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
+ i->FinishedLoading((i+1)->InsertIndex(), config);
}
- (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex());
+ (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);
}
}
-template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
quant_.SetupMemory(start, config);
start += Quant::Size(counts.size(), config);
unigram.Init(start);
@@ -919,22 +920,24 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c
std::vector<uint8_t*> middle_starts(counts.size() - 2);
for (unsigned char i = 2; i < counts.size(); ++i) {
middle_starts[i-2] = start;
- start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
+ start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
}
- // Crazy backwards thing so we initialize in the correct order.
+ // Crazy backwards thing so we initialize using pointers to ones that have already been initialized
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
quant_.Mid(i),
+ counts[i-1],
counts[0],
counts[i],
- (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]));
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]),
+ config);
}
longest.Init(start, quant_.Long(counts.size()), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
unigram.LoadedBinary();
for (Middle *i = middle_begin_; i != middle_end_; ++i) {
i->LoadedBinary();
@@ -942,7 +945,7 @@ template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
longest.LoadedBinary();
}
-template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
@@ -966,14 +969,16 @@ template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *fi
// At least 1MB sorting memory.
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, backing);
+ BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}
-template class TrieSearch<DontQuantize>;
-template class TrieSearch<SeparatelyQuantize>;
+template class TrieSearch<DontQuantize, DontBhiksha>;
+template class TrieSearch<DontQuantize, ArrayBhiksha>;
+template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
+template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
} // namespace trie
} // namespace ngram
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0a52acb5..2f39c09f 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,31 +13,33 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-template <class Quant> class TrieSearch;
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+template <class Quant, class Bhiksha> class TrieSearch;
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
-template <class Quant> class TrieSearch {
+template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
- typedef trie::BitPackedMiddle<typename Quant::Middle> Middle;
+ typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
Longest longest;
- static const ModelType kModelType = Quant::kModelType;
+ static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
+ AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
@@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {
}
private:
- friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa
new file mode 100644
index 00000000..060733d9
--- /dev/null
+++ b/klm/lm/test_nounk.arpa
@@ -0,0 +1,120 @@
+
+\data\
+ngram 1=36
+ngram 2=45
+ngram 3=10
+ngram 4=6
+ngram 5=4
+
+\1-grams:
+-1.383514 , -0.30103
+-1.139057 . -0.845098
+-1.029493 </s>
+-99 <s> -0.4149733
+-1.285941 a -0.69897
+-1.687872 also -0.30103
+-1.687872 beyond -0.30103
+-1.687872 biarritz -0.30103
+-1.687872 call -0.30103
+-1.687872 concerns -0.30103
+-1.687872 consider -0.30103
+-1.687872 considering -0.30103
+-1.687872 for -0.30103
+-1.509559 higher -0.30103
+-1.687872 however -0.30103
+-1.687872 i -0.30103
+-1.687872 immediate -0.30103
+-1.687872 in -0.30103
+-1.687872 is -0.30103
+-1.285941 little -0.69897
+-1.383514 loin -0.30103
+-1.687872 look -0.30103
+-1.285941 looking -0.4771212
+-1.206319 more -0.544068
+-1.509559 on -0.4771212
+-1.509559 screening -0.4771212
+-1.687872 small -0.30103
+-1.687872 the -0.30103
+-1.687872 to -0.30103
+-1.687872 watch -0.30103
+-1.687872 watching -0.30103
+-1.687872 what -0.30103
+-1.687872 would -0.30103
+-3.141592 foo
+-2.718281 bar 3.0
+-6.535897 baz -0.0
+
+\2-grams:
+-0.6925742 , .
+-0.7522095 , however
+-0.7522095 , is
+-0.0602359 . </s>
+-0.4846522 <s> looking -0.4771214
+-1.051485 <s> screening
+-1.07153 <s> the
+-1.07153 <s> watching
+-1.07153 <s> what
+-0.09132547 a little -0.69897
+-0.2922095 also call
+-0.2922095 beyond immediate
+-0.2705918 biarritz .
+-0.2922095 call for
+-0.2922095 concerns in
+-0.2922095 consider watch
+-0.2922095 considering consider
+-0.2834328 for ,
+-0.5511513 higher more
+-0.5845945 higher small
+-0.2834328 however ,
+-0.2922095 i would
+-0.2922095 immediate concerns
+-0.2922095 in biarritz
+-0.2922095 is to
+-0.09021038 little more -0.1998621
+-0.7273645 loin ,
+-0.6925742 loin .
+-0.6708385 loin </s>
+-0.2922095 look beyond
+-0.4638903 looking higher
+-0.4638903 looking on -0.4771212
+-0.5136299 more . -0.4771212
+-0.3561665 more loin
+-0.1649931 on a -0.4771213
+-0.1649931 screening a -0.4771213
+-0.2705918 small .
+-0.287799 the screening
+-0.2922095 to look
+-0.2622373 watch </s>
+-0.2922095 watching considering
+-0.2922095 what i
+-0.2922095 would also
+-2 also would -6
+-6 foo bar
+
+\3-grams:
+-0.01916512 more . </s>
+-0.0283603 on a little -0.4771212
+-0.0283603 screening a little -0.4771212
+-0.01660496 a little more -0.09409451
+-0.3488368 <s> looking higher
+-0.3488368 <s> looking on -0.4771212
+-0.1892331 little more loin
+-0.04835128 looking on a -0.4771212
+-3 also would consider -7
+-7 to look good
+
+\4-grams:
+-0.009249173 looking on a little -0.4771212
+-0.005464747 on a little more -0.4771212
+-0.005464747 screening a little more
+-0.1453306 a little more loin
+-0.01552657 <s> looking on a -0.4771212
+-4 also would consider higher -8
+
+\5-grams:
+-0.003061223 <s> looking on a little
+-0.001813953 looking on a little more
+-0.0432557 on a little more loin
+-5 also would consider higher looking
+
+\end\
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 63c2a612..8c536e66 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,5 +1,6 @@
#include "lm/trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
@@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
- return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr));
+template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+ return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
-template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) {
- if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- BaseInit(base, max_vocab, quant.TotalBits() + next_bits_);
+template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
+ BitPacked(),
+ quant_(quant),
+ // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
+ bhiksha_(base, entries + 1, max_next, config),
+ next_source_(&next_source) {
+ if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
+ BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
}
-template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) {
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
@@ -75,47 +81,42 @@ template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float
quant_.Write(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
uint64_t next = next_source_->InsertIndex();
- assert(next <= next_mask_);
- util::WriteInt57(base_, at_pointer, next_bits_, next);
+ bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
++insert_index_;
}
-template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
+ uint64_t index = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
quant_.Read(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
+
return true;
}
-template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
- uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
- at_pointer *= total_bits_;
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+ uint64_t index;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
+ uint64_t at_pointer = index * total_bits_;
at_pointer += word_bits_;
quant_.ReadBackoff(base_, at_pointer, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
return true;
}
-template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) {
- assert(next_end <= next_mask_);
- uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_, last_next_write, next_bits_, next_end);
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
+ uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
+ bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
+ bhiksha_.FinishedLoading(config);
}
template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
@@ -135,8 +136,10 @@ template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float
return true;
}
-template class BitPackedMiddle<DontQuantize::Middle>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle>;
+template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
template class BitPackedLongest<DontQuantize::Longest>;
template class BitPackedLongest<SeparatelyQuantize::Longest>;
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 8fa21aaf..53612064 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -10,6 +10,7 @@
namespace lm {
namespace ngram {
+class Config;
namespace trie {
struct NodeRange {
@@ -46,13 +47,12 @@ class Unigram {
void LoadedBinary() {}
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
prob = val->weights.prob;
backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
- return true;
}
private:
@@ -67,8 +67,6 @@ class BitPacked {
return insert_index_;
}
- void LoadedBinary() {}
-
protected:
static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
@@ -83,30 +81,30 @@ class BitPacked {
uint64_t insert_index_, max_vocab_;
};
-template <class Quant> class BitPackedMiddle : public BitPacked {
+template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
- BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
+ BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
void Insert(WordIndex word, float prob, float backoff);
+ void FinishedLoading(uint64_t next_end, const Config &config);
+
+ void LoadedBinary() { bhiksha_.LoadedBinary(); }
+
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
- void FinishedLoading(uint64_t next_end);
-
private:
Quant quant_;
- uint8_t next_bits_;
- uint64_t next_mask_;
+ Bhiksha bhiksha_;
const BitPacked *next_source_;
};
-
template <class Quant> class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
@@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_.TotalBits());
}
+ void LoadedBinary() {}
+
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 7defd5c1..04979d51 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
WordIndex index = 0;
while (true) {
ssize_t got = read(fd, &buf[0], kInitialRead);
- if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");
if (got == 0) return index;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
ssize_t ret = read(fd, &next_char, 1);
- if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
- if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word.");
+ UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words");
+ UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index c92518e4..9d218fff 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary {
}
}
+ // Size for purposes of file writing
static size_t Size(std::size_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
@@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary {
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
+ // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
+ std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
+
bool SawUnk() const { return saw_unk_; }
void LoadedBinary(int fd, EnumerateVocab *to);
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index b35d80c8..9f47d559 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -107,9 +107,20 @@ void BitPackingSanity();
uint8_t RequiredBits(uint64_t max_value);
struct BitsMask {
+ static BitsMask ByMax(uint64_t max_value) {
+ BitsMask ret;
+ ret.FromMax(max_value);
+ return ret;
+ }
+ static BitsMask ByBits(uint8_t bits) {
+ BitsMask ret;
+ ret.bits = bits;
+ ret.mask = (1ULL << bits) - 1;
+ return ret;
+ }
void FromMax(uint64_t max_value) {
bits = RequiredBits(max_value);
- mask = (1 << bits) - 1;
+ mask = (1ULL << bits) - 1;
}
uint8_t bits;
uint64_t mask;
diff --git a/klm/util/exception.cc b/klm/util/exception.cc
index 84f9fe7c..62280970 100644
--- a/klm/util/exception.cc
+++ b/klm/util/exception.cc
@@ -1,5 +1,9 @@
#include "util/exception.hh"
+#ifdef __GXX_RTTI
+#include <typeinfo>
+#endif
+
#include <errno.h>
#include <string.h>
@@ -22,6 +26,30 @@ const char *Exception::what() const throw() {
return text_.c_str();
}
+void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
+ /* The child class might have set some text, but we want this to come first.
+ * Another option would be passing this information to the constructor, but
+ * then child classes would have to accept constructor arguments and pass
+ * them down.
+ */
+ text_ = stream_.str();
+ stream_.str("");
+ stream_ << file << ':' << line;
+ if (func) stream_ << " in " << func << " threw ";
+ if (child_name) {
+ stream_ << child_name;
+ } else {
+#ifdef __GXX_RTTI
+ stream_ << typeid(this).name();
+#else
+ stream_ << "an exception";
+#endif
+ }
+ if (condition) stream_ << " because `" << condition;
+ stream_ << "'.\n";
+ stream_ << text_;
+}
+
namespace {
// The XOPEN version.
const char *HandleStrerror(int ret, const char *buf) {
diff --git a/klm/util/exception.hh b/klm/util/exception.hh
index c6936914..81675a57 100644
--- a/klm/util/exception.hh
+++ b/klm/util/exception.hh
@@ -1,8 +1,6 @@
#ifndef UTIL_EXCEPTION__
#define UTIL_EXCEPTION__
-#include "util/string_piece.hh"
-
#include <exception>
#include <sstream>
#include <string>
@@ -22,6 +20,14 @@ class Exception : public std::exception {
// Not threadsafe, but probably doesn't matter. FWIW, Boost's exception guidance implies that what() isn't threadsafe.
const char *what() const throw();
+ // For use by the UTIL_THROW macros.
+ void SetLocation(
+ const char *file,
+ unsigned int line,
+ const char *func,
+ const char *child_name,
+ const char *condition);
+
private:
template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data);
@@ -43,7 +49,49 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep
return e;
}
-#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; }
+#ifdef __GNUC__
+#define UTIL_FUNC_NAME __PRETTY_FUNCTION__
+#else
+#ifdef _WIN32
+#define UTIL_FUNC_NAME __FUNCTION__
+#else
+#define UTIL_FUNC_NAME NULL
+#endif
+#endif
+
+#define UTIL_SET_LOCATION(UTIL_e, child, condition) do { \
+ (UTIL_e).SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, (child), (condition)); \
+} while (0)
+
+/* Create an instance of Exception, add the message Modify, and throw it.
+ * Modify is appended to the what() message and can contain << for ostream
+ * operations.
+ *
+ * do .. while kludge to swallow trailing ; character
+ * http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .
+ */
+#define UTIL_THROW(Exception, Modify) do { \
+ Exception UTIL_e; \
+ UTIL_SET_LOCATION(UTIL_e, #Exception, NULL); \
+ UTIL_e << Modify; \
+ throw UTIL_e; \
+} while (0)
+
+#define UTIL_THROW_VAR(Var, Modify) do { \
+ Exception &UTIL_e = (Var); \
+ UTIL_SET_LOCATION(UTIL_e, NULL, NULL); \
+ UTIL_e << Modify; \
+ throw UTIL_e; \
+} while (0)
+
+#define UTIL_THROW_IF(Condition, Exception, Modify) do { \
+ if (Condition) { \
+ Exception UTIL_e; \
+ UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \
+ UTIL_e << Modify; \
+ throw UTIL_e; \
+ } \
+} while (0)
class ErrnoException : public Exception {
public:
@@ -51,7 +99,7 @@ class ErrnoException : public Exception {
virtual ~ErrnoException() throw();
- int Error() { return errno_; }
+ int Error() const throw() { return errno_; }
private:
int errno_;
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index f447a70c..cbe4234f 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -41,8 +41,8 @@ GZException::GZException(void *file) {
const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
int OpenReadOrThrow(const char *name) {
- int ret = open(name, O_RDONLY);
- if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading");
+ int ret;
+ UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name);
return ret;
}
@@ -52,13 +52,13 @@ off_t SizeFile(int fd) {
return sb.st_size;
}
-FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
+FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :
file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
-FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
+FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) :
file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
@@ -78,7 +78,7 @@ FilePiece::~FilePiece() {
#endif
}
-StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
+StringPiece FilePiece::ReadLine(char delim) {
size_t skip = 0;
while (true) {
for (const char *i = position_ + skip; i < position_end_; ++i) {
@@ -97,20 +97,20 @@ StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileExcepti
}
}
-float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
+float FilePiece::ReadFloat() {
return ReadNumber<float>();
}
-double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) {
+double FilePiece::ReadDouble() {
return ReadNumber<double>();
}
-long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) {
+long int FilePiece::ReadLong() {
return ReadNumber<long int>();
}
-unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) {
+unsigned long int FilePiece::ReadULong() {
return ReadNumber<unsigned long int>();
}
-void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
+void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) {
#ifdef HAVE_ZLIB
gz_file_ = NULL;
#endif
@@ -163,7 +163,7 @@ void ParseNumber(const char *begin, char *&end, unsigned long int &out) {
}
} // namespace
-template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) {
+template <class T> T FilePiece::ReadNumber() {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
@@ -186,7 +186,7 @@ template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti
return ret;
}
-const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, EndOfFileException) {
+const char *FilePiece::FindDelimiterOrEOF(const bool *delim) {
size_t skip = 0;
while (true) {
for (const char *i = position_ + skip; i < position_end_; ++i) {
@@ -201,7 +201,7 @@ const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException,
}
}
-void FilePiece::Shift() throw(GZException, EndOfFileException) {
+void FilePiece::Shift() {
if (at_end_) {
progress_.Finished();
throw EndOfFileException();
@@ -217,7 +217,7 @@ void FilePiece::Shift() throw(GZException, EndOfFileException) {
}
}
-void FilePiece::MMapShift(off_t desired_begin) throw() {
+void FilePiece::MMapShift(off_t desired_begin) {
// Use mmap.
off_t ignore = desired_begin % page_;
// Duplicate request for Shift means give more data.
@@ -259,25 +259,23 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {
progress_.Set(desired_begin);
}
-void FilePiece::TransitionToRead() throw (GZException) {
+void FilePiece::TransitionToRead() {
assert(!fallback_to_read_);
fallback_to_read_ = true;
data_.reset();
data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
- if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
+ UTIL_THROW_IF(!data_.get(), ErrnoException, "malloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_;
#ifdef HAVE_ZLIB
assert(!gz_file_);
gz_file_ = gzdopen(file_.get(), "r");
- if (!gz_file_) {
- UTIL_THROW(GZException, "zlib failed to open " << file_name_);
- }
+ UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_);
#endif
}
-void FilePiece::ReadShift() throw(GZException, EndOfFileException) {
+void FilePiece::ReadShift() {
assert(fallback_to_read_);
// Bytes [data_.begin(), position_) have been consumed.
// Bytes [position_, position_end_) have been read into the buffer.
@@ -297,7 +295,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {
std::size_t valid_length = position_end_ - position_;
default_map_size_ *= 2;
data_.call_realloc(default_map_size_);
- if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_);
+ UTIL_THROW_IF(!data_.get(), ErrnoException, "realloc failed for " << default_map_size_);
position_ = data_.begin();
position_end_ = position_ + valid_length;
} else {
@@ -320,7 +318,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {
}
#else
read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
- if (read_return == -1) UTIL_THROW(ErrnoException, "read failed");
+ UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed");
progress_.Set(mapped_offset_);
#endif
if (read_return == 0) {
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index 870ae5a3..a5c00910 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -45,13 +45,13 @@ off_t SizeFile(int fd);
class FilePiece {
public:
// 32 MB default.
- explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
+ explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
// Takes ownership of fd. name is used for messages.
- explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
+ explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
~FilePiece();
- char get() throw(GZException, EndOfFileException) {
+ char get() {
if (position_ == position_end_) {
Shift();
if (at_end_) throw EndOfFileException();
@@ -60,22 +60,22 @@ class FilePiece {
}
// Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace().
- StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) {
+ StringPiece ReadDelimited(const bool *delim = kSpaces) {
SkipSpaces(delim);
return Consume(FindDelimiterOrEOF(delim));
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
- StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
+ StringPiece ReadLine(char delim = '\n');
- float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException);
- double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException);
- long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException);
- unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException);
+ float ReadFloat();
+ double ReadDouble();
+ long int ReadLong();
+ unsigned long int ReadULong();
// Skip spaces defined by isspace.
- void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) {
+ void SkipSpaces(const bool *delim = kSpaces) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!delim[static_cast<unsigned char>(*position_)]) return;
@@ -89,9 +89,9 @@ class FilePiece {
const std::string &FileName() const { return file_name_; }
private:
- void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException);
+ void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer);
- template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException);
+ template <class T> T ReadNumber();
StringPiece Consume(const char *to) {
StringPiece ret(position_, to - position_);
@@ -99,14 +99,14 @@ class FilePiece {
return ret;
}
- const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException);
+ const char *FindDelimiterOrEOF(const bool *delim = kSpaces);
- void Shift() throw (EndOfFileException, GZException);
+ void Shift();
// Backends to Shift().
- void MMapShift(off_t desired_begin) throw ();
+ void MMapShift(off_t desired_begin);
- void TransitionToRead() throw (GZException);
- void ReadShift() throw (GZException, EndOfFileException);
+ void TransitionToRead();
+ void ReadShift();
const char *position_, *last_space_, *position_end_;
diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc
index d58a0727..fec47fd9 100644
--- a/klm/util/murmur_hash.cc
+++ b/klm/util/murmur_hash.cc
@@ -1,129 +1,129 @@
-/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
- * code is released to the public domain. For business purposes, Murmurhash is
- * under the MIT license."
- * This is modified from the original:
- * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
- * length changed to unsigned int.
- * placed in namespace util
- * add MurmurHashNative
- * default option = 0 for seed
- */
-
-#include "util/murmur_hash.hh"
-
-namespace util {
-
-//-----------------------------------------------------------------------------
-// MurmurHash2, 64-bit versions, by Austin Appleby
-
-// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
-// and endian-ness issues if used across multiple platforms.
-
-// 64-bit hash for 64-bit platforms
-
-uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
-{
- const uint64_t m = 0xc6a4a7935bd1e995ULL;
- const int r = 47;
-
- uint64_t h = seed ^ (len * m);
-
- const uint64_t * data = (const uint64_t *)key;
- const uint64_t * end = data + (len/8);
-
- while(data != end)
- {
- uint64_t k = *data++;
-
- k *= m;
- k ^= k >> r;
- k *= m;
-
- h ^= k;
- h *= m;
- }
-
- const unsigned char * data2 = (const unsigned char*)data;
-
- switch(len & 7)
- {
- case 7: h ^= uint64_t(data2[6]) << 48;
- case 6: h ^= uint64_t(data2[5]) << 40;
- case 5: h ^= uint64_t(data2[4]) << 32;
- case 4: h ^= uint64_t(data2[3]) << 24;
- case 3: h ^= uint64_t(data2[2]) << 16;
- case 2: h ^= uint64_t(data2[1]) << 8;
- case 1: h ^= uint64_t(data2[0]);
- h *= m;
- };
-
- h ^= h >> r;
- h *= m;
- h ^= h >> r;
-
- return h;
-}
-
-
-// 64-bit hash for 32-bit platforms
-
-uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
-{
- const unsigned int m = 0x5bd1e995;
- const int r = 24;
-
- unsigned int h1 = seed ^ len;
- unsigned int h2 = 0;
-
- const unsigned int * data = (const unsigned int *)key;
-
- while(len >= 8)
- {
- unsigned int k1 = *data++;
- k1 *= m; k1 ^= k1 >> r; k1 *= m;
- h1 *= m; h1 ^= k1;
- len -= 4;
-
- unsigned int k2 = *data++;
- k2 *= m; k2 ^= k2 >> r; k2 *= m;
- h2 *= m; h2 ^= k2;
- len -= 4;
- }
-
- if(len >= 4)
- {
- unsigned int k1 = *data++;
- k1 *= m; k1 ^= k1 >> r; k1 *= m;
- h1 *= m; h1 ^= k1;
- len -= 4;
- }
-
- switch(len)
- {
- case 3: h2 ^= ((unsigned char*)data)[2] << 16;
- case 2: h2 ^= ((unsigned char*)data)[1] << 8;
- case 1: h2 ^= ((unsigned char*)data)[0];
- h2 *= m;
- };
-
- h1 ^= h2 >> 18; h1 *= m;
- h2 ^= h1 >> 22; h2 *= m;
- h1 ^= h2 >> 17; h1 *= m;
- h2 ^= h1 >> 19; h2 *= m;
-
- uint64_t h = h1;
-
- h = (h << 32) | h2;
-
- return h;
-}
-
-uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
- if (sizeof(int) == 4) {
- return MurmurHash64B(key, len, seed);
- } else {
- return MurmurHash64A(key, len, seed);
- }
-}
-
-} // namespace util
+/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
+ * code is released to the public domain. For business purposes, Murmurhash is
+ * under the MIT license."
+ * This is modified from the original:
+ * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
+ * length changed to unsigned int.
+ * placed in namespace util
+ * add MurmurHashNative
+ * default option = 0 for seed
+ */
+
+#include "util/murmur_hash.hh"
+
+namespace util {
+
+//-----------------------------------------------------------------------------
+// MurmurHash2, 64-bit versions, by Austin Appleby
+
+// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
+// and endian-ness issues if used across multiple platforms.
+
+// 64-bit hash for 64-bit platforms
+
+uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
+{
+ const uint64_t m = 0xc6a4a7935bd1e995ULL;
+ const int r = 47;
+
+ uint64_t h = seed ^ (len * m);
+
+ const uint64_t * data = (const uint64_t *)key;
+ const uint64_t * end = data + (len/8);
+
+ while(data != end)
+ {
+ uint64_t k = *data++;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ const unsigned char * data2 = (const unsigned char*)data;
+
+ switch(len & 7)
+ {
+ case 7: h ^= uint64_t(data2[6]) << 48;
+ case 6: h ^= uint64_t(data2[5]) << 40;
+ case 5: h ^= uint64_t(data2[4]) << 32;
+ case 4: h ^= uint64_t(data2[3]) << 24;
+ case 3: h ^= uint64_t(data2[2]) << 16;
+ case 2: h ^= uint64_t(data2[1]) << 8;
+ case 1: h ^= uint64_t(data2[0]);
+ h *= m;
+ };
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+
+
+// 64-bit hash for 32-bit platforms
+
+uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
+{
+ const unsigned int m = 0x5bd1e995;
+ const int r = 24;
+
+ unsigned int h1 = seed ^ len;
+ unsigned int h2 = 0;
+
+ const unsigned int * data = (const unsigned int *)key;
+
+ while(len >= 8)
+ {
+ unsigned int k1 = *data++;
+ k1 *= m; k1 ^= k1 >> r; k1 *= m;
+ h1 *= m; h1 ^= k1;
+ len -= 4;
+
+ unsigned int k2 = *data++;
+ k2 *= m; k2 ^= k2 >> r; k2 *= m;
+ h2 *= m; h2 ^= k2;
+ len -= 4;
+ }
+
+ if(len >= 4)
+ {
+ unsigned int k1 = *data++;
+ k1 *= m; k1 ^= k1 >> r; k1 *= m;
+ h1 *= m; h1 ^= k1;
+ len -= 4;
+ }
+
+ switch(len)
+ {
+ case 3: h2 ^= ((unsigned char*)data)[2] << 16;
+ case 2: h2 ^= ((unsigned char*)data)[1] << 8;
+ case 1: h2 ^= ((unsigned char*)data)[0];
+ h2 *= m;
+ };
+
+ h1 ^= h2 >> 18; h1 *= m;
+ h2 ^= h1 >> 22; h2 *= m;
+ h1 ^= h2 >> 17; h1 *= m;
+ h2 ^= h1 >> 19; h2 *= m;
+
+ uint64_t h = h1;
+
+ h = (h << 32) | h2;
+
+ return h;
+}
+
+uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
+ if (sizeof(int) == 4) {
+ return MurmurHash64B(key, len, seed);
+ } else {
+ return MurmurHash64A(key, len, seed);
+ }
+}
+
+} // namespace util
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 00be0ed7..2ec342a6 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -57,7 +57,7 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
equal_(equal_func),
entries_(0)
#ifdef DEBUG
- , initialized_(true),
+ , initialized_(true)
#endif
{}
diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh
index 84d7aa02..0d6ecbbd 100644
--- a/klm/util/sorted_uniform.hh
+++ b/klm/util/sorted_uniform.hh
@@ -12,7 +12,7 @@ namespace util {
template <class T> class IdentityAccessor {
public:
typedef T Key;
- T operator()(const uint64_t *in) const { return *in; }
+ T operator()(const T *in) const { return *in; }
};
struct Pivot64 {
@@ -101,6 +101,27 @@ template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(co
return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
}
+// May return begin - 1.
+template <class Iterator, class Accessor> Iterator BinaryBelow(
+ const Accessor &accessor,
+ Iterator begin,
+ Iterator end,
+ const typename Accessor::Key key) {
+ while (end > begin) {
+ Iterator pivot(begin + (end - begin) / 2);
+ typename Accessor::Key mid(accessor(pivot));
+ if (mid < key) {
+ begin = pivot + 1;
+ } else if (mid > key) {
+ end = pivot;
+ } else {
+ for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {}
+ return pivot - 1;
+ }
+ }
+ return begin - 1;
+}
+
// To use this template, you need to define a Pivot function to match Key.
template <class PackingT> class SortedUniformMap {
public: