summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
Diffstat (limited to 'klm')
-rw-r--r--klm/alone/Jamfile4
-rw-r--r--klm/alone/assemble.cc76
-rw-r--r--klm/alone/assemble.hh21
-rw-r--r--klm/alone/graph.hh87
-rw-r--r--klm/alone/just_vocab.cc14
-rw-r--r--klm/alone/labeled_edge.hh30
-rw-r--r--klm/alone/main.cc85
-rw-r--r--klm/alone/read.cc118
-rw-r--r--klm/alone/read.hh29
-rw-r--r--klm/alone/threading.cc80
-rw-r--r--klm/alone/threading.hh129
-rw-r--r--klm/alone/vocab.cc19
-rw-r--r--klm/alone/vocab.hh34
-rw-r--r--klm/lm/fragment.cc37
-rw-r--r--klm/lm/partial.hh167
-rw-r--r--klm/lm/partial_test.cc199
-rw-r--r--klm/search/Jamfile5
-rw-r--r--klm/search/arity.hh8
-rw-r--r--klm/search/config.hh25
-rw-r--r--klm/search/context.hh65
-rw-r--r--klm/search/edge.hh31
-rw-r--r--klm/search/edge_generator.cc120
-rw-r--r--klm/search/edge_generator.hh58
-rw-r--r--klm/search/edge_queue.cc25
-rw-r--r--klm/search/edge_queue.hh73
-rw-r--r--klm/search/final.hh39
-rw-r--r--klm/search/note.hh12
-rw-r--r--klm/search/rule.cc43
-rw-r--r--klm/search/rule.hh20
-rw-r--r--klm/search/source.hh48
-rw-r--r--klm/search/types.hh18
-rw-r--r--klm/search/vertex.cc48
-rw-r--r--klm/search/vertex.hh158
-rw-r--r--klm/search/vertex_generator.cc83
-rw-r--r--klm/search/vertex_generator.hh59
-rw-r--r--klm/search/weights.cc71
-rw-r--r--klm/search/weights.hh52
-rw-r--r--klm/search/weights_test.cc38
-rw-r--r--klm/util/have.hh2
39 files changed, 2229 insertions, 1 deletions
diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile
new file mode 100644
index 00000000..2cc90c05
--- /dev/null
+++ b/klm/alone/Jamfile
@@ -0,0 +1,4 @@
+lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : <include>.. : : <include>.. <library>../search//search <library>../lm//kenlm ;
+
+exe decode : main.cc standalone main.cc : <threading>multi:<library>..//boost_thread ;
+exe just_vocab : just_vocab.cc standalone : <threading>multi:<library>..//boost_thread ;
diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc
new file mode 100644
index 00000000..2ae72ce9
--- /dev/null
+++ b/klm/alone/assemble.cc
@@ -0,0 +1,76 @@
+#include "alone/assemble.hh"
+
+#include "alone/labeled_edge.hh"
+#include "search/final.hh"
+
+#include <iostream>
+
+namespace alone {
+
+std::ostream &operator<<(std::ostream &o, const search::Final &final) {
+ const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words();
+ if (words.empty()) return o;
+ const search::Final *const *child = final.Children().data();
+ std::vector<const std::string*>::const_iterator i(words.begin());
+ for (; i != words.end() - 1; ++i) {
+ if (*i) {
+ o << **i << ' ';
+ } else {
+ o << **child << ' ';
+ ++child;
+ }
+ }
+
+ if (*i) {
+ if (**i != "</s>") {
+ o << **i;
+ }
+ } else {
+ o << **child;
+ }
+
+ return o;
+}
+
+namespace {
+
+void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) {
+ for (unsigned int i = 0; i < level; ++i)
+ o << indent_str;
+}
+
+void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) {
+ o << "(\n";
+ MakeIndent(o, indent_str, indent);
+ const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words();
+ const search::Final *const *child = final.Children().data();
+ for (std::vector<const std::string*>::const_iterator i(words.begin()); i != words.end(); ++i) {
+ if (*i) {
+ o << **i;
+ if (i == words.end() - 1) {
+ o << '\n';
+ MakeIndent(o, indent_str, indent);
+ } else {
+ o << ' ';
+ }
+ } else {
+ // One extra indent from the line we're currently on.
+ o << indent_str;
+ DetailedFinalInternal(o, **child, indent_str, indent + 1);
+ for (unsigned int i = 0; i < indent; ++i) o << indent_str;
+ ++child;
+ }
+ }
+ o << ")=" << final.Bound() << '\n';
+}
+} // namespace
+
+void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) {
+ DetailedFinalInternal(o, final, indent_str, 0);
+}
+
+void PrintFinal(const search::Final &final) {
+ std::cout << final << std::endl;
+}
+
+} // namespace alone
diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh
new file mode 100644
index 00000000..e6b0ad5c
--- /dev/null
+++ b/klm/alone/assemble.hh
@@ -0,0 +1,21 @@
+#ifndef ALONE_ASSEMBLE__
+#define ALONE_ASSEMBLE__
+
+#include <iosfwd>
+
+namespace search {
+class Final;
+} // namespace search
+
+namespace alone {
+
+std::ostream &operator<<(std::ostream &o, const search::Final &final);
+
+void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " ");
+
+// This isn't called anywhere but makes it easy to print from gdb.
+void PrintFinal(const search::Final &final);
+
+} // namespace alone
+
+#endif // ALONE_ASSEMBLE__
diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh
new file mode 100644
index 00000000..788352c9
--- /dev/null
+++ b/klm/alone/graph.hh
@@ -0,0 +1,87 @@
+#ifndef ALONE_GRAPH__
+#define ALONE_GRAPH__
+
+#include "alone/labeled_edge.hh"
+#include "search/rule.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/exception.hh"
+
+#include <boost/noncopyable.hpp>
+#include <boost/pool/object_pool.hpp>
+#include <boost/scoped_array.hpp>
+
+namespace alone {
+
+template <class T> class FixedAllocator : boost::noncopyable {
+ public:
+ FixedAllocator() : current_(NULL), end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!current_);
+ array_.reset(new T[count]);
+ current_ = array_.get();
+ end_ = current_ + count;
+ }
+
+ T &operator[](std::size_t idx) {
+ return array_.get()[idx];
+ }
+
+ T *New() {
+ T *ret = current_++;
+ UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end");
+ return ret;
+ }
+
+ std::size_t Size() const {
+ return end_ - array_.get();
+ }
+
+ private:
+ boost::scoped_array<T> array_;
+ T *current_, *end_;
+};
+
+class Graph : boost::noncopyable {
+ public:
+ typedef LabeledEdge Edge;
+ typedef search::Vertex Vertex;
+
+ Graph() {}
+
+ void SetCounts(std::size_t vertices, std::size_t edges) {
+ vertices_.Init(vertices);
+ edges_.Init(edges);
+ }
+
+ Vertex *NewVertex() {
+ return vertices_.New();
+ }
+
+ std::size_t VertexSize() const { return vertices_.Size(); }
+
+ Vertex &MutableVertex(std::size_t index) {
+ return vertices_[index];
+ }
+
+ Edge *NewEdge() {
+ return edges_.New();
+ }
+
+ std::size_t EdgeSize() const { return edges_.Size(); }
+
+ void SetRoot(Vertex *root) { root_ = root; }
+
+ Vertex &Root() { return *root_; }
+
+ private:
+ FixedAllocator<Vertex> vertices_;
+ FixedAllocator<Edge> edges_;
+
+ Vertex *root_;
+};
+
+} // namespace alone
+
+#endif // ALONE_GRAPH__
diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc
new file mode 100644
index 00000000..35aea5ed
--- /dev/null
+++ b/klm/alone/just_vocab.cc
@@ -0,0 +1,14 @@
+#include "alone/read.hh"
+#include "util/file_piece.hh"
+
+#include <iostream>
+
+int main() {
+ util::FilePiece f(0, "stdin", &std::cerr);
+ while (true) {
+ try {
+ alone::JustVocab(f, std::cout);
+ } catch (const util::EndOfFileException &e) { break; }
+ std::cout << '\n';
+ }
+}
diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh
new file mode 100644
index 00000000..94d8cbdf
--- /dev/null
+++ b/klm/alone/labeled_edge.hh
@@ -0,0 +1,30 @@
+#ifndef ALONE_LABELED_EDGE__
+#define ALONE_LABELED_EDGE__
+
+#include "search/edge.hh"
+
+#include <string>
+#include <vector>
+
+namespace alone {
+
+class LabeledEdge : public search::Edge {
+ public:
+ LabeledEdge() {}
+
+ void AppendWord(const std::string *word) {
+ words_.push_back(word);
+ }
+
+ const std::vector<const std::string *> &Words() const {
+ return words_;
+ }
+
+ private:
+ // NULL for non-terminals.
+ std::vector<const std::string*> words_;
+};
+
+} // namespace alone
+
+#endif // ALONE_LABELED_EDGE__
diff --git a/klm/alone/main.cc b/klm/alone/main.cc
new file mode 100644
index 00000000..e09ab01d
--- /dev/null
+++ b/klm/alone/main.cc
@@ -0,0 +1,85 @@
+#include "alone/threading.hh"
+#include "search/config.hh"
+#include "search/context.hh"
+#include "util/exception.hh"
+#include "util/file_piece.hh"
+#include "util/usage.hh"
+
+#include <boost/lexical_cast.hpp>
+
+#include <iostream>
+#include <memory>
+
+namespace alone {
+
+template <class Control> void ReadLoop(const std::string &graph_prefix, Control &control) {
+ for (unsigned int sentence = 0; ; ++sentence) {
+ std::stringstream name;
+ name << graph_prefix << '/' << sentence;
+ std::auto_ptr<util::FilePiece> file;
+ try {
+ file.reset(new util::FilePiece(name.str().c_str()));
+ } catch (const util::ErrnoException &e) {
+ if (e.Error() == ENOENT) return;
+ throw;
+ }
+ control.Add(file.release());
+ }
+}
+
+template <class Model> void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) {
+ Model model(model_file);
+ search::Weights weights(weight_str);
+ search::Config config(weights, pop_limit);
+
+ if (threads > 1) {
+#ifdef WITH_THREADS
+ Controller<Model> controller(config, model, threads, std::cout);
+ ReadLoop(graph_prefix, controller);
+#else
+ UTIL_THROW(util::Exception, "Threading support not compiled in.");
+#endif
+ } else {
+ InThread<Model> controller(config, model, std::cout);
+ ReadLoop(graph_prefix, controller);
+ }
+}
+
+void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) {
+ lm::ngram::ModelType model_type;
+ if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING;
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ RunWithModelType<lm::ngram::ProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads);
+ break;
+ case lm::ngram::REST_PROBING:
+ RunWithModelType<lm::ngram::RestProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads);
+ break;
+ default:
+ UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
+ }
+}
+
+} // namespace alone
+
+int main(int argc, char *argv[]) {
+ if (argc < 5 || argc > 6) {
+ std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl;
+ return 1;
+ }
+
+#ifdef WITH_THREADS
+ unsigned thread_count = boost::thread::hardware_concurrency();
+#else
+ unsigned thread_count = 1;
+#endif
+ if (argc == 6) {
+ thread_count = boost::lexical_cast<unsigned>(argv[5]);
+ UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0");
+ }
+ UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line.");
+ alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast<unsigned int>(argv[4]), thread_count);
+
+ util::PrintUsage(std::cerr);
+ return 0;
+}
diff --git a/klm/alone/read.cc b/klm/alone/read.cc
new file mode 100644
index 00000000..0b20be35
--- /dev/null
+++ b/klm/alone/read.cc
@@ -0,0 +1,118 @@
+#include "alone/read.hh"
+
+#include "alone/graph.hh"
+#include "alone/vocab.hh"
+#include "search/arity.hh"
+#include "search/context.hh"
+#include "search/weights.hh"
+#include "util/file_piece.hh"
+
+#include <boost/unordered_set.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <cstdlib>
+
+namespace alone {
+
+namespace {
+
+template <class Model> Graph::Edge &ReadEdge(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) {
+ Graph::Edge *ret = to.NewEdge();
+
+ StringPiece got;
+
+ std::vector<lm::WordIndex> words;
+ unsigned long int terminals = 0;
+ while ("|||" != (got = from.ReadDelimited())) {
+ if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) {
+ // non-terminal
+ char *end_ptr;
+ unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10);
+ UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got);
+ UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?");
+ ret->Add(to.MutableVertex(child));
+ words.push_back(lm::kMaxWordIndex);
+ ret->AppendWord(NULL);
+ } else {
+ const std::pair<const std::string, lm::WordIndex> &found = vocab.FindOrAdd(got);
+ words.push_back(found.second);
+ ret->AppendWord(&found.first);
+ ++terminals;
+ }
+ }
+ if (final) {
+ // This is not counted for the word penalty.
+ words.push_back(vocab.EndSentence().second);
+ ret->AppendWord(&vocab.EndSentence().first);
+ }
+ // Hard-coded word penalty.
+ float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10;
+ ret->InitRule().Init(context, additive, words, final);
+ unsigned int arity = ret->GetRule().Arity();
+ UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity);
+ return *ret;
+}
+
+} // namespace
+
+// TODO: refactor
+void JustVocab(util::FilePiece &from, std::ostream &out) {
+ boost::unordered_set<std::string> seen;
+ unsigned long int vertices = from.ReadULong();
+ from.ReadULong(); // edges
+ UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero");
+ UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts");
+ std::string temp;
+ for (unsigned long int i = 0; i < vertices; ++i) {
+ unsigned long int edge_count = from.ReadULong();
+ UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count");
+ for (unsigned long int e = 0; e < edge_count; ++e) {
+ StringPiece got;
+ while ("|||" != (got = from.ReadDelimited())) {
+ if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue;
+ temp.assign(got.data(), got.size());
+ if (seen.insert(temp).second) out << temp << ' ';
+ }
+ from.ReadLine(); // weights
+ }
+ }
+ // Eat sentence
+ from.ReadLine();
+}
+
+template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab) {
+ unsigned long int vertices;
+ try {
+ vertices = from.ReadULong();
+ } catch (const util::EndOfFileException &e) { return false; }
+ unsigned long int edges = from.ReadULong();
+ UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices);
+ UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges);
+ --vertices;
+ --edges;
+ UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts");
+ to.SetCounts(vertices, edges);
+ Graph::Vertex *vertex;
+ for (unsigned long int i = 0; ; ++i) {
+ vertex = to.NewVertex();
+ unsigned long int edge_count = from.ReadULong();
+ bool root = (i == vertices - 1);
+ UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count");
+ for (unsigned long int e = 0; e < edge_count; ++e) {
+ vertex->Add(ReadEdge(context, from, to, vocab, root));
+ }
+ vertex->FinishedAdding();
+ if (root) break;
+ }
+ to.SetRoot(vertex);
+ StringPiece str = from.ReadLine();
+ UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root");
+ // The edge
+ from.ReadLine();
+ return true;
+}
+
+template bool ReadCDec(search::Context<lm::ngram::ProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
+template bool ReadCDec(search::Context<lm::ngram::RestProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
+
+} // namespace alone
diff --git a/klm/alone/read.hh b/klm/alone/read.hh
new file mode 100644
index 00000000..10769a86
--- /dev/null
+++ b/klm/alone/read.hh
@@ -0,0 +1,29 @@
+#ifndef ALONE_READ__
+#define ALONE_READ__
+
+#include "util/exception.hh"
+
+#include <iosfwd>
+
+namespace util { class FilePiece; }
+
+namespace search { template <class Model> class Context; }
+
+namespace alone {
+
+class Graph;
+class Vocab;
+
+class FormatException : public util::Exception {
+ public:
+ FormatException() {}
+ ~FormatException() throw() {}
+};
+
+void JustVocab(util::FilePiece &from, std::ostream &to);
+
+template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab);
+
+} // namespace alone
+
+#endif // ALONE_READ__
diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc
new file mode 100644
index 00000000..475386b6
--- /dev/null
+++ b/klm/alone/threading.cc
@@ -0,0 +1,80 @@
+#include "alone/threading.hh"
+
+#include "alone/assemble.hh"
+#include "alone/graph.hh"
+#include "alone/read.hh"
+#include "alone/vocab.hh"
+#include "lm/model.hh"
+#include "search/context.hh"
+#include "search/vertex_generator.hh"
+
+#include <boost/ref.hpp>
+#include <boost/scoped_ptr.hpp>
+#include <boost/utility/in_place_factory.hpp>
+
+#include <sstream>
+
+namespace alone {
+template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) {
+ search::Context<Model> context(config, model);
+ Graph graph;
+ Vocab vocab(model.GetVocabulary());
+ {
+ boost::scoped_ptr<util::FilePiece> in(in_ptr);
+ ReadCDec(context, *in, graph, vocab);
+ }
+
+ for (std::size_t i = 0; i < graph.VertexSize(); ++i) {
+ search::VertexGenerator(context, graph.MutableVertex(i));
+ }
+ search::PartialVertex top = graph.Root().RootPartial();
+ if (top.Empty()) {
+ out << "NO PATH FOUND";
+ } else {
+ search::PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ top = continuation;
+ }
+ out << top.End() << " ||| " << top.End().Bound() << std::endl;
+ }
+}
+
+template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
+template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
+
+#ifdef WITH_THREADS
+template <class Model> void DecodeHandler<Model>::operator()(Input message) {
+ std::stringstream assemble;
+ Decode(config_, model_, message.file, assemble);
+ Produce(message.sentence_id, assemble.str());
+}
+
+template <class Model> void DecodeHandler<Model>::Produce(unsigned int sentence_id, const std::string &str) {
+ Output out;
+ out.sentence_id = sentence_id;
+ out.str = new std::string(str);
+ out_.Produce(out);
+}
+
+void PrintHandler::operator()(Output message) {
+ unsigned int relative = message.sentence_id - done_;
+ if (waiting_.size() <= relative) waiting_.resize(relative + 1);
+ waiting_[relative] = message.str;
+ for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) {
+ out_ << *lead;
+ delete lead;
+ }
+}
+
+template <class Model> Controller<Model>::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) :
+ sentence_id_(0),
+ printer_(decode_workers, 1, boost::ref(to), Output::Poison()),
+ decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {}
+
+template class Controller<lm::ngram::RestProbingModel>;
+template class Controller<lm::ngram::ProbingModel>;
+
+#endif
+
+} // namespace alone
diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh
new file mode 100644
index 00000000..0ab0f739
--- /dev/null
+++ b/klm/alone/threading.hh
@@ -0,0 +1,129 @@
+#ifndef ALONE_THREADING__
+#define ALONE_THREADING__
+
+#ifdef WITH_THREADS
+#include "util/pcqueue.hh"
+#include "util/pool.hh"
+#endif
+
+#include <iosfwd>
+#include <queue>
+#include <string>
+
+namespace util {
+class FilePiece;
+} // namespace util
+
+namespace search {
+class Config;
+template <class Model> class Context;
+} // namespace search
+
+namespace alone {
+
+template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out);
+
+class Graph;
+
+#ifdef WITH_THREADS
+struct SentenceID {
+ unsigned int sentence_id;
+ bool operator==(const SentenceID &other) const {
+ return sentence_id == other.sentence_id;
+ }
+};
+
+struct Input : public SentenceID {
+ util::FilePiece *file;
+ static Input Poison() {
+ Input ret;
+ ret.sentence_id = static_cast<unsigned int>(-1);
+ ret.file = NULL;
+ return ret;
+ }
+};
+
+struct Output : public SentenceID {
+ std::string *str;
+ static Output Poison() {
+ Output ret;
+ ret.sentence_id = static_cast<unsigned int>(-1);
+ ret.str = NULL;
+ return ret;
+ }
+};
+
+template <class Model> class DecodeHandler {
+ public:
+ typedef Input Request;
+
+ DecodeHandler(const search::Config &config, const Model &model, util::PCQueue<Output> &out) : config_(config), model_(model), out_(out) {}
+
+ void operator()(Input message);
+
+ private:
+ void Produce(unsigned int sentence_id, const std::string &str);
+
+ const search::Config &config_;
+
+ const Model &model_;
+
+ util::PCQueue<Output> &out_;
+};
+
+class PrintHandler {
+ public:
+ typedef Output Request;
+
+ explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {}
+
+ void operator()(Output message);
+
+ private:
+ std::ostream &out_;
+ std::deque<std::string*> waiting_;
+ unsigned int done_;
+};
+
+template <class Model> class Controller {
+ public:
+ // This config must remain valid.
+ explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to);
+
+ // Takes ownership of in.
+ void Add(util::FilePiece *in) {
+ Input input;
+ input.sentence_id = sentence_id_++;
+ input.file = in;
+ decoder_.Produce(input);
+ }
+
+ private:
+ unsigned int sentence_id_;
+
+ util::Pool<PrintHandler> printer_;
+
+ util::Pool<DecodeHandler<Model> > decoder_;
+};
+#endif
+
+// Same API as controller.
+template <class Model> class InThread {
+ public:
+ InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {}
+
+ // Takes ownership of in.
+ void Add(util::FilePiece *in) {
+ Decode(config_, model_, in, to_);
+ }
+
+ private:
+ const search::Config &config_;
+
+ const Model &model_;
+
+ std::ostream &to_;
+};
+
+} // namespace alone
+#endif // ALONE_THREADING__
diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc
new file mode 100644
index 00000000..ffe55301
--- /dev/null
+++ b/klm/alone/vocab.cc
@@ -0,0 +1,19 @@
+#include "alone/vocab.hh"
+
+#include "lm/virtual_interface.hh"
+#include "util/string_piece.hh"
+
+namespace alone {
+
+Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("</s>")) {}
+
+const std::pair<const std::string, lm::WordIndex> &Vocab::FindOrAdd(const StringPiece &str) {
+ Map::const_iterator i(FindStringPiece(map_, str));
+ if (i != map_.end()) return *i;
+ std::pair<std::string, lm::WordIndex> to_ins;
+ to_ins.first.assign(str.data(), str.size());
+ to_ins.second = backing_.Index(str);
+ return *map_.insert(to_ins).first;
+}
+
+} // namespace alone
diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh
new file mode 100644
index 00000000..3ac0f542
--- /dev/null
+++ b/klm/alone/vocab.hh
@@ -0,0 +1,34 @@
+#ifndef ALONE_VOCAB__
+#define ALONE_VOCAB__
+
+#include "lm/word_index.hh"
+#include "util/string_piece.hh"
+
+#include <boost/functional/hash/hash.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <string>
+
+namespace lm { namespace base { class Vocabulary; } }
+
+namespace alone {
+
+class Vocab {
+ public:
+ explicit Vocab(const lm::base::Vocabulary &backing);
+
+ const std::pair<const std::string, lm::WordIndex> &FindOrAdd(const StringPiece &str);
+
+ const std::pair<const std::string, lm::WordIndex> &EndSentence() const { return end_sentence_; }
+
+ private:
+ typedef boost::unordered_map<std::string, lm::WordIndex> Map;
+ Map map_;
+
+ const lm::base::Vocabulary &backing_;
+
+ const std::pair<const std::string, lm::WordIndex> &end_sentence_;
+};
+
+} // namespace alone
+#endif // ALONE_VCOAB__
diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc
new file mode 100644
index 00000000..0267cd4e
--- /dev/null
+++ b/klm/lm/fragment.cc
@@ -0,0 +1,37 @@
+#include "lm/binary_format.hh"
+#include "lm/model.hh"
+#include "lm/left.hh"
+#include "util/tokenize_piece.hh"
+
+template <class Model> void Query(const char *name) {
+ Model model(name);
+ std::string line;
+ lm::ngram::ChartState ignored;
+ while (getline(std::cin, line)) {
+ lm::ngram::RuleScore<Model> scorer(model, ignored);
+ for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) {
+ scorer.Terminal(model.GetVocabulary().Index(*i));
+ }
+ std::cout << scorer.Finish() << '\n';
+ }
+}
+
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected model file name." << std::endl;
+ return 1;
+ }
+ const char *name = argv[1];
+ lm::ngram::ModelType model_type = lm::ngram::PROBING;
+ lm::ngram::RecognizeBinary(name, model_type);
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ Query<lm::ngram::ProbingModel>(name);
+ break;
+ case lm::ngram::REST_PROBING:
+ Query<lm::ngram::RestProbingModel>(name);
+ break;
+ default:
+ std::cerr << "Model type not supported yet." << std::endl;
+ }
+}
diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh
new file mode 100644
index 00000000..1dede359
--- /dev/null
+++ b/klm/lm/partial.hh
@@ -0,0 +1,167 @@
+#ifndef LM_PARTIAL__
+#define LM_PARTIAL__
+
+#include "lm/return.hh"
+#include "lm/state.hh"
+
+#include <algorithm>
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+
+struct ExtendReturn {
+ float adjust;
+ bool make_full;
+ unsigned char next_use;
+};
+
+template <class Model> ExtendReturn ExtendLoop(
+ const Model &model,
+ unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
+ const uint64_t *pointers, const uint64_t *pointers_end,
+ uint64_t *&pointers_write,
+ float *backoff_write) {
+ unsigned char add_length = add_rend - add_rbegin;
+
+ float backoff_buf[2][KENLM_MAX_ORDER - 1];
+ float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
+ std::copy(backoff_start, backoff_start + add_length, backoff_in);
+
+ ExtendReturn value;
+ value.make_full = false;
+ value.adjust = 0.0;
+ value.next_use = add_length;
+
+ unsigned char i = 0;
+ unsigned char length = pointers_end - pointers;
+ // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
+ if (pointers_write) {
+ // Using full context, writing to new left state.
+ for (; i < length; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ if (ret.independent_left) {
+ value.adjust += ret.prob;
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ value.adjust += ret.rest;
+ *pointers_write++ = ret.extend_left;
+ if (value.next_use != add_length) {
+ value.make_full = true;
+ ++i;
+ break;
+ }
+ }
+ }
+ // Using some of the new context.
+ for (; i < length && value.next_use; ++i) {
+ FullScoreReturn ret(model.ExtendLeft(
+ add_rbegin, add_rbegin + value.next_use,
+ backoff_in,
+ pointers[i], i + seen + 1,
+ backoff_out,
+ value.next_use));
+ std::swap(backoff_in, backoff_out);
+ value.adjust += ret.prob;
+ }
+ float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
+ // Using none of the new context.
+ value.adjust += unrest;
+
+ std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
+ return value;
+}
+
+template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
+ assert(seen < reveal.length || reveal_full);
+ uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
+ left.pointers, left.pointers + left.length,
+ pointers_write,
+ left.full ? backoff_buffer : (right.backoff + right.length)));
+ if (reveal_full) {
+ left.length = 0;
+ value.make_full = true;
+ } else {
+ left.length = pointers_write - left.pointers;
+ value.make_full |= (left.length == model.Order() - 1);
+ }
+ if (left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ // If left wasn't full when it came in, put words into right state.
+ std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
+ right.length += value.next_use;
+ left.full = value.make_full || (right.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
+ assert(seen < reveal.length || reveal.full);
+ uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
+ ExtendReturn value(ExtendLoop(
+ model,
+ seen, right.words, right.words + right.length, right.backoff,
+ reveal.pointers + seen, reveal.pointers + reveal.length,
+ pointers_write,
+ right.backoff));
+ if (reveal.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
+ right.length = 0;
+ value.make_full = true;
+ } else {
+ right.length = value.next_use;
+ value.make_full |= (right.length == model.Order() - 1);
+ }
+ if (!left.full) {
+ left.length = pointers_write - left.pointers;
+ left.full = value.make_full || (left.length == model.Order() - 1);
+ }
+ return value.adjust;
+}
+
+template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
+ assert(first_right.length < KENLM_MAX_ORDER);
+ assert(second_left.length < KENLM_MAX_ORDER);
+ assert(between_length < KENLM_MAX_ORDER - 1);
+ uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
+ float backoff_buffer[KENLM_MAX_ORDER - 1];
+ ExtendReturn value(ExtendLoop(
+ model,
+ between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
+ second_left.pointers, second_left.pointers + second_left.length,
+ pointers_write,
+ second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
+ if (second_left.full) {
+ for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
+ } else {
+ std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
+ second_right.length += value.next_use;
+ value.make_full |= (second_right.length == model.Order() - 1);
+ }
+ if (!first_left.full) {
+ first_left.length = pointers_write - first_left.pointers;
+ first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
+ }
+ assert(first_left.length < KENLM_MAX_ORDER);
+ assert(second_right.length < KENLM_MAX_ORDER);
+ return value.adjust;
+}
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_PARTIAL__
diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc
new file mode 100644
index 00000000..8d309c85
--- /dev/null
+++ b/klm/lm/partial_test.cc
@@ -0,0 +1,199 @@
+#include "lm/partial.hh"
+
+#include "lm/left.hh"
+#include "lm/model.hh"
+#include "util/tokenize_piece.hh"
+
+#define BOOST_TEST_MODULE PartialTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+const char *TestLocation() {
+ if (boost::unit_test::framework::master_test_suite().argc < 2) {
+ return "test.arpa";
+ }
+ return boost::unit_test::framework::master_test_suite().argv[1];
+}
+
+Config SilentConfig() {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+ return config;
+}
+
+struct ModelFixture {
+ ModelFixture() : m(TestLocation(), SilentConfig()) {}
+
+ RestProbingModel m;
+};
+
+BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture)
+
+BOOST_AUTO_TEST_CASE(SimpleBefore) {
+ Left left;
+ left.full = false;
+ left.length = 0;
+ Right right;
+ right.length = 0;
+
+ Right reveal;
+ reveal.length = 1;
+ WordIndex period = m.GetVocabulary().Index(".");
+ reveal.words[0] = period;
+ reveal.backoff[0] = -0.845098;
+
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(1, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+
+ WordIndex more = m.GetVocabulary().Index("more");
+ reveal.words[1] = more;
+ reveal.backoff[1] = -0.4771212;
+ reveal.length = 2;
+ BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001);
+ BOOST_CHECK_EQUAL(0, left.length);
+ BOOST_CHECK(!left.full);
+ BOOST_CHECK_EQUAL(2, right.length);
+ BOOST_CHECK_EQUAL(period, right.words[0]);
+ BOOST_CHECK_EQUAL(more, right.words[1]);
+ BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
+ BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001);
+}
+
+BOOST_AUTO_TEST_CASE(AlsoWouldConsider) {
+ WordIndex would = m.GetVocabulary().Index("would");
+ WordIndex consider = m.GetVocabulary().Index("consider");
+
+ ChartState current;
+ current.left.length = 1;
+ current.left.pointers[0] = would;
+ current.left.full = false;
+ current.right.length = 1;
+ current.right.words[0] = would;
+ current.right.backoff[0] = -0.30103;
+
+ Left after;
+ after.full = false;
+ after.length = 1;
+ after.pointers[0] = consider;
+
+ // adjustment for would consider
+ BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001);
+
+ BOOST_CHECK_EQUAL(2, current.left.length);
+ BOOST_CHECK_EQUAL(would, current.left.pointers[0]);
+ BOOST_CHECK_EQUAL(false, current.left.full);
+
+ WordIndex also = m.GetVocabulary().Index("also");
+ Right before;
+ before.length = 1;
+ before.words[0] = also;
+ before.backoff[0] = -0.30103;
+ // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)]
+ // p(also -> would) = -2, p(also would -> consider) = -3
+ BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001);
+ BOOST_CHECK_EQUAL(0, current.left.length);
+ BOOST_CHECK(current.left.full);
+ BOOST_CHECK_EQUAL(2, current.right.length);
+ BOOST_CHECK_EQUAL(would, current.right.words[0]);
+ BOOST_CHECK_EQUAL(also, current.right.words[1]);
+}
+
+BOOST_AUTO_TEST_CASE(EndSentence) {
+ WordIndex loin = m.GetVocabulary().Index("loin");
+ WordIndex period = m.GetVocabulary().Index(".");
+ WordIndex eos = m.GetVocabulary().EndSentence();
+
+ ChartState between;
+ between.left.length = 1;
+ between.left.pointers[0] = eos;
+ between.left.full = true;
+ between.right.length = 0;
+
+ Right before;
+ before.words[0] = period;
+ before.words[1] = loin;
+ before.backoff[0] = -0.845098;
+ before.backoff[1] = 0.0;
+
+ before.length = 1;
+ BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001);
+ BOOST_CHECK_EQUAL(0, between.left.length);
+}
+
+float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) {
+ RuleScore<RestProbingModel> scorer(model, out);
+ for (unsigned int *i = begin; i < end; ++i) {
+ scorer.Terminal(*i);
+ }
+ return scorer.Finish();
+}
+
+void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) {
+ Right before(before_in);
+ Left after(after_in);
+ after.full = false;
+ float got = 0.0;
+ for (unsigned int i = 1; i < 5; ++i) {
+ if (before_in.length >= i) {
+ before.length = i;
+ got += RevealBefore(model, before, i - 1, false, between.left, between.right);
+ }
+ if (after_in.length >= i) {
+ after.length = i;
+ got += RevealAfter(model, between.left, between.right, after, i - 1);
+ }
+ }
+ if (after_in.full) {
+ after.full = true;
+ got += RevealAfter(model, between.left, between.right, after, after.length);
+ }
+ if (before_full) {
+ got += RevealBefore(model, before, before.length, true, between.left, between.right);
+ }
+ // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.
+ BOOST_CHECK(fabs(expect - got) < 0.001);
+}
+
+void FullDivide(const RestProbingModel &model, StringPiece str) {
+ std::vector<WordIndex> indices;
+ for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
+ indices.push_back(model.GetVocabulary().Index(*i));
+ }
+ ChartState full_state;
+ float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state);
+
+ ChartState before_state;
+ before_state.left.full = false;
+ RuleScore<RestProbingModel> before_scorer(model, before_state);
+ float before_score = 0.0;
+ for (unsigned int before = 0; before < indices.size(); ++before) {
+ for (unsigned int after = before; after <= indices.size(); ++after) {
+ ChartState after_state, between_state;
+ float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state);
+ float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state);
+ CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left);
+ }
+ before_scorer.Terminal(indices[before]);
+ before_score = before_scorer.Finish();
+ }
+}
+
+BOOST_AUTO_TEST_CASE(Strings) {
+ FullDivide(m, "also would consider");
+ FullDivide(m, "looking on a little more loin . </s>");
+ FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+} // namespace
+} // namespace ngram
+} // namespace lm
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
new file mode 100644
index 00000000..e8b14363
--- /dev/null
+++ b/klm/search/Jamfile
@@ -0,0 +1,5 @@
+lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
+
+import testing ;
+
+unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ;
diff --git a/klm/search/arity.hh b/klm/search/arity.hh
new file mode 100644
index 00000000..09c2c671
--- /dev/null
+++ b/klm/search/arity.hh
@@ -0,0 +1,8 @@
+#ifndef SEARCH_ARITY__
+#define SEARCH_ARITY__
+namespace search {
+
+const unsigned int kMaxArity = 2;
+
+} // namespace search
+#endif // SEARCH_ARITY__
diff --git a/klm/search/config.hh b/klm/search/config.hh
new file mode 100644
index 00000000..ef8e2354
--- /dev/null
+++ b/klm/search/config.hh
@@ -0,0 +1,25 @@
+#ifndef SEARCH_CONFIG__
+#define SEARCH_CONFIG__
+
+#include "search/weights.hh"
+#include "util/string_piece.hh"
+
+namespace search {
+
+class Config {
+ public:
+ Config(const Weights &weights, unsigned int pop_limit) :
+ weights_(weights), pop_limit_(pop_limit) {}
+
+ const Weights &GetWeights() const { return weights_; }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ private:
+ Weights weights_;
+ unsigned int pop_limit_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONFIG__
diff --git a/klm/search/context.hh b/klm/search/context.hh
new file mode 100644
index 00000000..27940053
--- /dev/null
+++ b/klm/search/context.hh
@@ -0,0 +1,65 @@
+#ifndef SEARCH_CONTEXT__
+#define SEARCH_CONTEXT__
+
+#include "lm/model.hh"
+#include "search/config.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/exception.hh"
+
+#include <boost/pool/object_pool.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <vector>
+
+namespace search {
+
+class Weights;
+
+class ContextBase {
+ public:
+ explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
+
+ Final *NewFinal() {
+ Final *ret = final_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ VertexNode *NewVertexNode() {
+ VertexNode *ret = vertex_node_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ void DeleteVertexNode(VertexNode *node) {
+ vertex_node_pool_.destroy(node);
+ }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ const Weights &GetWeights() const { return weights_; }
+
+ private:
+ boost::object_pool<Final> final_pool_;
+ boost::object_pool<VertexNode> vertex_node_pool_;
+
+ unsigned int pop_limit_;
+
+ const Weights &weights_;
+};
+
+template <class Model> class Context : public ContextBase {
+ public:
+ Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {}
+
+ const Model &LanguageModel() const { return model_; }
+
+ private:
+ const Model &model_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONTEXT__
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
new file mode 100644
index 00000000..77ab0ade
--- /dev/null
+++ b/klm/search/edge.hh
@@ -0,0 +1,31 @@
+#ifndef SEARCH_EDGE__
+#define SEARCH_EDGE__
+
+#include "lm/state.hh"
+#include "search/arity.hh"
+#include "search/rule.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+
+#include <queue>
+
+namespace search {
+
+struct PartialEdge {
+ Score score;
+ // Terminals
+ lm::ngram::ChartState between[kMaxArity + 1];
+ // Non-terminals
+ PartialVertex nt[kMaxArity];
+
+ const lm::ngram::ChartState &CompletedState() const {
+ return between[0];
+ }
+
+ bool operator<(const PartialEdge &other) const {
+ return score < other.score;
+ }
+};
+
+} // namespace search
+#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
new file mode 100644
index 00000000..56239dfb
--- /dev/null
+++ b/klm/search/edge_generator.cc
@@ -0,0 +1,120 @@
+#include "search/edge_generator.hh"
+
+#include "lm/left.hh"
+#include "lm/partial.hh"
+#include "search/context.hh"
+#include "search/vertex.hh"
+#include "search/vertex_generator.hh"
+
+#include <numeric>
+
+namespace search {
+
+EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
+/* for (unsigned char i = 0; i < edge.Arity(); ++i) {
+ root.nt[i] = edge.GetVertex(i).RootPartial();
+ }
+ for (unsigned char i = edge.Arity(); i < 2; ++i) {
+ root.nt[i] = kBlankPartialVertex;
+ }*/
+ generate_.push(&root);
+ top_score_ = root.score;
+}
+
+namespace {
+
+template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
+ memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
+
+ float ret = 0.0;
+ lm::ngram::ChartState *before, *after;
+ if (victim == 0) {
+ before = &update.between[0];
+ after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
+ } else {
+ assert(victim == 1);
+ assert(arity == 2);
+ before = &update.between[previous.nt[0].Complete() ? 0 : 1];
+ after = &update.between[2];
+ }
+ const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
+ const PartialVertex &update_nt = update.nt[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ float just_after = 0.0;
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
+ ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ if (victim == 0) {
+ update.between[0].right = after->right;
+ } else {
+ update.between[2].left = before->left;
+ }
+ }
+ return previous.score + (ret + just_after) * context.GetWeights().LM();
+}
+
+} // namespace
+
+template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
+ assert(!generate_.empty());
+ PartialEdge &top = *generate_.top();
+ generate_.pop();
+ unsigned int victim = 0;
+ unsigned char lowest_length = 255;
+ for (unsigned char i = 0; i != arity_; ++i) {
+ if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
+ lowest_length = top.nt[i].Length();
+ victim = i;
+ }
+ }
+ if (lowest_length == 255) {
+ // All states report complete.
+ top.between[0].right = top.between[arity_].right;
+ // Now top.between[0] is the full edge state.
+ top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
+ return &top;
+ }
+
+ unsigned int stay = !victim;
+ PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
+ float old_bound = top.nt[victim].Bound();
+ // The alternate's score will change because alternate.nt[victim] changes.
+ bool split = top.nt[victim].Split(continuation.nt[victim]);
+ // top is now the alternate.
+
+ continuation.nt[stay] = top.nt[stay];
+ continuation.score = FastScore(context, victim, arity_, top, continuation);
+ // TODO: dedupe?
+ generate_.push(&continuation);
+
+ if (split) {
+ // We have an alternate.
+ top.score += top.nt[victim].Bound() - old_bound;
+ // TODO: dedupe?
+ generate_.push(&top);
+ } else {
+ partial_edge_pool.free(&top);
+ }
+
+ top_score_ = generate_.top()->score;
+ return NULL;
+}
+
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+
+} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
new file mode 100644
index 00000000..875ccc5e
--- /dev/null
+++ b/klm/search/edge_generator.hh
@@ -0,0 +1,58 @@
+#ifndef SEARCH_EDGE_GENERATOR__
+#define SEARCH_EDGE_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/note.hh"
+
+#include <boost/pool/pool.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+template <class Model> class Context;
+
+class VertexGenerator;
+
+struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> {
+ bool operator()(const PartialEdge *first, const PartialEdge *second) const {
+ return *first < *second;
+ }
+};
+
+class EdgeGenerator {
+ public:
+ EdgeGenerator(PartialEdge &root, unsigned char arity, Note note);
+
+ Score TopScore() const {
+ return top_score_;
+ }
+
+ Note GetNote() const {
+ return note_;
+ }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return NULL.
+ template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool);
+
+ private:
+ Score top_score_;
+
+ unsigned char arity_;
+
+ typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;
+ Generate generate_;
+
+ Note note_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_GENERATOR__
diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc
new file mode 100644
index 00000000..e3ae6ebf
--- /dev/null
+++ b/klm/search/edge_queue.cc
@@ -0,0 +1,25 @@
+#include "search/edge_queue.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) {
+ take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+}
+
+/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) {
+ // Ignore empty edges.
+ for (unsigned char i = 0; i < edge.Arity(); ++i) {
+ PartialVertex root(edge.GetVertex(i).RootPartial());
+ if (root.Empty()) return;
+ total_score += root.Bound();
+ }
+ PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+ allocated.score = total_score;
+}*/
+
+} // namespace search
diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh
new file mode 100644
index 00000000..187eaed7
--- /dev/null
+++ b/klm/search/edge_queue.hh
@@ -0,0 +1,73 @@
+#ifndef SEARCH_EDGE_QUEUE__
+#define SEARCH_EDGE_QUEUE__
+
+#include "search/edge.hh"
+#include "search/edge_generator.hh"
+#include "search/note.hh"
+
+#include <boost/pool/pool.hpp>
+#include <boost/pool/object_pool.hpp>
+
+#include <queue>
+
+namespace search {
+
+template <class Model> class Context;
+
+class EdgeQueue {
+ public:
+ explicit EdgeQueue(unsigned int pop_limit_hint);
+
+ PartialEdge &InitializeEdge() {
+ return *take_;
+ }
+
+ void AddEdge(unsigned char arity, Note note) {
+ generate_.push(edge_pool_.construct(*take_, arity, note));
+ take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+ }
+
+ bool Empty() const { return generate_.empty(); }
+
+ /* Generate hypotheses and send them to output. Normally, output is a
+ * VertexGenerator, but the decoder may want to route edges to different
+ * vertices i.e. if they have different LHS non-terminal labels.
+ */
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ int to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ EdgeGenerator *top = generate_.top();
+ generate_.pop();
+ PartialEdge *ret = top->Pop(context, partial_edge_pool_);
+ if (ret) {
+ output.NewHypothesis(*ret, top->GetNote());
+ --to_pop;
+ if (top->TopScore() != -kScoreInf) {
+ generate_.push(top);
+ }
+ } else {
+ generate_.push(top);
+ }
+ }
+ output.FinishedSearch();
+ }
+
+ private:
+ boost::object_pool<EdgeGenerator> edge_pool_;
+
+ struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
+ bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
+ return first->TopScore() < second->TopScore();
+ }
+ };
+
+ typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate;
+ Generate generate_;
+
+ boost::pool<> partial_edge_pool_;
+
+ PartialEdge *take_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_QUEUE__
diff --git a/klm/search/final.hh b/klm/search/final.hh
new file mode 100644
index 00000000..1b3092ac
--- /dev/null
+++ b/klm/search/final.hh
@@ -0,0 +1,39 @@
+#ifndef SEARCH_FINAL__
+#define SEARCH_FINAL__
+
+#include "search/arity.hh"
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <boost/array.hpp>
+
+namespace search {
+
+class Final {
+ public:
+ typedef boost::array<const Final*, search::kMaxArity> ChildArray;
+
+ void Reset(Score bound, Note note, const Final &left, const Final &right) {
+ bound_ = bound;
+ note_ = note;
+ children_[0] = &left;
+ children_[1] = &right;
+ }
+
+ const ChildArray &Children() const { return children_; }
+
+ Note GetNote() const { return note_; }
+
+ Score Bound() const { return bound_; }
+
+ private:
+ Score bound_;
+
+ Note note_;
+
+ ChildArray children_;
+};
+
+} // namespace search
+
+#endif // SEARCH_FINAL__
diff --git a/klm/search/note.hh b/klm/search/note.hh
new file mode 100644
index 00000000..50bed06e
--- /dev/null
+++ b/klm/search/note.hh
@@ -0,0 +1,12 @@
+#ifndef SEARCH_NOTE__
+#define SEARCH_NOTE__
+
+namespace search {
+
+union Note {
+ const void *vp;
+};
+
+} // namespace search
+
+#endif // SEARCH_NOTE__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
new file mode 100644
index 00000000..5b00207e
--- /dev/null
+++ b/klm/search/rule.cc
@@ -0,0 +1,43 @@
+#include "search/rule.hh"
+
+#include "search/context.hh"
+#include "search/final.hh"
+
+#include <ostream>
+
+#include <math.h>
+
+namespace search {
+
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
+ unsigned int oov_count = 0;
+ float prob = 0.0;
+ const Model &model = context.LanguageModel();
+ const lm::WordIndex oov = model.GetVocabulary().NotFound();
+ for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
+ // TODO: optimize
+ if (prepend_bos && (word == words.begin())) {
+ scorer.BeginSentence();
+ }
+ for (; ; ++word) {
+ if (word == words.end()) {
+ prob += scorer.Finish();
+ return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
+ }
+ if (*word == kNonTerminal) break;
+ if (*word == oov) ++oov_count;
+ scorer.Terminal(*word);
+ }
+ prob += scorer.Finish();
+ }
+}
+
+template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+
+} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
new file mode 100644
index 00000000..0ce2794d
--- /dev/null
+++ b/klm/search/rule.hh
@@ -0,0 +1,20 @@
+#ifndef SEARCH_RULE__
+#define SEARCH_RULE__
+
+#include "lm/left.hh"
+#include "lm/word_index.hh"
+#include "search/types.hh"
+
+#include <vector>
+
+namespace search {
+
+template <class Model> class Context;
+
+const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
+
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
+
+} // namespace search
+
+#endif // SEARCH_RULE__
diff --git a/klm/search/source.hh b/klm/search/source.hh
new file mode 100644
index 00000000..11839f7b
--- /dev/null
+++ b/klm/search/source.hh
@@ -0,0 +1,48 @@
+#ifndef SEARCH_SOURCE__
+#define SEARCH_SOURCE__
+
+#include "search/types.hh"
+
+#include <assert.h>
+#include <vector>
+
+namespace search {
+
+template <class Final> class Source {
+ public:
+ Source() : bound_(kScoreInf) {}
+
+ Index Size() const {
+ return final_.size();
+ }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ const Final &operator[](Index index) const {
+ return *final_[index];
+ }
+
+ Score ScoreOrBound(Index index) const {
+ return Size() > index ? final_[index]->Total() : Bound();
+ }
+
+ protected:
+ void AddFinal(const Final &store) {
+ final_.push_back(&store);
+ }
+
+ void SetBound(Score to) {
+ assert(to <= bound_ + 0.001);
+ bound_ = to;
+ }
+
+ private:
+ std::vector<const Final *> final_;
+
+ Score bound_;
+};
+
+} // namespace search
+#endif // SEARCH_SOURCE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
new file mode 100644
index 00000000..9726379f
--- /dev/null
+++ b/klm/search/types.hh
@@ -0,0 +1,18 @@
+#ifndef SEARCH_TYPES__
+#define SEARCH_TYPES__
+
+#include <cmath>
+
+namespace search {
+
+typedef float Score;
+const Score kScoreInf = INFINITY;
+
+// This could have been an enum but gcc wants 4 bytes.
+typedef bool ExtendDirection;
+const ExtendDirection kExtendLeft = 0;
+const ExtendDirection kExtendRight = 1;
+
+} // namespace search
+
+#endif // SEARCH_TYPES__
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
new file mode 100644
index 00000000..cc53c0dd
--- /dev/null
+++ b/klm/search/vertex.cc
@@ -0,0 +1,48 @@
+#include "search/vertex.hh"
+
+#include "search/context.hh"
+
+#include <algorithm>
+#include <functional>
+
+#include <assert.h>
+
+namespace search {
+
+namespace {
+
+struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
+ bool operator()(const VertexNode *first, const VertexNode *second) const {
+ return first->Bound() > second->Bound();
+ }
+};
+
+} // namespace
+
+void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
+ if (Complete()) {
+ assert(end_);
+ assert(extend_.empty());
+ bound_ = end_->Bound();
+ return;
+ }
+ if (extend_.size() == 1 && parent_ptr) {
+ *parent_ptr = extend_[0];
+ extend_[0]->SortAndSet(context, parent_ptr);
+ context.DeleteVertexNode(this);
+ return;
+ }
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->SortAndSet(context, &*i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+namespace {
+VertexNode kBlankVertexNode;
+} // namespace
+
+PartialVertex kBlankPartialVertex(kBlankVertexNode);
+
+} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
new file mode 100644
index 00000000..e1a9ad11
--- /dev/null
+++ b/klm/search/vertex.hh
@@ -0,0 +1,158 @@
+#ifndef SEARCH_VERTEX__
+#define SEARCH_VERTEX__
+
+#include "lm/left.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+
+#include <boost/unordered_set.hpp>
+
+#include <queue>
+#include <vector>
+
+#include <stdint.h>
+
+namespace search {
+
+class ContextBase;
+
+class VertexNode {
+ public:
+ VertexNode() : end_(NULL) {}
+
+ void InitRoot() {
+ extend_.clear();
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ bound_ = -kScoreInf;
+ end_ = NULL;
+ }
+
+ lm::ngram::ChartState &MutableState() { return state_; }
+ bool &MutableRightFull() { return right_full_; }
+
+ void AddExtend(VertexNode *next) {
+ extend_.push_back(next);
+ }
+
+ void SetEnd(Final *end) { end_ = end; }
+
+ Final &MutableEnd() { return *end_; }
+
+ void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+
+ // Should only happen to a root node when the entire vertex is empty.
+ bool Empty() const {
+ return !end_ && extend_.empty();
+ }
+
+ bool Complete() const {
+ return end_;
+ }
+
+ const lm::ngram::ChartState &State() const { return state_; }
+ bool RightFull() const { return right_full_; }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ unsigned char Length() const {
+ return state_.left.length + state_.right.length;
+ }
+
+ // May be NULL.
+ const Final *End() const { return end_; }
+
+ const VertexNode &operator[](size_t index) const {
+ return *extend_[index];
+ }
+
+ size_t Size() const {
+ return extend_.size();
+ }
+
+ private:
+ std::vector<VertexNode*> extend_;
+
+ lm::ngram::ChartState state_;
+ bool right_full_;
+
+ Score bound_;
+ Final *end_;
+};
+
+class PartialVertex {
+ public:
+ PartialVertex() {}
+
+ explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+
+ bool Empty() const { return back_->Empty(); }
+
+ bool Complete() const { return back_->Complete(); }
+
+ const lm::ngram::ChartState &State() const { return back_->State(); }
+ bool RightFull() const { return back_->RightFull(); }
+
+ Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
+
+ unsigned char Length() const { return back_->Length(); }
+
+ bool HasAlternative() const {
+ return index_ + 1 < back_->Size();
+ }
+
+ // Split into continuation and alternative, rendering this the alternative.
+ bool Split(PartialVertex &continuation) {
+ assert(!Complete());
+ continuation.back_ = &((*back_)[index_]);
+ continuation.index_ = 0;
+ if (index_ + 1 < back_->Size()) {
+ ++index_;
+ return true;
+ }
+ return false;
+ }
+
+ const Final &End() const {
+ return *back_->End();
+ }
+
+ private:
+ const VertexNode *back_;
+ unsigned int index_;
+};
+
+extern PartialVertex kBlankPartialVertex;
+
+class Vertex {
+ public:
+ Vertex() {}
+
+ PartialVertex RootPartial() const { return PartialVertex(root_); }
+
+ const Final *BestChild() const {
+ PartialVertex top(RootPartial());
+ if (top.Empty()) {
+ return NULL;
+ } else {
+ PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ top = continuation;
+ }
+ return &top.End();
+ }
+ }
+
+ private:
+ friend class VertexGenerator;
+
+ VertexNode root_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX__
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
new file mode 100644
index 00000000..d94e6e06
--- /dev/null
+++ b/klm/search/vertex_generator.cc
@@ -0,0 +1,83 @@
+#include "search/vertex_generator.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+#include "search/edge.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
+ gen.root_.InitRoot();
+ root_.under = &gen.root_;
+}
+
+namespace {
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+} // namespace
+
+void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+ std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
+ if (!got.second) {
+ // Found it already.
+ Final &exists = *got.first->second;
+ if (exists.Bound() < partial.score) {
+ exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
+ }
+ return;
+ }
+ unsigned char left = 0, right = 0;
+ Trie *node = &root_;
+ while (true) {
+ if (left == state.left.length) {
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ for (; right < state.right.length; ++right) {
+ node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
+ }
+ break;
+ }
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
+ left++;
+ if (right == state.right.length) {
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ for (; left < state.left.length; ++left) {
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
+ }
+ break;
+ }
+ node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
+ right++;
+ }
+
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ got.first->second = CompleteTransition(*node, state, note, partial);
+}
+
+VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ VertexGenerator::Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context_.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
+ }
+ return next;
+}
+
+Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {
+ VertexNode &node = *starter.under;
+ assert(node.State().left.full == state.left.full);
+ assert(!node.End());
+ Final *final = context_.NewFinal();
+ final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
+ node.SetEnd(final);
+ return final;
+}
+
+} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
new file mode 100644
index 00000000..6b98da3e
--- /dev/null
+++ b/klm/search/vertex_generator.hh
@@ -0,0 +1,59 @@
+#ifndef SEARCH_VERTEX_GENERATOR__
+#define SEARCH_VERTEX_GENERATOR__
+
+#include "search/note.hh"
+#include "search/vertex.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+class ContextBase;
+class Final;
+struct PartialEdge;
+
+class VertexGenerator {
+ public:
+ VertexGenerator(ContextBase &context, Vertex &gen);
+
+ void NewHypothesis(const PartialEdge &partial, Note note);
+
+ void FinishedSearch() {
+ root_.under->SortAndSet(context_, NULL);
+ }
+
+ const Vertex &Generating() const { return gen_; }
+
+ private:
+ // Parallel structure to VertexNode.
+ struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+ };
+
+ Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
+
+ Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial);
+
+ ContextBase &context_;
+
+ Vertex &gen_;
+
+ Trie root_;
+
+ typedef boost::unordered_map<uint64_t, Final*> Existing;
+ Existing existing_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX_GENERATOR__
diff --git a/klm/search/weights.cc b/klm/search/weights.cc
new file mode 100644
index 00000000..d65471ad
--- /dev/null
+++ b/klm/search/weights.cc
@@ -0,0 +1,71 @@
+#include "search/weights.hh"
+#include "util/tokenize_piece.hh"
+
+#include <cstdlib>
+
+namespace search {
+
+namespace {
+struct Insert {
+ void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
+ std::string copy(name.data(), name.size());
+ map[copy] = score;
+ }
+};
+
+struct DotProduct {
+ search::Score total;
+ DotProduct() : total(0.0) {}
+
+ void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
+ boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
+ if (i != map.end())
+ total += score * i->second;
+ }
+};
+
+template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
+ for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
+ util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
+ UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
+ StringPiece name(*equals);
+ UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
+ char *end;
+ // Assumes proper termination.
+ double value = std::strtod(equals->data(), &end);
+ UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
+ UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
+ op(map, name, value);
+ }
+}
+
+} // namespace
+
+Weights::Weights(StringPiece text) {
+ Insert op;
+ Parse<Map, Insert>(text, map_, op);
+ lm_ = Steal("LanguageModel");
+ oov_ = Steal("OOV");
+ word_penalty_ = Steal("WordPenalty");
+}
+
+Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {}
+
+search::Score Weights::DotNoLM(StringPiece text) const {
+ DotProduct dot;
+ Parse<const Map, DotProduct>(text, map_, dot);
+ return dot.total;
+}
+
+float Weights::Steal(const std::string &str) {
+ Map::iterator i(map_.find(str));
+ if (i == map_.end()) {
+ return 0.0;
+ } else {
+ float ret = i->second;
+ map_.erase(i);
+ return ret;
+ }
+}
+
+} // namespace search
diff --git a/klm/search/weights.hh b/klm/search/weights.hh
new file mode 100644
index 00000000..df1c419f
--- /dev/null
+++ b/klm/search/weights.hh
@@ -0,0 +1,52 @@
+// For now, the individual features are not kept.
+#ifndef SEARCH_WEIGHTS__
+#define SEARCH_WEIGHTS__
+
+#include "search/types.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <string>
+
+namespace search {
+
+class WeightParseException : public util::Exception {
+ public:
+ WeightParseException() {}
+ ~WeightParseException() throw() {}
+};
+
+class Weights {
+ public:
+ // Parses weights, sets lm_weight_, removes it from map_.
+ explicit Weights(StringPiece text);
+
+ // Just the three scores we care about adding.
+ Weights(Score lm, Score oov, Score word_penalty);
+
+ Score DotNoLM(StringPiece text) const;
+
+ Score LM() const { return lm_; }
+
+ Score OOV() const { return oov_; }
+
+ Score WordPenalty() const { return word_penalty_; }
+
+ // Mostly for testing.
+ const boost::unordered_map<std::string, Score> &GetMap() const { return map_; }
+
+ private:
+ float Steal(const std::string &str);
+
+ typedef boost::unordered_map<std::string, Score> Map;
+
+ Map map_;
+
+ Score lm_, oov_, word_penalty_;
+};
+
+} // namespace search
+
+#endif // SEARCH_WEIGHTS__
diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc
new file mode 100644
index 00000000..4811ff06
--- /dev/null
+++ b/klm/search/weights_test.cc
@@ -0,0 +1,38 @@
+#include "search/weights.hh"
+
+#define BOOST_TEST_MODULE WeightTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace search {
+namespace {
+
+#define CHECK_WEIGHT(value, string) \
+ i = parsed.find(string); \
+ BOOST_REQUIRE(i != parsed.end()); \
+ BOOST_CHECK_CLOSE((value), i->second, 0.001);
+
+BOOST_AUTO_TEST_CASE(parse) {
+ // These are not real feature weights.
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
+ boost::unordered_map<std::string, search::Score>::const_iterator i;
+ CHECK_WEIGHT(0.0, "rarity");
+ CHECK_WEIGHT(0.0, "phrase-SGT");
+ CHECK_WEIGHT(9.45117, "phrase-TGS");
+ CHECK_WEIGHT(2.33833, "lexical-SGT");
+ BOOST_CHECK(parsed.end() == parsed.find("lm"));
+ BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
+ CHECK_WEIGHT(-28.3317, "lexical-TGS");
+ CHECK_WEIGHT(5.0, "glue?");
+}
+
+BOOST_AUTO_TEST_CASE(dot) {
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
+}
+
+} // namespace
+} // namespace search
diff --git a/klm/util/have.hh b/klm/util/have.hh
index 1d76a7fc..b8181e99 100644
--- a/klm/util/have.hh
+++ b/klm/util/have.hh
@@ -13,7 +13,7 @@
#endif
#ifndef HAVE_BOOST
-//#define HAVE_BOOST
+#define HAVE_BOOST
#endif
#ifndef HAVE_THREADS