summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
committerKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
commit535d4016ec5179cb673b697c2e81500a2097924c (patch)
tree4ae43b02d23317f37017a93fd12552b55c8d2a06 /klm
parent5dc790adc222db09c25b8be1b7a443a142f70180 (diff)
lazy dd880b4 including kenlm 6eef0f1
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/builder/lmplz_main.cc15
-rw-r--r--klm/lm/builder/ngram.hh2
-rw-r--r--klm/lm/model.cc21
-rw-r--r--klm/lm/model.hh5
-rw-r--r--klm/lm/search_hashed.cc29
-rw-r--r--klm/lm/search_hashed.hh19
-rw-r--r--klm/lm/state.hh2
-rw-r--r--klm/lm/virtual_interface.hh3
-rw-r--r--klm/lm/vocab.hh2
-rw-r--r--klm/search/Makefile.am17
-rw-r--r--klm/search/context.hh12
-rw-r--r--klm/search/edge_generator.cc12
-rw-r--r--klm/search/vertex.cc204
-rw-r--r--klm/search/vertex.hh121
-rw-r--r--klm/search/vertex_generator.hh36
-rw-r--r--klm/util/double-conversion/utils.h6
-rw-r--r--klm/util/file.cc14
-rw-r--r--klm/util/pool.cc4
-rw-r--r--klm/util/probing_hash_table.hh23
-rw-r--r--klm/util/proxy_iterator.hh25
-rw-r--r--klm/util/sized_iterator.hh21
-rw-r--r--klm/util/stream/chain.hh2
-rw-r--r--klm/util/usage.cc15
23 files changed, 418 insertions, 192 deletions
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index 1e086dcc..c87abdb8 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -52,13 +52,14 @@ int main(int argc, char *argv[]) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
- "@inproceedings{kenlm,\n"
- "author = {Kenneth Heafield},\n"
- "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
- "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
- "month = {July}, year={2011},\n"
- "address = {Edinburgh, UK},\n"
- "publisher = {Association for Computational Linguistics},\n"
+ "@inproceedings{Heafield-estimate,\n"
+ " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n"
+ " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n"
+ " year = {2013},\n"
+ " month = {8},\n"
+ " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n"
+ " address = {Sofia, Bulgaria},\n"
+ " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"
"}\n\n"
"Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
"the model (-o) is the only mandatory option. As this is an on-disk program,\n"
diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh
index 2984ed0b..f5681516 100644
--- a/klm/lm/builder/ngram.hh
+++ b/klm/lm/builder/ngram.hh
@@ -53,7 +53,7 @@ class NGram {
Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
uint64_t &Count() { return Value().count; }
- const uint64_t Count() const { return Value().count; }
+ uint64_t Count() const { return Value().count; }
std::size_t Order() const { return end_ - begin_; }
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a40fd2fb..a26654a6 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
+
+base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
+ RecognizeBinary(file_name, model_type);
+ switch (model_type) {
+ case PROBING:
+ return new ProbingModel(file_name, config);
+ case REST_PROBING:
+ return new RestProbingModel(file_name, config);
+ case TRIE:
+ return new TrieModel(file_name, config);
+ case QUANT_TRIE:
+ return new QuantTrieModel(file_name, config);
+ case ARRAY_TRIE:
+ return new ArrayTrieModel(file_name, config);
+ case QUANT_ARRAY_TRIE:
+ return new QuantArrayTrieModel(file_name, config);
+ default:
+ UTIL_THROW(FormatLoadException, "Confused by model type " << model_type);
+ }
+}
+
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 13ff864e..60f55110 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
+/* Autorecognize the file type, load, and return the virtual base class. Don't
+ * use the virtual base class if you can avoid it. Instead, use the above
+ * classes as template arguments to your own virtual feature function.*/
+base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
+
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 2d6f15b2..62275d27 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {
Weights *modify_;
};
-// Find the lower order entry, inserting blanks along the way as necessary.
+// Find the lower order entry, inserting blanks along the way as necessary.
template <class Value> void FindLower(
const std::vector<uint64_t> &keys,
typename Value::Weights &unigram,
@@ -64,7 +64,7 @@ template <class Value> void FindLower(
typename Value::ProbingEntry entry;
// Backoff will always be 0.0. We'll get the probability and rest in another pass.
entry.value.backoff = kNoExtensionBackoff;
- // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
for (int lower = keys.size() - 2; ; --lower) {
if (lower == -1) {
between.push_back(&unigram);
@@ -77,11 +77,11 @@ template <class Value> void FindLower(
}
}
-// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
+// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
template <class Added, class Build> void AdjustLower(
const Added &added,
const Build &build,
- std::vector<typename Build::Value::Weights *> &between,
+ std::vector<typename Build::Value::Weights *> &between,
const unsigned int n,
const std::vector<WordIndex> &vocab_ids,
typename Build::Value::Weights *unigrams,
@@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(
}
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
float prob = -fabs(between.back()->prob);
- // Order of the n-gram on which probabilities are based.
+ // Order of the n-gram on which probabilities are based.
unsigned char basis = n - between.size();
assert(basis != 0);
typename Build::Value::Weights **change = &between.back();
// Skip the basis.
--change;
if (basis == 1) {
- // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
+ // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
float &backoff = unigrams[vocab_ids[1]].backoff;
SetExtension(backoff);
prob += backoff;
@@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(
typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
build.MarkExtends(**i, added);
const typename Value::Weights *longer = *i;
- // Everything has probability but is not marked as extending.
+ // Everything has probability but is not marked as extending.
for (++i; i != between.end(); ++i) {
build.MarkExtends(**i, *longer);
longer = *i;
}
}
-// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
+// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
template <class Build> void MarkLower(
const std::vector<uint64_t> &keys,
const Build &build,
@@ -144,15 +144,15 @@ template <class Build> void MarkLower(
int start_order,
const typename Build::Value::Weights &longer) {
if (start_order == 0) return;
- typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
- // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
+ // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
if (even_lower == -1) {
build.MarkExtends(unigram, longer);
return;
}
- middle[even_lower].UnsafeMutableFind(keys[even_lower], iter);
- if (!build.MarkExtends(iter->value, longer)) return;
+ if (!build.MarkExtends(
+ middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
+ longer)) return;
}
}
@@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(
Store &store,
PositiveProbWarn &warn) {
typedef typename Build::Value Value;
- typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
assert(n >= 2);
ReadNGramHeader(f, n);
@@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
- // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
+ // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(entry.value.prob);
entry.key = keys[n-2];
@@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
} // namespace
namespace detail {
-
+
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram_ = Unigram(start, counts[0], allocated);
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 00595796..9d067bc2 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -71,7 +71,7 @@ template <class Value> class HashedSearch {
static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
- // TODO: move probing_multiplier here with next binary file format update.
+ // TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -102,14 +102,9 @@ template <class Value> class HashedSearch {
return ret;
}
-#pragma GCC diagnostic ignored "-Wuninitialized"
MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
node = extend_pointer;
- typename Middle::ConstIterator found;
- bool got = middle_[extend_length - 2].Find(extend_pointer, found);
- assert(got);
- (void)got;
- return MiddlePointer(found->value);
+ return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
}
MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
@@ -126,14 +121,14 @@ template <class Value> class HashedSearch {
}
LongestPointer LookupLongest(WordIndex word, const Node &node) const {
- // Sign bit is always on because longest n-grams do not extend left.
+ // Sign bit is always on because longest n-grams do not extend left.
typename Longest::ConstIterator found;
if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
return LongestPointer(found->value.prob);
}
- // Generate a node without necessarily checking that it actually exists.
- // Optionally return false if it's know to not exist.
+ // Generate a node without necessarily checking that it actually exists.
+ // Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
@@ -144,7 +139,7 @@ template <class Value> class HashedSearch {
}
private:
- // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
+ // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
@@ -153,7 +148,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index d8e6c132..a6b9accb 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {
}
struct ChartState {
- bool operator==(const ChartState &other) {
+ bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh
index 6a5a0196..17f064b2 100644
--- a/klm/lm/virtual_interface.hh
+++ b/klm/lm/virtual_interface.hh
@@ -6,6 +6,7 @@
#include "util/string_piece.hh"
#include <string>
+#include <string.h>
namespace lm {
namespace base {
@@ -119,7 +120,9 @@ class Model {
size_t StateSize() const { return state_size_; }
const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
+ void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }
const void *NullContextMemory() const { return null_context_memory_; }
+ void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 3902f117..226ae438 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len);
inline uint64_t HashForVocab(const StringPiece &str) {
return HashForVocab(str.data(), str.length());
}
-class ProbingVocabularyHeader;
+struct ProbingVocabularyHeader;
} // namespace detail
class WriteWordsWrapper : public EnumerateVocab {
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
index 03554276..b8c8a050 100644
--- a/klm/search/Makefile.am
+++ b/klm/search/Makefile.am
@@ -1,23 +1,10 @@
noinst_LIBRARIES = libksearch.a
libksearch_a_SOURCES = \
- applied.hh \
- config.hh \
- context.hh \
- dedupe.hh \
- edge.hh \
- edge_generator.hh \
- header.hh \
- nbest.hh \
- rule.hh \
- types.hh \
- vertex.hh \
- vertex_generator.hh \
edge_generator.cc \
nbest.cc \
rule.cc \
- vertex.cc \
- vertex_generator.cc
+ vertex.cc
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/search/context.hh b/klm/search/context.hh
index 08f21bbf..c3c8e53b 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -12,16 +12,6 @@ class ContextBase {
public:
explicit ContextBase(const Config &config) : config_(config) {}
- VertexNode *NewVertexNode() {
- VertexNode *ret = vertex_node_pool_.construct();
- assert(ret);
- return ret;
- }
-
- void DeleteVertexNode(VertexNode *node) {
- vertex_node_pool_.destroy(node);
- }
-
unsigned int PopLimit() const { return config_.PopLimit(); }
Score LMWeight() const { return config_.LMWeight(); }
@@ -29,8 +19,6 @@ class ContextBase {
const Config &GetConfig() const { return config_; }
private:
- boost::object_pool<VertexNode> vertex_node_pool_;
-
Config config_;
};
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index eacf5de5..dd9d61e4 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
Arity victim = 0;
Arity victim_completed;
Arity incomplete;
+ unsigned char lowest_niceness = 255;
// Select victim or return if complete.
{
Arity completed = 0;
- unsigned char lowest_length = 255;
for (Arity i = 0; i != arity; ++i) {
if (top_nt[i].Complete()) {
++completed;
- } else if (top_nt[i].Length() < lowest_length) {
- lowest_length = top_nt[i].Length();
+ } else if (top_nt[i].Niceness() < lowest_niceness) {
+ lowest_niceness = top_nt[i].Niceness();
victim = i;
victim_completed = completed;
}
}
- if (lowest_length == 255) {
+ if (lowest_niceness == 255) {
return top;
}
incomplete = arity - completed;
@@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
generate_.push(alternate);
}
+#ifndef NDEBUG
+ Score before = top.GetScore();
+#endif
// top is now the continuation.
FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
// TODO: dedupe?
generate_.push(top);
+ assert(lowest_niceness != 254 || top.GetScore() == before);
// Invalid indicates no new hypothesis generated.
return PartialEdge();
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
index 45842982..bf40810e 100644
--- a/klm/search/vertex.cc
+++ b/klm/search/vertex.cc
@@ -2,6 +2,8 @@
#include "search/context.hh"
+#include <boost/unordered_map.hpp>
+
#include <algorithm>
#include <functional>
@@ -11,45 +13,193 @@ namespace search {
namespace {
-struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
- bool operator()(const VertexNode *first, const VertexNode *second) const {
- return first->Bound() > second->Bound();
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+
+class DivideLeft {
+ public:
+ explicit DivideLeft(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.left.length) ?
+ state.left.pointers[index_] :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+class DivideRight {
+ public:
+ explicit DivideRight(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.right.length) ?
+ static_cast<uint64_t>(state.right.words[index_]) :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+template <class Divider> void Split(const Divider &divider, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
+ // Map from divider to index in extend.
+ typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
+ Lookup lookup;
+ for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
+ uint64_t key = divider(i->state);
+ std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
+ if (res.second) {
+ extend.resize(extend.size() + 1);
+ extend.back().AppendHypothesis(*i);
+ } else {
+ extend[res.first->second].AppendHypothesis(*i);
+ }
}
+ //assert((extend.size() != 1) || (hypos.size() == 1));
+}
+
+lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
+ return right.words[index];
+}
+
+uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
+ return left.pointers[index];
+}
+
+template <class Side> class DetermineSame {
+ public:
+ DetermineSame(const Side &side, unsigned char guaranteed)
+ : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
+
+ void Consider(const Side &other) {
+ if (shared_ != other.length) {
+ complete_ = false;
+ if (shared_ > other.length)
+ shared_ = other.length;
+ }
+ for (unsigned char i = guaranteed_; i < shared_; ++i) {
+ if (Identify(side_, i) != Identify(other, i)) {
+ shared_ = i;
+ complete_ = false;
+ return;
+ }
+ }
+ }
+
+ unsigned char Shared() const { return shared_; }
+
+ bool Complete() const { return complete_; }
+
+ private:
+ const Side &side_;
+ unsigned char guaranteed_, shared_;
+ bool complete_;
};
+// Custom enum to save memory: valid values of policy_.
+// Alternate and there is still alternation to do.
+const unsigned char kPolicyAlternate = 0;
+// Branch based on left state only, because right ran out or this is a left tree.
+const unsigned char kPolicyOneLeft = 1;
+// Branch based on right state only.
+const unsigned char kPolicyOneRight = 2;
+// Reveal everything in the next branch. Used to terminate the left/right policies.
+// static const unsigned char kPolicyEverything = 3;
+
+} // namespace
+
+namespace {
+struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
+ bool operator()(const HypoState &first, const HypoState &second) const {
+ return first.score > second.score;
+ }
+};
} // namespace
-void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
- if (Complete()) {
- assert(end_);
- assert(extend_.empty());
- return;
+void VertexNode::FinishRoot() {
+ std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
+ extend_.clear();
+ // HACK: extend to one hypo so that root can be blank.
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ niceness_ = 0;
+ policy_ = kPolicyAlternate;
+ if (hypos_.size() == 1) {
+ extend_.resize(1);
+ extend_.front().AppendHypothesis(hypos_.front());
+ extend_.front().FinishedAppending(0, 0);
+ }
+ if (hypos_.empty()) {
+ bound_ = -INFINITY;
+ } else {
+ bound_ = hypos_.front().score;
}
- if (extend_.size() == 1) {
- parent_ptr = extend_[0];
- extend_[0]->RecursiveSortAndSet(context, parent_ptr);
- context.DeleteVertexNode(this);
- return;
+}
+
+void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
+ assert(!hypos_.empty());
+ assert(extend_.empty());
+ bound_ = hypos_.front().score;
+ state_ = hypos_.front().state;
+ bool all_full = state_.left.full;
+ bool all_non_full = !state_.left.full;
+ DetermineSame<lm::ngram::Left> left(state_.left, common_left);
+ DetermineSame<lm::ngram::Right> right(state_.right, common_right);
+ for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
+ all_full &= i->state.left.full;
+ all_non_full &= !i->state.left.full;
+ left.Consider(i->state.left);
+ right.Consider(i->state.right);
}
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ state_.left.full = all_full && left.Complete();
+ right_full_ = all_full && right.Complete();
+ state_.left.length = left.Shared();
+ state_.right.length = right.Shared();
+
+ if (!all_full && !all_non_full) {
+ policy_ = kPolicyAlternate;
+ } else if (left.Complete()) {
+ policy_ = kPolicyOneRight;
+ } else if (right.Complete()) {
+ policy_ = kPolicyOneLeft;
+ } else {
+ policy_ = kPolicyAlternate;
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
+ niceness_ = state_.left.length + state_.right.length;
}
-void VertexNode::SortAndSet(ContextBase &context) {
- // This is the root. The root might be empty.
- if (extend_.empty()) {
- bound_ = -INFINITY;
- return;
+void VertexNode::BuildExtend() {
+ // Already built.
+ if (!extend_.empty()) return;
+ // Nothing to build since this is a leaf.
+ if (hypos_.size() <= 1) return;
+ bool left_branch = true;
+ switch (policy_) {
+ case kPolicyAlternate:
+ left_branch = (state_.left.length <= state_.right.length);
+ break;
+ case kPolicyOneLeft:
+ left_branch = true;
+ break;
+ case kPolicyOneRight:
+ left_branch = false;
+ break;
+ }
+ if (left_branch) {
+ Split(DivideLeft(state_.left.length), hypos_, extend_);
+ } else {
+ Split(DivideRight(state_.right.length), hypos_, extend_);
}
- // The root cannot be replaced. There's always one transition.
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ // TODO: provide more here for branching?
+ i->FinishedAppending(state_.left.length, state_.right.length);
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
}
} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index ca9a4fcd..81c3cfed 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -16,59 +16,74 @@ namespace search {
class ContextBase;
+struct HypoState {
+ History history;
+ lm::ngram::ChartState state;
+ Score score;
+};
+
class VertexNode {
public:
- VertexNode() : end_() {}
-
- void InitRoot() {
- extend_.clear();
- state_.left.full = false;
- state_.left.length = 0;
- state_.right.length = 0;
- right_full_ = false;
- end_ = History();
+ VertexNode() {}
+
+ void InitRoot() { hypos_.clear(); }
+
+ /* The steps of building a VertexNode:
+ * 1. Default construct.
+ * 2. AppendHypothesis at least once, possibly multiple times.
+ * 3. FinishAppending with the number of words on left and right guaranteed
+ * to be common.
+ * 4. If !Complete(), call BuildExtend to construct the extensions
+ */
+ // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
+ void AppendHypothesis(const NBestComplete &best) {
+ assert(hypos_.empty() || !(hypos_.front().state == *best.state));
+ HypoState hypo;
+ hypo.history = best.history;
+ hypo.state = *best.state;
+ hypo.score = best.score;
+ hypos_.push_back(hypo);
+ }
+ void AppendHypothesis(const HypoState &hypo) {
+ hypos_.push_back(hypo);
}
- lm::ngram::ChartState &MutableState() { return state_; }
- bool &MutableRightFull() { return right_full_; }
+ // Sort hypotheses for the root.
+ void FinishRoot();
- void AddExtend(VertexNode *next) {
- extend_.push_back(next);
- }
+ void FinishedAppending(const unsigned char common_left, const unsigned char common_right);
- void SetEnd(History end, Score score) {
- assert(!end_);
- end_ = end;
- bound_ = score;
- }
-
- void SortAndSet(ContextBase &context);
+ void BuildExtend();
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return hypos_.empty() && extend_.empty();
}
bool Complete() const {
- return end_;
+ // HACK: prevent root from being complete. TODO: allow root to be complete.
+ return hypos_.size() == 1 && extend_.empty();
}
const lm::ngram::ChartState &State() const { return state_; }
bool RightFull() const { return right_full_; }
+ // Priority relative to other non-terminals. 0 is highest.
+ unsigned char Niceness() const { return niceness_; }
+
Score Bound() const {
return bound_;
}
- unsigned char Length() const {
- return state_.left.length + state_.right.length;
- }
-
// Will be invalid unless this is a leaf.
- History End() const { return end_; }
+ History End() const {
+ assert(hypos_.size() == 1);
+ return hypos_.front().history;
+ }
- const VertexNode &operator[](size_t index) const {
- return *extend_[index];
+ VertexNode &operator[](size_t index) {
+ assert(!extend_.empty());
+ return extend_[index];
}
size_t Size() const {
@@ -76,22 +91,26 @@ class VertexNode {
}
private:
- void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+ // Hypotheses to be split.
+ std::vector<HypoState> hypos_;
- std::vector<VertexNode*> extend_;
+ std::vector<VertexNode> extend_;
lm::ngram::ChartState state_;
bool right_full_;
+ unsigned char niceness_;
+
+ unsigned char policy_;
+
Score bound_;
- History end_;
};
class PartialVertex {
public:
PartialVertex() {}
- explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+ explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}
bool Empty() const { return back_->Empty(); }
@@ -100,17 +119,14 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
-
- unsigned char Length() const { return back_->Length(); }
+ Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }
- bool HasAlternative() const {
- return index_ + 1 < back_->Size();
- }
+ unsigned char Niceness() const { return back_->Niceness(); }
// Split into continuation and alternative, rendering this the continuation.
bool Split(PartialVertex &alternative) {
assert(!Complete());
+ back_->BuildExtend();
bool ret;
if (index_ + 1 < back_->Size()) {
alternative.index_ = index_ + 1;
@@ -129,7 +145,7 @@ class PartialVertex {
}
private:
- const VertexNode *back_;
+ VertexNode *back_;
unsigned int index_;
};
@@ -139,10 +155,21 @@ class Vertex {
public:
Vertex() {}
- PartialVertex RootPartial() const { return PartialVertex(root_); }
+ //PartialVertex RootFirst() const { return PartialVertex(right_); }
+ PartialVertex RootAlternate() { return PartialVertex(root_); }
+ //PartialVertex RootLast() const { return PartialVertex(left_); }
+
+ bool Empty() const {
+ return root_.Empty();
+ }
+
+ Score Bound() const {
+ return root_.Bound();
+ }
- History BestChild() const {
- PartialVertex top(RootPartial());
+ History BestChild() {
+ // left_ and right_ are not set at the root.
+ PartialVertex top(RootAlternate());
if (top.Empty()) {
return History();
} else {
@@ -158,6 +185,12 @@ class Vertex {
template <class Output> friend class VertexGenerator;
template <class Output> friend class RootVertexGenerator;
VertexNode root_;
+
+ // These will not be set for the root vertex.
+ // Branches only on left state.
+ //VertexNode left_;
+ // Branches only on right state.
+ //VertexNode right_;
};
} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 646b8189..91000012 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -4,10 +4,8 @@
#include "search/edge.hh"
#include "search/types.hh"
#include "search/vertex.hh"
-#include "util/exception.hh"
#include <boost/unordered_map.hpp>
-#include <boost/version.hpp>
namespace lm {
namespace ngram {
@@ -19,45 +17,25 @@ namespace search {
class ContextBase;
-#if BOOST_VERSION > 104200
-// Parallel structure to VertexNode.
-struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
-};
-
-void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
-
-#endif // BOOST_VERSION
-
// Output makes the single-best or n-best list.
template <class Output> class VertexGenerator {
public:
- VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
- gen.root_.InitRoot();
- }
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}
void NewHypothesis(PartialEdge partial) {
nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
void FinishedSearch() {
-#if BOOST_VERSION > 104200
- Trie root;
- root.under = &gen_.root_;
+ gen_.root_.InitRoot();
for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
- AddHypothesis(context_, root, nbest_.Complete(i->second));
+ gen_.root_.AppendHypothesis(nbest_.Complete(i->second));
}
existing_.clear();
- root.under->SortAndSet(context_);
-#else
- UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
-#endif
+ gen_.root_.FinishRoot();
}
- const Vertex &Generating() const { return gen_; }
+ Vertex &Generating() { return gen_; }
private:
ContextBase &context_;
@@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator {
void FinishedSearch() {
gen_.root_.InitRoot();
- NBestComplete completed(out_.Complete(combine_));
- gen_.root_.SetEnd(completed.history, completed.score);
+ gen_.root_.AppendHypothesis(out_.Complete(combine_));
+ gen_.root_.FinishRoot();
}
private:
diff --git a/klm/util/double-conversion/utils.h b/klm/util/double-conversion/utils.h
index 2bd71605..9ccb3b65 100644
--- a/klm/util/double-conversion/utils.h
+++ b/klm/util/double-conversion/utils.h
@@ -299,7 +299,11 @@ template <class Dest, class Source>
inline Dest BitCast(const Source& source) {
// Compile time assertion: sizeof(Dest) == sizeof(Source)
// A compile error here means your Dest and Source have different sizes.
- typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1];
+ typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1]
+#if __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 8
+ __attribute__((unused))
+#endif
+ ;
Dest dest;
memmove(&dest, &source, sizeof(dest));
diff --git a/klm/util/file.cc b/klm/util/file.cc
index c7d8e23b..bef04cb1 100644
--- a/klm/util/file.cc
+++ b/klm/util/file.cc
@@ -116,7 +116,7 @@ std::size_t GuardLarge(std::size_t size) {
// The following operating systems have broken read/write/pread/pwrite that
// only supports up to 2^31.
#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) || defined(OS_ANDROID)
- return std::min(static_cast<std::size_t>(INT_MAX), size);
+ return std::min(static_cast<std::size_t>(static_cast<unsigned>(-1)), size);
#else
return size;
#endif
@@ -209,7 +209,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
#endif
errno = 0;
do {
- ret =
+ ret =
#if defined(_WIN32) || defined(_WIN64)
_write
#else
@@ -229,7 +229,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
}
void FSyncOrThrow(int fd) {
-// Apparently windows doesn't have fsync?
+// Apparently windows doesn't have fsync?
#if !defined(_WIN32) && !defined(_WIN64)
UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "while syncing");
#endif
@@ -248,7 +248,7 @@ template <> struct CheckOffT<8> {
typedef CheckOffT<sizeof(off_t)>::True IgnoredType;
#endif
-// Can't we all just get along?
+// Can't we all just get along?
void InternalSeek(int fd, int64_t off, int whence) {
if (
#if defined(_WIN32) || defined(_WIN64)
@@ -457,9 +457,9 @@ bool TryName(int fd, std::string &out) {
std::ostringstream convert;
convert << fd;
name += convert.str();
-
+
struct stat sb;
- if (-1 == lstat(name.c_str(), &sb))
+ if (-1 == lstat(name.c_str(), &sb))
return false;
out.resize(sb.st_size + 1);
ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1);
@@ -471,7 +471,7 @@ bool TryName(int fd, std::string &out) {
}
out.resize(ret);
// Don't use the non-file names.
- if (!out.empty() && out[0] != '/')
+ if (!out.empty() && out[0] != '/')
return false;
return true;
#endif
diff --git a/klm/util/pool.cc b/klm/util/pool.cc
index 429ba158..db72a8ec 100644
--- a/klm/util/pool.cc
+++ b/klm/util/pool.cc
@@ -25,7 +25,9 @@ void Pool::FreeAll() {
}
void *Pool::More(std::size_t size) {
- std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size);
+ // Double until we hit 2^21 (2 MB). Then grow in 2 MB blocks.
+ std::size_t desired_size = static_cast<size_t>(32) << std::min(static_cast<std::size_t>(16), free_list_.size());
+ std::size_t amount = std::max(desired_size, size);
uint8_t *ret = static_cast<uint8_t*>(MallocOrThrow(amount));
free_list_.push_back(ret);
current_ = ret + size;
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 57866ff9..51a2944d 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -109,9 +109,20 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
if (equal_(got, key)) { out = i; return true; }
if (equal_(got, invalid_)) return false;
if (++i == end_) i = begin_;
- }
+ }
+ }
+
+ // Like UnsafeMutableFind, but the key must be there.
+ template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
+ for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
+ Key got(i->GetKey());
+ if (equal_(got, key)) { return i; }
+ assert(!equal_(got, invalid_));
+ if (++i == end_) i = begin_;
+ }
}
+
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
assert(initialized_);
@@ -124,6 +135,16 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
}
+ // Like Find but we're sure it must be there.
+ template <class Key> ConstIterator MustFind(const Key key) const {
+ for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
+ Key got(i->GetKey());
+ if (equal_(got, key)) { return i; }
+ assert(!equal_(got, invalid_));
+ if (++i == end_) i = begin_;
+ }
+ }
+
void Clear() {
Entry invalid;
invalid.SetKey(invalid_);
diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh
index 121a45fa..0ee1716f 100644
--- a/klm/util/proxy_iterator.hh
+++ b/klm/util/proxy_iterator.hh
@@ -6,11 +6,11 @@
/* This is a RandomAccessIterator that uses a proxy to access the underlying
* data. Useful for packing data at bit offsets but still using STL
- * algorithms.
+ * algorithms.
*
* Normally I would use boost::iterator_facade but some people are too lazy to
* install boost and still want to use my language model. It's amazing how
- * many operators an iterator has.
+ * many operators an iterator has.
*
* The Proxy needs to provide:
* class InnerIterator;
@@ -22,15 +22,15 @@
* operator<(InnerIterator)
* operator+=(std::ptrdiff_t)
* operator-(InnerIterator)
- * and of course whatever Proxy needs to dereference it.
+ * and of course whatever Proxy needs to dereference it.
*
- * It's also a good idea to specialize std::swap for Proxy.
+ * It's also a good idea to specialize std::swap for Proxy.
*/
namespace util {
template <class Proxy> class ProxyIterator {
private:
- // Self.
+ // Self.
typedef ProxyIterator<Proxy> S;
typedef typename Proxy::InnerIterator InnerIterator;
@@ -38,16 +38,21 @@ template <class Proxy> class ProxyIterator {
typedef std::random_access_iterator_tag iterator_category;
typedef typename Proxy::value_type value_type;
typedef std::ptrdiff_t difference_type;
- typedef Proxy reference;
+ typedef Proxy & reference;
typedef Proxy * pointer;
ProxyIterator() {}
- // For cast from non const to const.
+ // For cast from non const to const.
template <class AlternateProxy> ProxyIterator(const ProxyIterator<AlternateProxy> &in) : p_(*in) {}
explicit ProxyIterator(const Proxy &p) : p_(p) {}
- // p_'s operator= does value copying, but here we want iterator copying.
+ // p_'s swap does value swapping, but here we want iterator swapping
+ friend inline void swap(ProxyIterator<Proxy> &first, ProxyIterator<Proxy> &second) {
+ swap(first.I(), second.I());
+ }
+
+ // p_'s operator= does value copying, but here we want iterator copying.
S &operator=(const S &other) {
I() = other.I();
return *this;
@@ -72,8 +77,8 @@ template <class Proxy> class ProxyIterator {
std::ptrdiff_t operator-(const S &other) const { return I() - other.I(); }
- Proxy operator*() { return p_; }
- const Proxy operator*() const { return p_; }
+ Proxy &operator*() { return p_; }
+ const Proxy &operator*() const { return p_; }
Proxy *operator->() { return &p_; }
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh
index cf998953..dce8f229 100644
--- a/klm/util/sized_iterator.hh
+++ b/klm/util/sized_iterator.hh
@@ -36,6 +36,11 @@ class SizedInnerIterator {
void *Data() { return ptr_; }
std::size_t EntrySize() const { return size_; }
+ friend inline void swap(SizedInnerIterator &first, SizedInnerIterator &second) {
+ std::swap(first.ptr_, second.ptr_);
+ std::swap(first.size_, second.size_);
+ }
+
private:
uint8_t *ptr_;
std::size_t size_;
@@ -64,9 +69,19 @@ class SizedProxy {
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
+ /**
+ // TODO: this (deep) swap was recently added. why? if any std heap sort etc
+ // algs are using swap, that's going to be worse performance than using
+ // =. i'm not sure why we *want* a deep swap. if C++11 compilers are
+ // choosing between move constructor and swap, then we'd better implement a
+ // (deep) move constructor. it may also be that this is moot since i made
+ // ProxyIterator a reference and added a shallow ProxyIterator swap? (I
+ // need Ken or someone competent to judge whether that's correct also. -
+ // let me know at graehl@gmail.com
+ */
friend void swap(SizedProxy &first, SizedProxy &second) {
std::swap_ranges(
- static_cast<char*>(first.inner_.Data()),
+ static_cast<char*>(first.inner_.Data()),
static_cast<char*>(first.inner_.Data()) + first.inner_.EntrySize(),
static_cast<char*>(second.inner_.Data()));
}
@@ -87,7 +102,7 @@ typedef ProxyIterator<SizedProxy> SizedIterator;
inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); }
-// Useful wrapper for a comparison function i.e. sort.
+// Useful wrapper for a comparison function i.e. sort.
template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public std::binary_function<const Proxy &, const Proxy &, bool> {
public:
explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {}
@@ -106,7 +121,7 @@ template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public
}
const Delegate &GetDelegate() const { return delegate_; }
-
+
private:
const Delegate delegate_;
};
diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh
index 154b9b33..0cc83a85 100644
--- a/klm/util/stream/chain.hh
+++ b/klm/util/stream/chain.hh
@@ -122,7 +122,7 @@ class Chain {
threads_.push_back(new Thread(Complete(), kRecycle));
}
- Chain &operator>>(const Recycler &recycle) {
+ Chain &operator>>(const Recycler &) {
CompleteLoop();
return *this;
}
diff --git a/klm/util/usage.cc b/klm/util/usage.cc
index 5fa3cc9a..8db375e1 100644
--- a/klm/util/usage.cc
+++ b/klm/util/usage.cc
@@ -21,6 +21,21 @@ namespace util {
#if !defined(_WIN32) && !defined(_WIN64)
namespace {
+
+// On Mac OS X, clock_gettime is not implemented.
+// CLOCK_MONOTONIC is not defined either.
+#ifdef __MACH__
+#define CLOCK_MONOTONIC 0
+
+int clock_gettime(int clk_id, struct timespec *tp) {
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ tp->tv_sec = tv.tv_sec;
+ tp->tv_nsec = tv.tv_usec * 1000;
+ return 0;
+}
+#endif // __MACH__
+
float FloatSec(const struct timeval &tv) {
return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);
}