summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
Diffstat (limited to 'klm')
-rw-r--r--klm/alone/Jamfile4
-rw-r--r--klm/alone/assemble.cc76
-rw-r--r--klm/alone/assemble.hh21
-rw-r--r--klm/alone/graph.hh87
-rw-r--r--klm/alone/just_vocab.cc14
-rw-r--r--klm/alone/labeled_edge.hh30
-rw-r--r--klm/alone/main.cc85
-rw-r--r--klm/alone/read.cc118
-rw-r--r--klm/alone/read.hh29
-rw-r--r--klm/alone/threading.cc80
-rw-r--r--klm/alone/threading.hh129
-rw-r--r--klm/alone/vocab.cc19
-rw-r--r--klm/alone/vocab.hh34
-rw-r--r--klm/lm/model.cc2
-rw-r--r--klm/lm/vocab.cc4
-rw-r--r--klm/lm/vocab.hh2
-rw-r--r--klm/search/Jamfile2
-rw-r--r--klm/search/Makefile.am11
-rw-r--r--klm/search/arity.hh8
-rw-r--r--klm/search/context.hh10
-rw-r--r--klm/search/edge.hh53
-rw-r--r--klm/search/edge_generator.cc144
-rw-r--r--klm/search/edge_generator.hh49
-rw-r--r--klm/search/edge_queue.cc25
-rw-r--r--klm/search/edge_queue.hh73
-rw-r--r--klm/search/final.hh41
-rw-r--r--klm/search/header.hh57
-rw-r--r--klm/search/source.hh48
-rw-r--r--klm/search/types.hh8
-rw-r--r--klm/search/vertex.cc10
-rw-r--r--klm/search/vertex.hh55
-rw-r--r--klm/search/vertex_generator.cc97
-rw-r--r--klm/search/vertex_generator.hh33
-rw-r--r--klm/util/Makefile.am2
-rw-r--r--klm/util/ersatz_progress.hh2
-rw-r--r--klm/util/exception.hh2
-rw-r--r--klm/util/pool.cc35
-rw-r--r--klm/util/pool.hh45
-rw-r--r--klm/util/probing_hash_table.hh2
-rw-r--r--klm/util/string_piece.cc192
-rw-r--r--klm/util/tokenize_piece.hh12
41 files changed, 611 insertions, 1139 deletions
diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile
deleted file mode 100644
index 2cc90c05..00000000
--- a/klm/alone/Jamfile
+++ /dev/null
@@ -1,4 +0,0 @@
-lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : <include>.. : : <include>.. <library>../search//search <library>../lm//kenlm ;
-
-exe decode : main.cc standalone main.cc : <threading>multi:<library>..//boost_thread ;
-exe just_vocab : just_vocab.cc standalone : <threading>multi:<library>..//boost_thread ;
diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc
deleted file mode 100644
index 2ae72ce9..00000000
--- a/klm/alone/assemble.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-#include "alone/assemble.hh"
-
-#include "alone/labeled_edge.hh"
-#include "search/final.hh"
-
-#include <iostream>
-
-namespace alone {
-
-std::ostream &operator<<(std::ostream &o, const search::Final &final) {
- const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words();
- if (words.empty()) return o;
- const search::Final *const *child = final.Children().data();
- std::vector<const std::string*>::const_iterator i(words.begin());
- for (; i != words.end() - 1; ++i) {
- if (*i) {
- o << **i << ' ';
- } else {
- o << **child << ' ';
- ++child;
- }
- }
-
- if (*i) {
- if (**i != "</s>") {
- o << **i;
- }
- } else {
- o << **child;
- }
-
- return o;
-}
-
-namespace {
-
-void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) {
- for (unsigned int i = 0; i < level; ++i)
- o << indent_str;
-}
-
-void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) {
- o << "(\n";
- MakeIndent(o, indent_str, indent);
- const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words();
- const search::Final *const *child = final.Children().data();
- for (std::vector<const std::string*>::const_iterator i(words.begin()); i != words.end(); ++i) {
- if (*i) {
- o << **i;
- if (i == words.end() - 1) {
- o << '\n';
- MakeIndent(o, indent_str, indent);
- } else {
- o << ' ';
- }
- } else {
- // One extra indent from the line we're currently on.
- o << indent_str;
- DetailedFinalInternal(o, **child, indent_str, indent + 1);
- for (unsigned int i = 0; i < indent; ++i) o << indent_str;
- ++child;
- }
- }
- o << ")=" << final.Bound() << '\n';
-}
-} // namespace
-
-void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) {
- DetailedFinalInternal(o, final, indent_str, 0);
-}
-
-void PrintFinal(const search::Final &final) {
- std::cout << final << std::endl;
-}
-
-} // namespace alone
diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh
deleted file mode 100644
index e6b0ad5c..00000000
--- a/klm/alone/assemble.hh
+++ /dev/null
@@ -1,21 +0,0 @@
-#ifndef ALONE_ASSEMBLE__
-#define ALONE_ASSEMBLE__
-
-#include <iosfwd>
-
-namespace search {
-class Final;
-} // namespace search
-
-namespace alone {
-
-std::ostream &operator<<(std::ostream &o, const search::Final &final);
-
-void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " ");
-
-// This isn't called anywhere but makes it easy to print from gdb.
-void PrintFinal(const search::Final &final);
-
-} // namespace alone
-
-#endif // ALONE_ASSEMBLE__
diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh
deleted file mode 100644
index 788352c9..00000000
--- a/klm/alone/graph.hh
+++ /dev/null
@@ -1,87 +0,0 @@
-#ifndef ALONE_GRAPH__
-#define ALONE_GRAPH__
-
-#include "alone/labeled_edge.hh"
-#include "search/rule.hh"
-#include "search/types.hh"
-#include "search/vertex.hh"
-#include "util/exception.hh"
-
-#include <boost/noncopyable.hpp>
-#include <boost/pool/object_pool.hpp>
-#include <boost/scoped_array.hpp>
-
-namespace alone {
-
-template <class T> class FixedAllocator : boost::noncopyable {
- public:
- FixedAllocator() : current_(NULL), end_(NULL) {}
-
- void Init(std::size_t count) {
- assert(!current_);
- array_.reset(new T[count]);
- current_ = array_.get();
- end_ = current_ + count;
- }
-
- T &operator[](std::size_t idx) {
- return array_.get()[idx];
- }
-
- T *New() {
- T *ret = current_++;
- UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end");
- return ret;
- }
-
- std::size_t Size() const {
- return end_ - array_.get();
- }
-
- private:
- boost::scoped_array<T> array_;
- T *current_, *end_;
-};
-
-class Graph : boost::noncopyable {
- public:
- typedef LabeledEdge Edge;
- typedef search::Vertex Vertex;
-
- Graph() {}
-
- void SetCounts(std::size_t vertices, std::size_t edges) {
- vertices_.Init(vertices);
- edges_.Init(edges);
- }
-
- Vertex *NewVertex() {
- return vertices_.New();
- }
-
- std::size_t VertexSize() const { return vertices_.Size(); }
-
- Vertex &MutableVertex(std::size_t index) {
- return vertices_[index];
- }
-
- Edge *NewEdge() {
- return edges_.New();
- }
-
- std::size_t EdgeSize() const { return edges_.Size(); }
-
- void SetRoot(Vertex *root) { root_ = root; }
-
- Vertex &Root() { return *root_; }
-
- private:
- FixedAllocator<Vertex> vertices_;
- FixedAllocator<Edge> edges_;
-
- Vertex *root_;
-};
-
-} // namespace alone
-
-#endif // ALONE_GRAPH__
diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc
deleted file mode 100644
index 35aea5ed..00000000
--- a/klm/alone/just_vocab.cc
+++ /dev/null
@@ -1,14 +0,0 @@
-#include "alone/read.hh"
-#include "util/file_piece.hh"
-
-#include <iostream>
-
-int main() {
- util::FilePiece f(0, "stdin", &std::cerr);
- while (true) {
- try {
- alone::JustVocab(f, std::cout);
- } catch (const util::EndOfFileException &e) { break; }
- std::cout << '\n';
- }
-}
diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh
deleted file mode 100644
index 94d8cbdf..00000000
--- a/klm/alone/labeled_edge.hh
+++ /dev/null
@@ -1,30 +0,0 @@
-#ifndef ALONE_LABELED_EDGE__
-#define ALONE_LABELED_EDGE__
-
-#include "search/edge.hh"
-
-#include <string>
-#include <vector>
-
-namespace alone {
-
-class LabeledEdge : public search::Edge {
- public:
- LabeledEdge() {}
-
- void AppendWord(const std::string *word) {
- words_.push_back(word);
- }
-
- const std::vector<const std::string *> &Words() const {
- return words_;
- }
-
- private:
- // NULL for non-terminals.
- std::vector<const std::string*> words_;
-};
-
-} // namespace alone
-
-#endif // ALONE_LABELED_EDGE__
diff --git a/klm/alone/main.cc b/klm/alone/main.cc
deleted file mode 100644
index e09ab01d..00000000
--- a/klm/alone/main.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-#include "alone/threading.hh"
-#include "search/config.hh"
-#include "search/context.hh"
-#include "util/exception.hh"
-#include "util/file_piece.hh"
-#include "util/usage.hh"
-
-#include <boost/lexical_cast.hpp>
-
-#include <iostream>
-#include <memory>
-
-namespace alone {
-
-template <class Control> void ReadLoop(const std::string &graph_prefix, Control &control) {
- for (unsigned int sentence = 0; ; ++sentence) {
- std::stringstream name;
- name << graph_prefix << '/' << sentence;
- std::auto_ptr<util::FilePiece> file;
- try {
- file.reset(new util::FilePiece(name.str().c_str()));
- } catch (const util::ErrnoException &e) {
- if (e.Error() == ENOENT) return;
- throw;
- }
- control.Add(file.release());
- }
-}
-
-template <class Model> void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) {
- Model model(model_file);
- search::Weights weights(weight_str);
- search::Config config(weights, pop_limit);
-
- if (threads > 1) {
-#ifdef WITH_THREADS
- Controller<Model> controller(config, model, threads, std::cout);
- ReadLoop(graph_prefix, controller);
-#else
- UTIL_THROW(util::Exception, "Threading support not compiled in.");
-#endif
- } else {
- InThread<Model> controller(config, model, std::cout);
- ReadLoop(graph_prefix, controller);
- }
-}
-
-void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) {
- lm::ngram::ModelType model_type;
- if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING;
- switch (model_type) {
- case lm::ngram::PROBING:
- RunWithModelType<lm::ngram::ProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads);
- break;
- case lm::ngram::REST_PROBING:
- RunWithModelType<lm::ngram::RestProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads);
- break;
- default:
- UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
- }
-}
-
-} // namespace alone
-
-int main(int argc, char *argv[]) {
- if (argc < 5 || argc > 6) {
- std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl;
- return 1;
- }
-
-#ifdef WITH_THREADS
- unsigned thread_count = boost::thread::hardware_concurrency();
-#else
- unsigned thread_count = 1;
-#endif
- if (argc == 6) {
- thread_count = boost::lexical_cast<unsigned>(argv[5]);
- UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0");
- }
- UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line.");
- alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast<unsigned int>(argv[4]), thread_count);
-
- util::PrintUsage(std::cerr);
- return 0;
-}
diff --git a/klm/alone/read.cc b/klm/alone/read.cc
deleted file mode 100644
index 0b20be35..00000000
--- a/klm/alone/read.cc
+++ /dev/null
@@ -1,118 +0,0 @@
-#include "alone/read.hh"
-
-#include "alone/graph.hh"
-#include "alone/vocab.hh"
-#include "search/arity.hh"
-#include "search/context.hh"
-#include "search/weights.hh"
-#include "util/file_piece.hh"
-
-#include <boost/unordered_set.hpp>
-#include <boost/unordered_map.hpp>
-
-#include <cstdlib>
-
-namespace alone {
-
-namespace {
-
-template <class Model> Graph::Edge &ReadEdge(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) {
- Graph::Edge *ret = to.NewEdge();
-
- StringPiece got;
-
- std::vector<lm::WordIndex> words;
- unsigned long int terminals = 0;
- while ("|||" != (got = from.ReadDelimited())) {
- if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) {
- // non-terminal
- char *end_ptr;
- unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10);
- UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got);
- UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?");
- ret->Add(to.MutableVertex(child));
- words.push_back(lm::kMaxWordIndex);
- ret->AppendWord(NULL);
- } else {
- const std::pair<const std::string, lm::WordIndex> &found = vocab.FindOrAdd(got);
- words.push_back(found.second);
- ret->AppendWord(&found.first);
- ++terminals;
- }
- }
- if (final) {
- // This is not counted for the word penalty.
- words.push_back(vocab.EndSentence().second);
- ret->AppendWord(&vocab.EndSentence().first);
- }
- // Hard-coded word penalty.
- float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10;
- ret->InitRule().Init(context, additive, words, final);
- unsigned int arity = ret->GetRule().Arity();
- UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity);
- return *ret;
-}
-
-} // namespace
-
-// TODO: refactor
-void JustVocab(util::FilePiece &from, std::ostream &out) {
- boost::unordered_set<std::string> seen;
- unsigned long int vertices = from.ReadULong();
- from.ReadULong(); // edges
- UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero");
- UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts");
- std::string temp;
- for (unsigned long int i = 0; i < vertices; ++i) {
- unsigned long int edge_count = from.ReadULong();
- UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count");
- for (unsigned long int e = 0; e < edge_count; ++e) {
- StringPiece got;
- while ("|||" != (got = from.ReadDelimited())) {
- if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue;
- temp.assign(got.data(), got.size());
- if (seen.insert(temp).second) out << temp << ' ';
- }
- from.ReadLine(); // weights
- }
- }
- // Eat sentence
- from.ReadLine();
-}
-
-template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab) {
- unsigned long int vertices;
- try {
- vertices = from.ReadULong();
- } catch (const util::EndOfFileException &e) { return false; }
- unsigned long int edges = from.ReadULong();
- UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices);
- UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges);
- --vertices;
- --edges;
- UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts");
- to.SetCounts(vertices, edges);
- Graph::Vertex *vertex;
- for (unsigned long int i = 0; ; ++i) {
- vertex = to.NewVertex();
- unsigned long int edge_count = from.ReadULong();
- bool root = (i == vertices - 1);
- UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count");
- for (unsigned long int e = 0; e < edge_count; ++e) {
- vertex->Add(ReadEdge(context, from, to, vocab, root));
- }
- vertex->FinishedAdding();
- if (root) break;
- }
- to.SetRoot(vertex);
- StringPiece str = from.ReadLine();
- UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root");
- // The edge
- from.ReadLine();
- return true;
-}
-
-template bool ReadCDec(search::Context<lm::ngram::ProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
-template bool ReadCDec(search::Context<lm::ngram::RestProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
-
-} // namespace alone
diff --git a/klm/alone/read.hh b/klm/alone/read.hh
deleted file mode 100644
index 10769a86..00000000
--- a/klm/alone/read.hh
+++ /dev/null
@@ -1,29 +0,0 @@
-#ifndef ALONE_READ__
-#define ALONE_READ__
-
-#include "util/exception.hh"
-
-#include <iosfwd>
-
-namespace util { class FilePiece; }
-
-namespace search { template <class Model> class Context; }
-
-namespace alone {
-
-class Graph;
-class Vocab;
-
-class FormatException : public util::Exception {
- public:
- FormatException() {}
- ~FormatException() throw() {}
-};
-
-void JustVocab(util::FilePiece &from, std::ostream &to);
-
-template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
-
-} // namespace alone
-
-#endif // ALONE_READ__
diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc
deleted file mode 100644
index 475386b6..00000000
--- a/klm/alone/threading.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-#include "alone/threading.hh"
-
-#include "alone/assemble.hh"
-#include "alone/graph.hh"
-#include "alone/read.hh"
-#include "alone/vocab.hh"
-#include "lm/model.hh"
-#include "search/context.hh"
-#include "search/vertex_generator.hh"
-
-#include <boost/ref.hpp>
-#include <boost/scoped_ptr.hpp>
-#include <boost/utility/in_place_factory.hpp>
-
-#include <sstream>
-
-namespace alone {
-template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) {
- search::Context<Model> context(config, model);
- Graph graph;
- Vocab vocab(model.GetVocabulary());
- {
- boost::scoped_ptr<util::FilePiece> in(in_ptr);
- ReadCDec(context, *in, graph, vocab);
- }
-
- for (std::size_t i = 0; i < graph.VertexSize(); ++i) {
- search::VertexGenerator(context, graph.MutableVertex(i));
- }
- search::PartialVertex top = graph.Root().RootPartial();
- if (top.Empty()) {
- out << "NO PATH FOUND";
- } else {
- search::PartialVertex continuation;
- while (!top.Complete()) {
- top.Split(continuation);
- top = continuation;
- }
- out << top.End() << " ||| " << top.End().Bound() << std::endl;
- }
-}
-
-template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
-template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
-
-#ifdef WITH_THREADS
-template <class Model> void DecodeHandler<Model>::operator()(Input message) {
- std::stringstream assemble;
- Decode(config_, model_, message.file, assemble);
- Produce(message.sentence_id, assemble.str());
-}
-
-template <class Model> void DecodeHandler<Model>::Produce(unsigned int sentence_id, const std::string &str) {
- Output out;
- out.sentence_id = sentence_id;
- out.str = new std::string(str);
- out_.Produce(out);
-}
-
-void PrintHandler::operator()(Output message) {
- unsigned int relative = message.sentence_id - done_;
- if (waiting_.size() <= relative) waiting_.resize(relative + 1);
- waiting_[relative] = message.str;
- for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) {
- out_ << *lead;
- delete lead;
- }
-}
-
-template <class Model> Controller<Model>::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) :
- sentence_id_(0),
- printer_(decode_workers, 1, boost::ref(to), Output::Poison()),
- decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {}
-
-template class Controller<lm::ngram::RestProbingModel>;
-template class Controller<lm::ngram::ProbingModel>;
-
-#endif
-
-} // namespace alone
diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh
deleted file mode 100644
index 0ab0f739..00000000
--- a/klm/alone/threading.hh
+++ /dev/null
@@ -1,129 +0,0 @@
-#ifndef ALONE_THREADING__
-#define ALONE_THREADING__
-
-#ifdef WITH_THREADS
-#include "util/pcqueue.hh"
-#include "util/pool.hh"
-#endif
-
-#include <iosfwd>
-#include <queue>
-#include <string>
-
-namespace util {
-class FilePiece;
-} // namespace util
-
-namespace search {
-class Config;
-template <class Model> class Context;
-} // namespace search
-
-namespace alone {
-
-template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out);
-
-class Graph;
-
-#ifdef WITH_THREADS
-struct SentenceID {
- unsigned int sentence_id;
- bool operator==(const SentenceID &other) const {
- return sentence_id == other.sentence_id;
- }
-};
-
-struct Input : public SentenceID {
- util::FilePiece *file;
- static Input Poison() {
- Input ret;
- ret.sentence_id = static_cast<unsigned int>(-1);
- ret.file = NULL;
- return ret;
- }
-};
-
-struct Output : public SentenceID {
- std::string *str;
- static Output Poison() {
- Output ret;
- ret.sentence_id = static_cast<unsigned int>(-1);
- ret.str = NULL;
- return ret;
- }
-};
-
-template <class Model> class DecodeHandler {
- public:
- typedef Input Request;
-
- DecodeHandler(const search::Config &config, const Model &model, util::PCQueue<Output> &out) : config_(config), model_(model), out_(out) {}
-
- void operator()(Input message);
-
- private:
- void Produce(unsigned int sentence_id, const std::string &str);
-
- const search::Config &config_;
-
- const Model &model_;
-
- util::PCQueue<Output> &out_;
-};
-
-class PrintHandler {
- public:
- typedef Output Request;
-
- explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {}
-
- void operator()(Output message);
-
- private:
- std::ostream &out_;
- std::deque<std::string*> waiting_;
- unsigned int done_;
-};
-
-template <class Model> class Controller {
- public:
- // This config must remain valid.
- explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to);
-
- // Takes ownership of in.
- void Add(util::FilePiece *in) {
- Input input;
- input.sentence_id = sentence_id_++;
- input.file = in;
- decoder_.Produce(input);
- }
-
- private:
- unsigned int sentence_id_;
-
- util::Pool<PrintHandler> printer_;
-
- util::Pool<DecodeHandler<Model> > decoder_;
-};
-#endif
-
-// Same API as controller.
-template <class Model> class InThread {
- public:
- InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {}
-
- // Takes ownership of in.
- void Add(util::FilePiece *in) {
- Decode(config_, model_, in, to_);
- }
-
- private:
- const search::Config &config_;
-
- const Model &model_;
-
- std::ostream &to_;
-};
-
-} // namespace alone
-#endif // ALONE_THREADING__
diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc
deleted file mode 100644
index ffe55301..00000000
--- a/klm/alone/vocab.cc
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "alone/vocab.hh"
-
-#include "lm/virtual_interface.hh"
-#include "util/string_piece.hh"
-
-namespace alone {
-
-Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("</s>")) {}
-
-const std::pair<const std::string, lm::WordIndex> &Vocab::FindOrAdd(const StringPiece &str) {
- Map::const_iterator i(FindStringPiece(map_, str));
- if (i != map_.end()) return *i;
- std::pair<std::string, lm::WordIndex> to_ins;
- to_ins.first.assign(str.data(), str.size());
- to_ins.second = backing_.Index(str);
- return *map_.insert(to_ins).first;
-}
-
-} // namespace alone
diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh
deleted file mode 100644
index 3ac0f542..00000000
--- a/klm/alone/vocab.hh
+++ /dev/null
@@ -1,34 +0,0 @@
-#ifndef ALONE_VOCAB__
-#define ALONE_VOCAB__
-
-#include "lm/word_index.hh"
-#include "util/string_piece.hh"
-
-#include <boost/functional/hash/hash.hpp>
-#include <boost/unordered_map.hpp>
-
-#include <string>
-
-namespace lm { namespace base { class Vocabulary; } }
-
-namespace alone {
-
-class Vocab {
- public:
- explicit Vocab(const lm::base::Vocabulary &backing);
-
- const std::pair<const std::string, lm::WordIndex> &FindOrAdd(const StringPiece &str);
-
- const std::pair<const std::string, lm::WordIndex> &EndSentence() const { return end_sentence_; }
-
- private:
- typedef boost::unordered_map<std::string, lm::WordIndex> Map;
- Map map_;
-
- const lm::base::Vocabulary &backing_;
-
- const std::pair<const std::string, lm::WordIndex> &end_sentence_;
-};
-
-} // namespace alone
-#endif // ALONE_VCOAB__
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 40af8a63..2fd20481 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -87,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
+ wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size());
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 398475be..11c27518 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -80,8 +80,8 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd) {
- util::SeekEnd(fd);
+void WriteWordsWrapper::Write(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 074cd446..de54eb06 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -35,7 +35,7 @@ class WriteWordsWrapper : public EnumerateVocab {
void Add(WordIndex index, const StringPiece &str);
- void Write(int fd);
+ void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
index e8b14363..bc95c53a 100644
--- a/klm/search/Jamfile
+++ b/klm/search/Jamfile
@@ -1,4 +1,4 @@
-lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
+lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
import testing ;
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
new file mode 100644
index 00000000..ccc5b7f6
--- /dev/null
+++ b/klm/search/Makefile.am
@@ -0,0 +1,11 @@
+noinst_LIBRARIES = libksearch.a
+
+libksearch_a_SOURCES = \
+ edge_generator.cc \
+ rule.cc \
+ vertex.cc \
+ vertex_generator.cc \
+ weights.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
+
diff --git a/klm/search/arity.hh b/klm/search/arity.hh
deleted file mode 100644
index 09c2c671..00000000
--- a/klm/search/arity.hh
+++ /dev/null
@@ -1,8 +0,0 @@
-#ifndef SEARCH_ARITY__
-#define SEARCH_ARITY__
-namespace search {
-
-const unsigned int kMaxArity = 2;
-
-} // namespace search
-#endif // SEARCH_ARITY__
diff --git a/klm/search/context.hh b/klm/search/context.hh
index 27940053..62163144 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -7,6 +7,7 @@
#include "search/types.hh"
#include "search/vertex.hh"
#include "util/exception.hh"
+#include "util/pool.hh"
#include <boost/pool/object_pool.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
@@ -21,10 +22,8 @@ class ContextBase {
public:
explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
- Final *NewFinal() {
- Final *ret = final_pool_.construct();
- assert(ret);
- return ret;
+ util::Pool &FinalPool() {
+ return final_pool_;
}
VertexNode *NewVertexNode() {
@@ -42,7 +41,8 @@ class ContextBase {
const Weights &GetWeights() const { return weights_; }
private:
- boost::object_pool<Final> final_pool_;
+ util::Pool final_pool_;
+
boost::object_pool<VertexNode> vertex_node_pool_;
unsigned int pop_limit_;
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
index 77ab0ade..187904bf 100644
--- a/klm/search/edge.hh
+++ b/klm/search/edge.hh
@@ -2,30 +2,53 @@
#define SEARCH_EDGE__
#include "lm/state.hh"
-#include "search/arity.hh"
-#include "search/rule.hh"
+#include "search/header.hh"
#include "search/types.hh"
#include "search/vertex.hh"
+#include "util/pool.hh"
-#include <queue>
+#include <functional>
+
+#include <stdint.h>
namespace search {
-struct PartialEdge {
- Score score;
- // Terminals
- lm::ngram::ChartState between[kMaxArity + 1];
- // Non-terminals
- PartialVertex nt[kMaxArity];
+// Copyable, but the copy will be shallow.
+class PartialEdge : public Header {
+ public:
+ // Allow default construction for STL.
+ PartialEdge() {}
+
+ PartialEdge(util::Pool &pool, Arity arity)
+ : Header(pool.Allocate(Size(arity, arity + 1)), arity) {}
+
+ PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)
+ : Header(pool.Allocate(Size(arity, chart_states)), arity) {}
- const lm::ngram::ChartState &CompletedState() const {
- return between[0];
- }
+ // Non-terminals
+ const PartialVertex *NT() const {
+ return reinterpret_cast<const PartialVertex*>(After());
+ }
+ PartialVertex *NT() {
+ return reinterpret_cast<PartialVertex*>(After());
+ }
- bool operator<(const PartialEdge &other) const {
- return score < other.score;
- }
+ const lm::ngram::ChartState &CompletedState() const {
+ return *Between();
+ }
+ const lm::ngram::ChartState *Between() const {
+ return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+ lm::ngram::ChartState *Between() {
+ return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+
+ private:
+ static std::size_t Size(Arity arity, Arity chart_states) {
+ return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState);
+ }
};
+
} // namespace search
#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index 56239dfb..260159b1 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -4,117 +4,107 @@
#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
-#include "search/vertex_generator.hh"
#include <numeric>
namespace search {
-EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
-/* for (unsigned char i = 0; i < edge.Arity(); ++i) {
- root.nt[i] = edge.GetVertex(i).RootPartial();
- }
- for (unsigned char i = edge.Arity(); i < 2; ++i) {
- root.nt[i] = kBlankPartialVertex;
- }*/
- generate_.push(&root);
- top_score_ = root.score;
-}
-
namespace {
-template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
- memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
-
- float ret = 0.0;
- lm::ngram::ChartState *before, *after;
- if (victim == 0) {
- before = &update.between[0];
- after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
- } else {
- assert(victim == 1);
- assert(arity == 2);
- before = &update.between[previous.nt[0].Complete() ? 0 : 1];
- after = &update.between[2];
- }
- const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
- const PartialVertex &update_nt = update.nt[victim];
+template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) {
+ lm::ngram::ChartState *between = update.Between();
+ lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1];
+
+ float adjustment = 0.0;
+ const lm::ngram::ChartState &previous_reveal = previous_vertex.State();
+ const PartialVertex &update_nt = update.NT()[victim];
const lm::ngram::ChartState &update_reveal = update_nt.State();
- float just_after = 0.0;
if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
- just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
}
- if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
- ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) {
+ adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
}
if (update_nt.Complete()) {
if (update_reveal.left.full) {
before->left.full = true;
} else {
assert(update_reveal.left.length == update_reveal.right.length);
- ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
}
- if (victim == 0) {
- update.between[0].right = after->right;
- } else {
- update.between[2].left = before->left;
+ before->right = after->right;
+ // Shift the others shifted one down, covering after.
+ for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) {
+ *cover = *(cover + 1);
}
}
- return previous.score + (ret + just_after) * context.GetWeights().LM();
+ update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
}
} // namespace
-template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
+template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
assert(!generate_.empty());
- PartialEdge &top = *generate_.top();
+ PartialEdge top = generate_.top();
generate_.pop();
- unsigned int victim = 0;
- unsigned char lowest_length = 255;
- for (unsigned char i = 0; i != arity_; ++i) {
- if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
- lowest_length = top.nt[i].Length();
- victim = i;
+ PartialVertex *const top_nt = top.NT();
+ const Arity arity = top.GetArity();
+
+ Arity victim = 0;
+ Arity victim_completed;
+ Arity incomplete;
+ // Select victim or return if complete.
+ {
+ Arity completed = 0;
+ unsigned char lowest_length = 255;
+ for (Arity i = 0; i != arity; ++i) {
+ if (top_nt[i].Complete()) {
+ ++completed;
+ } else if (top_nt[i].Length() < lowest_length) {
+ lowest_length = top_nt[i].Length();
+ victim = i;
+ victim_completed = completed;
+ }
}
- }
- if (lowest_length == 255) {
- // All states report complete.
- top.between[0].right = top.between[arity_].right;
- // Now top.between[0] is the full edge state.
- top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
- return &top;
+ if (lowest_length == 255) {
+ return top;
+ }
+ incomplete = arity - completed;
}
- unsigned int stay = !victim;
- PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
- float old_bound = top.nt[victim].Bound();
- // The alternate's score will change because alternate.nt[victim] changes.
- bool split = top.nt[victim].Split(continuation.nt[victim]);
- // top is now the alternate.
+ PartialVertex old_value(top_nt[victim]);
+ PartialVertex alternate_changed;
+ if (top_nt[victim].Split(alternate_changed)) {
+ PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1);
+ alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound());
- continuation.nt[stay] = top.nt[stay];
- continuation.score = FastScore(context, victim, arity_, top, continuation);
- // TODO: dedupe?
- generate_.push(&continuation);
+ alternate.SetNote(top.GetNote());
+
+ PartialVertex *alternate_nt = alternate.NT();
+ for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i];
+ alternate_nt[victim] = alternate_changed;
+ for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i];
+
+ memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1));
- if (split) {
- // We have an alternate.
- top.score += top.nt[victim].Bound() - old_bound;
// TODO: dedupe?
- generate_.push(&top);
- } else {
- partial_edge_pool.free(&top);
+ generate_.push(alternate);
}
- top_score_ = generate_.top()->score;
- return NULL;
+ // top is now the continuation.
+ FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
+ // TODO: dedupe?
+ generate_.push(top);
+
+ // Invalid indicates no new hypothesis generated.
+ return PartialEdge();
}
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context);
} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
index 875ccc5e..582c78b7 100644
--- a/klm/search/edge_generator.hh
+++ b/klm/search/edge_generator.hh
@@ -3,11 +3,8 @@
#include "search/edge.hh"
#include "search/note.hh"
+#include "search/types.hh"
-#include <boost/pool/pool.hpp>
-#include <boost/unordered_map.hpp>
-
-#include <functional>
#include <queue>
namespace lm {
@@ -20,38 +17,40 @@ namespace search {
template <class Model> class Context;
-class VertexGenerator;
-
-struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> {
- bool operator()(const PartialEdge *first, const PartialEdge *second) const {
- return *first < *second;
- }
-};
-
class EdgeGenerator {
public:
- EdgeGenerator(PartialEdge &root, unsigned char arity, Note note);
+ EdgeGenerator() {}
- Score TopScore() const {
- return top_score_;
+ PartialEdge AllocateEdge(Arity arity) {
+ return PartialEdge(partial_edge_pool_, arity);
}
- Note GetNote() const {
- return note_;
+ void AddEdge(PartialEdge edge) {
+ generate_.push(edge);
}
- // Pop. If there's a complete hypothesis, return it. Otherwise return NULL.
- template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool);
+ bool Empty() const { return generate_.empty(); }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge.
+ template <class Model> PartialEdge Pop(Context<Model> &context);
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ unsigned to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ PartialEdge got(Pop(context));
+ if (got.Valid()) {
+ output.NewHypothesis(got);
+ --to_pop;
+ }
+ }
+ output.FinishedSearch();
+ }
private:
- Score top_score_;
-
- unsigned char arity_;
+ util::Pool partial_edge_pool_;
- typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;
+ typedef std::priority_queue<PartialEdge> Generate;
Generate generate_;
-
- Note note_;
};
} // namespace search
diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc
deleted file mode 100644
index e3ae6ebf..00000000
--- a/klm/search/edge_queue.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "search/edge_queue.hh"
-
-#include "lm/left.hh"
-#include "search/context.hh"
-
-#include <stdint.h>
-
-namespace search {
-
-EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) {
- take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
-}
-
-/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) {
- // Ignore empty edges.
- for (unsigned char i = 0; i < edge.Arity(); ++i) {
- PartialVertex root(edge.GetVertex(i).RootPartial());
- if (root.Empty()) return;
- total_score += root.Bound();
- }
- PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc());
- allocated.score = total_score;
-}*/
-
-} // namespace search
diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh
deleted file mode 100644
index 187eaed7..00000000
--- a/klm/search/edge_queue.hh
+++ /dev/null
@@ -1,73 +0,0 @@
-#ifndef SEARCH_EDGE_QUEUE__
-#define SEARCH_EDGE_QUEUE__
-
-#include "search/edge.hh"
-#include "search/edge_generator.hh"
-#include "search/note.hh"
-
-#include <boost/pool/pool.hpp>
-#include <boost/pool/object_pool.hpp>
-
-#include <queue>
-
-namespace search {
-
-template <class Model> class Context;
-
-class EdgeQueue {
- public:
- explicit EdgeQueue(unsigned int pop_limit_hint);
-
- PartialEdge &InitializeEdge() {
- return *take_;
- }
-
- void AddEdge(unsigned char arity, Note note) {
- generate_.push(edge_pool_.construct(*take_, arity, note));
- take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
- }
-
- bool Empty() const { return generate_.empty(); }
-
- /* Generate hypotheses and send them to output. Normally, output is a
- * VertexGenerator, but the decoder may want to route edges to different
- * vertices i.e. if they have different LHS non-terminal labels.
- */
- template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
- int to_pop = context.PopLimit();
- while (to_pop > 0 && !generate_.empty()) {
- EdgeGenerator *top = generate_.top();
- generate_.pop();
- PartialEdge *ret = top->Pop(context, partial_edge_pool_);
- if (ret) {
- output.NewHypothesis(*ret, top->GetNote());
- --to_pop;
- if (top->TopScore() != -kScoreInf) {
- generate_.push(top);
- }
- } else {
- generate_.push(top);
- }
- }
- output.FinishedSearch();
- }
-
- private:
- boost::object_pool<EdgeGenerator> edge_pool_;
-
- struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
- bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
- return first->TopScore() < second->TopScore();
- }
- };
-
- typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate;
- Generate generate_;
-
- boost::pool<> partial_edge_pool_;
-
- PartialEdge *take_;
-};
-
-} // namespace search
-#endif // SEARCH_EDGE_QUEUE__
diff --git a/klm/search/final.hh b/klm/search/final.hh
index 1b3092ac..50e62cf2 100644
--- a/klm/search/final.hh
+++ b/klm/search/final.hh
@@ -1,37 +1,34 @@
#ifndef SEARCH_FINAL__
#define SEARCH_FINAL__
-#include "search/arity.hh"
-#include "search/note.hh"
-#include "search/types.hh"
-
-#include <boost/array.hpp>
+#include "search/header.hh"
+#include "util/pool.hh"
namespace search {
-class Final {
+// A full hypothesis with pointers to children.
+class Final : public Header {
public:
- typedef boost::array<const Final*, search::kMaxArity> ChildArray;
+ Final() {}
- void Reset(Score bound, Note note, const Final &left, const Final &right) {
- bound_ = bound;
- note_ = note;
- children_[0] = &left;
- children_[1] = &right;
+ Final(util::Pool &pool, Score score, Arity arity, Note note)
+ : Header(pool.Allocate(Size(arity)), arity) {
+ SetScore(score);
+ SetNote(note);
}
- const ChildArray &Children() const { return children_; }
-
- Note GetNote() const { return note_; }
-
- Score Bound() const { return bound_; }
+ // These are arrays of length GetArity().
+ Final *Children() {
+ return reinterpret_cast<Final*>(After());
+ }
+ const Final *Children() const {
+ return reinterpret_cast<const Final*>(After());
+ }
private:
- Score bound_;
-
- Note note_;
-
- ChildArray children_;
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Final);
+ }
};
} // namespace search
diff --git a/klm/search/header.hh b/klm/search/header.hh
new file mode 100644
index 00000000..25550dbe
--- /dev/null
+++ b/klm/search/header.hh
@@ -0,0 +1,57 @@
+#ifndef SEARCH_HEADER__
+#define SEARCH_HEADER__
+
+// Header consisting of Score, Arity, and Note
+
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+// Copying is shallow.
+class Header {
+ public:
+ bool Valid() const { return base_; }
+
+ Score GetScore() const {
+ return *reinterpret_cast<const float*>(base_);
+ }
+ void SetScore(Score to) {
+ *reinterpret_cast<float*>(base_) = to;
+ }
+ bool operator<(const Header &other) const {
+ return GetScore() < other.GetScore();
+ }
+
+ Arity GetArity() const {
+ return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
+ }
+
+ Note GetNote() const {
+ return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
+ }
+ void SetNote(Note to) {
+ *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
+ }
+
+ protected:
+ Header() : base_(NULL) {}
+
+ Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
+ *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ }
+
+ static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
+
+ uint8_t *After() { return base_ + kHeaderSize; }
+ const uint8_t *After() const { return base_ + kHeaderSize; }
+
+ private:
+ uint8_t *base_;
+};
+
+} // namespace search
+
+#endif // SEARCH_HEADER__
diff --git a/klm/search/source.hh b/klm/search/source.hh
deleted file mode 100644
index 11839f7b..00000000
--- a/klm/search/source.hh
+++ /dev/null
@@ -1,48 +0,0 @@
-#ifndef SEARCH_SOURCE__
-#define SEARCH_SOURCE__
-
-#include "search/types.hh"
-
-#include <assert.h>
-#include <vector>
-
-namespace search {
-
-template <class Final> class Source {
- public:
- Source() : bound_(kScoreInf) {}
-
- Index Size() const {
- return final_.size();
- }
-
- Score Bound() const {
- return bound_;
- }
-
- const Final &operator[](Index index) const {
- return *final_[index];
- }
-
- Score ScoreOrBound(Index index) const {
- return Size() > index ? final_[index]->Total() : Bound();
- }
-
- protected:
- void AddFinal(const Final &store) {
- final_.push_back(&store);
- }
-
- void SetBound(Score to) {
- assert(to <= bound_ + 0.001);
- bound_ = to;
- }
-
- private:
- std::vector<const Final *> final_;
-
- Score bound_;
-};
-
-} // namespace search
-#endif // SEARCH_SOURCE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
index 9726379f..06eb5bfa 100644
--- a/klm/search/types.hh
+++ b/klm/search/types.hh
@@ -1,17 +1,13 @@
#ifndef SEARCH_TYPES__
#define SEARCH_TYPES__
-#include <cmath>
+#include <stdint.h>
namespace search {
typedef float Score;
-const Score kScoreInf = INFINITY;
-// This could have been an enum but gcc wants 4 bytes.
-typedef bool ExtendDirection;
-const ExtendDirection kExtendLeft = 0;
-const ExtendDirection kExtendRight = 1;
+typedef uint32_t Arity;
} // namespace search
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
index cc53c0dd..11f4631f 100644
--- a/klm/search/vertex.cc
+++ b/klm/search/vertex.cc
@@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
if (Complete()) {
- assert(end_);
+ assert(end_.Valid());
assert(extend_.empty());
- bound_ = end_->Bound();
+ bound_ = end_.GetScore();
return;
}
if (extend_.size() == 1 && parent_ptr) {
@@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
bound_ = extend_.front()->Bound();
}
-namespace {
-VertexNode kBlankVertexNode;
-} // namespace
-
-PartialVertex kBlankPartialVertex(kBlankVertexNode);
-
} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index e1a9ad11..52bc1dfe 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() : end_(NULL) {}
+ VertexNode() {}
void InitRoot() {
extend_.clear();
@@ -26,8 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- bound_ = -kScoreInf;
- end_ = NULL;
+ end_ = Final();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -37,19 +36,20 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final *end) { end_ = end; }
+ void SetEnd(Final end) {
+ assert(!end_.Valid());
+ end_ = end;
+ }
- Final &MutableEnd() { return *end_; }
-
void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return !end_.Valid() && extend_.empty();
}
bool Complete() const {
- return end_;
+ return end_.Valid();
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -63,8 +63,8 @@ class VertexNode {
return state_.left.length + state_.right.length;
}
- // May be NULL.
- const Final *End() const { return end_; }
+ // Will be invalid unless this is a leaf.
+ const Final End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -81,7 +81,7 @@ class VertexNode {
bool right_full_;
Score bound_;
- Final *end_;
+ Final end_;
};
class PartialVertex {
@@ -97,7 +97,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -105,20 +105,24 @@ class PartialVertex {
return index_ + 1 < back_->Size();
}
- // Split into continuation and alternative, rendering this the alternative.
- bool Split(PartialVertex &continuation) {
+ // Split into continuation and alternative, rendering this the continuation.
+ bool Split(PartialVertex &alternative) {
assert(!Complete());
- continuation.back_ = &((*back_)[index_]);
- continuation.index_ = 0;
+ bool ret;
if (index_ + 1 < back_->Size()) {
- ++index_;
- return true;
+ alternative.index_ = index_ + 1;
+ alternative.back_ = back_;
+ ret = true;
+ } else {
+ ret = false;
}
- return false;
+ back_ = &((*back_)[index_]);
+ index_ = 0;
+ return ret;
}
- const Final &End() const {
- return *back_->End();
+ const Final End() const {
+ return back_->End();
}
private:
@@ -126,25 +130,22 @@ class PartialVertex {
unsigned int index_;
};
-extern PartialVertex kBlankPartialVertex;
-
class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final *BestChild() const {
+ const Final BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return NULL;
+ return Final();
} else {
PartialVertex continuation;
while (!top.Complete()) {
top.Split(continuation);
- top = continuation;
}
- return &top.End();
+ return top.End();
}
}
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index d94e6e06..0945fe55 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -10,74 +10,85 @@ namespace search {
VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
gen.root_.InitRoot();
- root_.under = &gen.root_;
}
namespace {
+
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-} // namespace
-void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) {
- const lm::ngram::ChartState &state = partial.CompletedState();
- std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
- if (!got.second) {
- // Found it already.
- Final &exists = *got.first->second;
- if (exists.Bound() < partial.score) {
- exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- }
- return;
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
}
+ return next;
+}
+
+void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
+ Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
+ Final *child_out = final.Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = part->End();
+
+ starter.under->SetEnd(final);
+}
+
+void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+
unsigned char left = 0, right = 0;
- Trie *node = &root_;
+ Trie *node = &root;
while (true) {
if (left == state.left.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
for (; right < state.right.length; ++right) {
- node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
}
break;
}
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
left++;
if (right == state.right.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
for (; left < state.left.length; ++left) {
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
}
break;
}
- node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
right++;
}
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- got.first->second = CompleteTransition(*node, state, note, partial);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ CompleteTransition(context, *node, partial);
}
-VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
- VertexGenerator::Trie &next = node.extend[added];
- if (!next.under) {
- next.under = context_.NewVertexNode();
- lm::ngram::ChartState &writing = next.under->MutableState();
- writing = state;
- writing.left.full &= left_full && state.left.full;
- next.under->MutableRightFull() = right_full && state.left.full;
- writing.left.length = left;
- writing.right.length = right;
- node.under->AddExtend(next.under);
- }
- return next;
-}
+} // namespace
-Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {
- VertexNode &node = *starter.under;
- assert(node.State().left.full == state.left.full);
- assert(!node.End());
- Final *final = context_.NewFinal();
- final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- node.SetEnd(final);
- return final;
+void VertexGenerator::FinishedSearch() {
+ Trie root;
+ root.under = &gen_.root_;
+ for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, i->second);
+ }
+ root.under->SortAndSet(context_, NULL);
}
} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 6b98da3e..60e86112 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -1,13 +1,11 @@
#ifndef SEARCH_VERTEX_GENERATOR__
#define SEARCH_VERTEX_GENERATOR__
-#include "search/note.hh"
+#include "search/edge.hh"
#include "search/vertex.hh"
#include <boost/unordered_map.hpp>
-#include <queue>
-
namespace lm {
namespace ngram {
class ChartState;
@@ -18,40 +16,29 @@ namespace search {
class ContextBase;
class Final;
-struct PartialEdge;
class VertexGenerator {
public:
VertexGenerator(ContextBase &context, Vertex &gen);
- void NewHypothesis(const PartialEdge &partial, Note note);
-
- void FinishedSearch() {
- root_.under->SortAndSet(context_, NULL);
+ void NewHypothesis(PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+ std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
+ if (!ret.second && ret.first->second < partial) {
+ ret.first->second = partial;
+ }
}
+ void FinishedSearch();
+
const Vertex &Generating() const { return gen_; }
private:
- // Parallel structure to VertexNode.
- struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
- };
-
- Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
-
- Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial);
-
ContextBase &context_;
Vertex &gen_;
- Trie root_;
-
- typedef boost::unordered_map<uint64_t, Final*> Existing;
+ typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
Existing existing_;
};
diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am
index 5ceccf2c..5306850f 100644
--- a/klm/util/Makefile.am
+++ b/klm/util/Makefile.am
@@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \
file_piece.cc \
mmap.cc \
murmur_hash.cc \
+ pool.cc \
+ string_piece.cc \
usage.cc
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh
index ff4d590f..9909736d 100644
--- a/klm/util/ersatz_progress.hh
+++ b/klm/util/ersatz_progress.hh
@@ -4,7 +4,7 @@
#include <iostream>
#include <string>
-#include <inttypes.h>
+#include <stdint.h>
// Ersatz version of boost::progress so core language model doesn't depend on
// boost. Also adds option to print nothing.
diff --git a/klm/util/exception.hh b/klm/util/exception.hh
index 83f99cd6..053a850b 100644
--- a/klm/util/exception.hh
+++ b/klm/util/exception.hh
@@ -6,7 +6,7 @@
#include <sstream>
#include <string>
-#include <inttypes.h>
+#include <stdint.h>
namespace util {
diff --git a/klm/util/pool.cc b/klm/util/pool.cc
new file mode 100644
index 00000000..2dffd06f
--- /dev/null
+++ b/klm/util/pool.cc
@@ -0,0 +1,35 @@
+#include "util/pool.hh"
+
+#include <stdlib.h>
+
+namespace util {
+
+Pool::Pool() {
+ current_ = NULL;
+ current_end_ = NULL;
+}
+
+Pool::~Pool() {
+ FreeAll();
+}
+
+void Pool::FreeAll() {
+ for (std::vector<void *>::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) {
+ free(*i);
+ }
+ free_list_.clear();
+ current_ = NULL;
+ current_end_ = NULL;
+}
+
+void *Pool::More(std::size_t size) {
+ std::size_t amount = std::max(static_cast<size_t>(32) << free_list_.size(), size);
+ uint8_t *ret = static_cast<uint8_t*>(malloc(amount));
+ if (!ret) throw std::bad_alloc();
+ free_list_.push_back(ret);
+ current_ = ret + size;
+ current_end_ = ret + amount;
+ return ret;
+}
+
+} // namespace util
diff --git a/klm/util/pool.hh b/klm/util/pool.hh
new file mode 100644
index 00000000..72f8a0c8
--- /dev/null
+++ b/klm/util/pool.hh
@@ -0,0 +1,45 @@
+// Very simple pool. It can only allocate memory. And all of the memory it
+// allocates must be freed at the same time.
+
+#ifndef UTIL_POOL__
+#define UTIL_POOL__
+
+#include <vector>
+
+#include <stdint.h>
+
+namespace util {
+
+class Pool {
+ public:
+ Pool();
+
+ ~Pool();
+
+ void *Allocate(std::size_t size) {
+ void *ret = current_;
+ current_ += size;
+ if (current_ < current_end_) {
+ return ret;
+ } else {
+ return More(size);
+ }
+ }
+
+ void FreeAll();
+
+ private:
+ void *More(std::size_t size);
+
+ std::vector<void *> free_list_;
+
+ uint8_t *current_, *current_end_;
+
+ // no copying
+ Pool(const Pool &);
+ Pool &operator=(const Pool &);
+};
+
+} // namespace util
+
+#endif // UTIL_POOL__
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 770faa7e..4a8aff35 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -8,7 +8,7 @@
#include <functional>
#include <assert.h>
-#include <inttypes.h>
+#include <stdint.h>
namespace util {
diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc
new file mode 100644
index 00000000..b422cefc
--- /dev/null
+++ b/klm/util/string_piece.cc
@@ -0,0 +1,192 @@
+// Copyright 2004 The RE2 Authors. All Rights Reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in string_piece.hh.
+
+#include "util/string_piece.hh"
+
+#include <algorithm>
+
+#include <limits.h>
+
+#ifndef HAVE_ICU
+
+typedef StringPiece::size_type size_type;
+
+void StringPiece::CopyToString(std::string* target) const {
+ target->assign(ptr_, length_);
+}
+
+size_type StringPiece::find(const StringPiece& s, size_type pos) const {
+ if (length_ < 0 || pos > static_cast<size_type>(length_))
+ return npos;
+
+ const char* result = std::search(ptr_ + pos, ptr_ + length_,
+ s.ptr_, s.ptr_ + s.length_);
+ const size_type xpos = result - ptr_;
+ return xpos + s.length_ <= length_ ? xpos : npos;
+}
+
+size_type StringPiece::find(char c, size_type pos) const {
+ if (length_ <= 0 || pos >= static_cast<size_type>(length_)) {
+ return npos;
+ }
+ const char* result = std::find(ptr_ + pos, ptr_ + length_, c);
+ return result != ptr_ + length_ ? result - ptr_ : npos;
+}
+
+size_type StringPiece::rfind(const StringPiece& s, size_type pos) const {
+ if (length_ < s.length_) return npos;
+ const size_t ulen = length_;
+ if (s.length_ == 0) return std::min(ulen, pos);
+
+ const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_;
+ const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_);
+ return result != last ? result - ptr_ : npos;
+}
+
+size_type StringPiece::rfind(char c, size_type pos) const {
+ if (length_ <= 0) return npos;
+ for (int i = std::min(pos, static_cast<size_type>(length_ - 1));
+ i >= 0; --i) {
+ if (ptr_[i] == c) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+// For each character in characters_wanted, sets the index corresponding
+// to the ASCII code of that character to 1 in table. This is used by
+// the find_.*_of methods below to tell whether or not a character is in
+// the lookup table in constant time.
+// The argument `table' must be an array that is large enough to hold all
+// the possible values of an unsigned char. Thus it should be be declared
+// as follows:
+// bool table[UCHAR_MAX + 1]
+static inline void BuildLookupTable(const StringPiece& characters_wanted,
+ bool* table) {
+ const size_type length = characters_wanted.length();
+ const char* const data = characters_wanted.data();
+ for (size_type i = 0; i < length; ++i) {
+ table[static_cast<unsigned char>(data[i])] = true;
+ }
+}
+
+size_type StringPiece::find_first_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0 || s.length_ == 0)
+ return npos;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_first_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = pos; i < length_; ++i) {
+ if (lookup[static_cast<unsigned char>(ptr_[i])]) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_first_not_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ if (s.length_ == 0)
+ return 0;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_first_not_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = pos; i < length_; ++i) {
+ if (!lookup[static_cast<unsigned char>(ptr_[i])]) {
+ return i;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_first_not_of(char c, size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ for (; pos < length_; ++pos) {
+ if (ptr_[pos] != c) {
+ return pos;
+ }
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const {
+ if (length_ == 0 || s.length_ == 0)
+ return npos;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_last_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (size_type i = std::min(pos, length_ - 1); ; --i) {
+ if (lookup[static_cast<unsigned char>(ptr_[i])])
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_not_of(const StringPiece& s,
+ size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ size_type i = std::min(pos, length_ - 1);
+ if (s.length_ == 0)
+ return i;
+
+ // Avoid the cost of BuildLookupTable() for a single-character search.
+ if (s.length_ == 1)
+ return find_last_not_of(s.ptr_[0], pos);
+
+ bool lookup[UCHAR_MAX + 1] = { false };
+ BuildLookupTable(s, lookup);
+ for (; ; --i) {
+ if (!lookup[static_cast<unsigned char>(ptr_[i])])
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+size_type StringPiece::find_last_not_of(char c, size_type pos) const {
+ if (length_ == 0)
+ return npos;
+
+ for (size_type i = std::min(pos, length_ - 1); ; --i) {
+ if (ptr_[i] != c)
+ return i;
+ if (i == 0)
+ break;
+ }
+ return npos;
+}
+
+StringPiece StringPiece::substr(size_type pos, size_type n) const {
+ if (pos > length_) pos = length_;
+ if (n > length_ - pos) n = length_ - pos;
+ return StringPiece(ptr_ + pos, n);
+}
+
+const size_type StringPiece::npos = size_type(-1);
+
+#endif // !HAVE_ICU
diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh
index c7e1c863..4a7f5460 100644
--- a/klm/util/tokenize_piece.hh
+++ b/klm/util/tokenize_piece.hh
@@ -54,6 +54,18 @@ class AnyCharacter {
StringPiece chars_;
};
+class AnyCharacterLast {
+ public:
+ explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {}
+
+ StringPiece Find(const StringPiece &in) const {
+ return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1);
+ }
+
+ private:
+ StringPiece chars_;
+};
+
template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> {
public:
TokenIter() {}