diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-09-12 15:07:44 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-09-12 15:07:44 +0100 |
commit | 173910593bf6bf3dc52902f99a683560d8c73942 (patch) | |
tree | 746ea920283178f3bf6b7f86e7b9e6b195821676 /klm | |
parent | 143ba7317dcaee3058d66f9e6558316f88f95212 (diff) |
Add the alone stuff, using a wrapper to the edge class.
Diffstat (limited to 'klm')
-rw-r--r-- | klm/alone/Jamfile | 4 | ||||
-rw-r--r-- | klm/alone/assemble.cc | 76 | ||||
-rw-r--r-- | klm/alone/assemble.hh | 21 | ||||
-rw-r--r-- | klm/alone/graph.hh | 87 | ||||
-rw-r--r-- | klm/alone/just_vocab.cc | 14 | ||||
-rw-r--r-- | klm/alone/labeled_edge.hh | 30 | ||||
-rw-r--r-- | klm/alone/main.cc | 84 | ||||
-rw-r--r-- | klm/alone/read.cc | 118 | ||||
-rw-r--r-- | klm/alone/read.hh | 29 | ||||
-rw-r--r-- | klm/alone/threading.cc | 80 | ||||
-rw-r--r-- | klm/alone/threading.hh | 129 | ||||
-rw-r--r-- | klm/alone/vocab.cc | 19 | ||||
-rw-r--r-- | klm/alone/vocab.hh | 34 |
13 files changed, 725 insertions, 0 deletions
diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile new file mode 100644 index 00000000..2cc90c05 --- /dev/null +++ b/klm/alone/Jamfile @@ -0,0 +1,4 @@ +lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : <include>.. : : <include>.. <library>../search//search <library>../lm//kenlm ; + +exe decode : main.cc standalone main.cc : <threading>multi:<library>..//boost_thread ; +exe just_vocab : just_vocab.cc standalone : <threading>multi:<library>..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc new file mode 100644 index 00000000..2ae72ce9 --- /dev/null +++ b/klm/alone/assemble.cc @@ -0,0 +1,76 @@ +#include "alone/assemble.hh" + +#include "alone/labeled_edge.hh" +#include "search/final.hh" + +#include <iostream> + +namespace alone { + +std::ostream &operator<<(std::ostream &o, const search::Final &final) { + const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words(); + if (words.empty()) return o; + const search::Final *const *child = final.Children().data(); + std::vector<const std::string*>::const_iterator i(words.begin()); + for (; i != words.end() - 1; ++i) { + if (*i) { + o << **i << ' '; + } else { + o << **child << ' '; + ++child; + } + } + + if (*i) { + if (**i != "</s>") { + o << **i; + } + } else { + o << **child; + } + + return o; +} + +namespace { + +void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { + for (unsigned int i = 0; i < level; ++i) + o << indent_str; +} + +void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { + o << "(\n"; + MakeIndent(o, indent_str, indent); + const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words(); + const search::Final *const *child = final.Children().data(); + for (std::vector<const std::string*>::const_iterator i(words.begin()); i != words.end(); ++i) { + if (*i) { + o << **i; + if (i == words.end() - 1) { + o << '\n'; + MakeIndent(o, indent_str, indent); + } else { + o << ' '; + } + } else { + // One extra indent from the line we're currently on. + o << indent_str; + DetailedFinalInternal(o, **child, indent_str, indent + 1); + for (unsigned int i = 0; i < indent; ++i) o << indent_str; + ++child; + } + } + o << ")=" << final.Bound() << '\n'; +} +} // namespace + +void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { + DetailedFinalInternal(o, final, indent_str, 0); +} + +void PrintFinal(const search::Final &final) { + std::cout << final << std::endl; +} + +} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh new file mode 100644 index 00000000..e6b0ad5c --- /dev/null +++ b/klm/alone/assemble.hh @@ -0,0 +1,21 @@ +#ifndef ALONE_ASSEMBLE__ +#define ALONE_ASSEMBLE__ + +#include <iosfwd> + +namespace search { +class Final; +} // namespace search + +namespace alone { + +std::ostream &operator<<(std::ostream &o, const search::Final &final); + +void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); + +// This isn't called anywhere but makes it easy to print from gdb. +void PrintFinal(const search::Final &final); + +} // namespace alone + +#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh new file mode 100644 index 00000000..788352c9 --- /dev/null +++ b/klm/alone/graph.hh @@ -0,0 +1,87 @@ +#ifndef ALONE_GRAPH__ +#define ALONE_GRAPH__ + +#include "alone/labeled_edge.hh" +#include "search/rule.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include <boost/noncopyable.hpp> +#include <boost/pool/object_pool.hpp> +#include <boost/scoped_array.hpp> + +namespace alone { + +template <class T> class FixedAllocator : boost::noncopyable { + public: + FixedAllocator() : current_(NULL), end_(NULL) {} + + void Init(std::size_t count) { + assert(!current_); + array_.reset(new T[count]); + current_ = array_.get(); + end_ = current_ + count; + } + + T &operator[](std::size_t idx) { + return array_.get()[idx]; + } + + T *New() { + T *ret = current_++; + UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); + return ret; + } + + std::size_t Size() const { + return end_ - array_.get(); + } + + private: + boost::scoped_array<T> array_; + T *current_, *end_; +}; + +class Graph : boost::noncopyable { + public: + typedef LabeledEdge Edge; + typedef search::Vertex Vertex; + + Graph() {} + + void SetCounts(std::size_t vertices, std::size_t edges) { + vertices_.Init(vertices); + edges_.Init(edges); + } + + Vertex *NewVertex() { + return vertices_.New(); + } + + std::size_t VertexSize() const { return vertices_.Size(); } + + Vertex &MutableVertex(std::size_t index) { + return vertices_[index]; + } + + Edge *NewEdge() { + return edges_.New(); + } + + std::size_t EdgeSize() const { return edges_.Size(); } + + void SetRoot(Vertex *root) { root_ = root; } + + Vertex &Root() { return *root_; } + + private: + FixedAllocator<Vertex> vertices_; + FixedAllocator<Edge> edges_; + + Vertex *root_; +}; + +} // namespace alone + +#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc new file mode 100644 index 00000000..35aea5ed --- /dev/null +++ b/klm/alone/just_vocab.cc @@ -0,0 +1,14 @@ +#include "alone/read.hh" +#include "util/file_piece.hh" + +#include <iostream> + +int main() { + util::FilePiece f(0, "stdin", &std::cerr); + while (true) { + try { + alone::JustVocab(f, std::cout); + } catch (const util::EndOfFileException &e) { break; } + std::cout << '\n'; + } +} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh new file mode 100644 index 00000000..94d8cbdf --- /dev/null +++ b/klm/alone/labeled_edge.hh @@ -0,0 +1,30 @@ +#ifndef ALONE_LABELED_EDGE__ +#define ALONE_LABELED_EDGE__ + +#include "search/edge.hh" + +#include <string> +#include <vector> + +namespace alone { + +class LabeledEdge : public search::Edge { + public: + LabeledEdge() {} + + void AppendWord(const std::string *word) { + words_.push_back(word); + } + + const std::vector<const std::string *> &Words() const { + return words_; + } + + private: + // NULL for non-terminals. + std::vector<const std::string*> words_; +}; + +} // namespace alone + +#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc new file mode 100644 index 00000000..7768b89c --- /dev/null +++ b/klm/alone/main.cc @@ -0,0 +1,84 @@ +#include "alone/threading.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "util/exception.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include <boost/lexical_cast.hpp> + +#include <iostream> +#include <memory> + +namespace alone { + +template <class Control> void ReadLoop(const std::string &graph_prefix, Control &control) { + for (unsigned int sentence = 0; ; ++sentence) { + std::stringstream name; + name << graph_prefix << '/' << sentence; + std::auto_ptr<util::FilePiece> file; + try { + file.reset(new util::FilePiece(name.str().c_str())); + } catch (const util::ErrnoException &e) { + if (e.Error() == ENOENT) return; + throw; + } + control.Add(file.release()); + } +} + +template <class Model> void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { + Model model(model_file); + search::Config config(weight_str, pop_limit); + + if (threads > 1) { +#ifdef WITH_THREADS + Controller<Model> controller(config, model, threads, std::cout); + ReadLoop(graph_prefix, controller); +#else + UTIL_THROW(util::Exception, "Threading support not compiled in."); +#endif + } else { + InThread<Model> controller(config, model, std::cout); + ReadLoop(graph_prefix, controller); + } +} + +void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + RunWithModelType<lm::ngram::ProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); + break; + case lm::ngram::REST_PROBING: + RunWithModelType<lm::ngram::RestProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); + break; + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +} // namespace alone + +int main(int argc, char *argv[]) { + if (argc < 5 || argc > 6) { + std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; + return 1; + } + +#ifdef WITH_THREADS + unsigned thread_count = boost::thread::hardware_concurrency(); +#else + unsigned thread_count = 1; +#endif + if (argc == 6) { + thread_count = boost::lexical_cast<unsigned>(argv[5]); + UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); + } + UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); + alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast<unsigned int>(argv[4]), thread_count); + + util::PrintUsage(std::cerr); + return 0; +} diff --git a/klm/alone/read.cc b/klm/alone/read.cc new file mode 100644 index 00000000..0b20be35 --- /dev/null +++ b/klm/alone/read.cc @@ -0,0 +1,118 @@ +#include "alone/read.hh" + +#include "alone/graph.hh" +#include "alone/vocab.hh" +#include "search/arity.hh" +#include "search/context.hh" +#include "search/weights.hh" +#include "util/file_piece.hh" + +#include <boost/unordered_set.hpp> +#include <boost/unordered_map.hpp> + +#include <cstdlib> + +namespace alone { + +namespace { + +template <class Model> Graph::Edge &ReadEdge(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { + Graph::Edge *ret = to.NewEdge(); + + StringPiece got; + + std::vector<lm::WordIndex> words; + unsigned long int terminals = 0; + while ("|||" != (got = from.ReadDelimited())) { + if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { + // non-terminal + char *end_ptr; + unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); + UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); + UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); + ret->Add(to.MutableVertex(child)); + words.push_back(lm::kMaxWordIndex); + ret->AppendWord(NULL); + } else { + const std::pair<const std::string, lm::WordIndex> &found = vocab.FindOrAdd(got); + words.push_back(found.second); + ret->AppendWord(&found.first); + ++terminals; + } + } + if (final) { + // This is not counted for the word penalty. + words.push_back(vocab.EndSentence().second); + ret->AppendWord(&vocab.EndSentence().first); + } + // Hard-coded word penalty. + float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10; + ret->InitRule().Init(context, additive, words, final); + unsigned int arity = ret->GetRule().Arity(); + UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); + return *ret; +} + +} // namespace + +// TODO: refactor +void JustVocab(util::FilePiece &from, std::ostream &out) { + boost::unordered_set<std::string> seen; + unsigned long int vertices = from.ReadULong(); + from.ReadULong(); // edges + UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); + UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); + std::string temp; + for (unsigned long int i = 0; i < vertices; ++i) { + unsigned long int edge_count = from.ReadULong(); + UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); + for (unsigned long int e = 0; e < edge_count; ++e) { + StringPiece got; + while ("|||" != (got = from.ReadDelimited())) { + if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; + temp.assign(got.data(), got.size()); + if (seen.insert(temp).second) out << temp << ' '; + } + from.ReadLine(); // weights + } + } + // Eat sentence + from.ReadLine(); +} + +template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab) { + unsigned long int vertices; + try { + vertices = from.ReadULong(); + } catch (const util::EndOfFileException &e) { return false; } + unsigned long int edges = from.ReadULong(); + UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); + UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); + --vertices; + --edges; + UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); + to.SetCounts(vertices, edges); + Graph::Vertex *vertex; + for (unsigned long int i = 0; ; ++i) { + vertex = to.NewVertex(); + unsigned long int edge_count = from.ReadULong(); + bool root = (i == vertices - 1); + UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); + for (unsigned long int e = 0; e < edge_count; ++e) { + vertex->Add(ReadEdge(context, from, to, vocab, root)); + } + vertex->FinishedAdding(); + if (root) break; + } + to.SetRoot(vertex); + StringPiece str = from.ReadLine(); + UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); + // The edge + from.ReadLine(); + return true; +} + +template bool ReadCDec(search::Context<lm::ngram::ProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab); +template bool ReadCDec(search::Context<lm::ngram::RestProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab); + +} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh new file mode 100644 index 00000000..10769a86 --- /dev/null +++ b/klm/alone/read.hh @@ -0,0 +1,29 @@ +#ifndef ALONE_READ__ +#define ALONE_READ__ + +#include "util/exception.hh" + +#include <iosfwd> + +namespace util { class FilePiece; } + +namespace search { template <class Model> class Context; } + +namespace alone { + +class Graph; +class Vocab; + +class FormatException : public util::Exception { + public: + FormatException() {} + ~FormatException() throw() {} +}; + +void JustVocab(util::FilePiece &from, std::ostream &to); + +template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab); + +} // namespace alone + +#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc new file mode 100644 index 00000000..475386b6 --- /dev/null +++ b/klm/alone/threading.cc @@ -0,0 +1,80 @@ +#include "alone/threading.hh" + +#include "alone/assemble.hh" +#include "alone/graph.hh" +#include "alone/read.hh" +#include "alone/vocab.hh" +#include "lm/model.hh" +#include "search/context.hh" +#include "search/vertex_generator.hh" + +#include <boost/ref.hpp> +#include <boost/scoped_ptr.hpp> +#include <boost/utility/in_place_factory.hpp> + +#include <sstream> + +namespace alone { +template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { + search::Context<Model> context(config, model); + Graph graph; + Vocab vocab(model.GetVocabulary()); + { + boost::scoped_ptr<util::FilePiece> in(in_ptr); + ReadCDec(context, *in, graph, vocab); + } + + for (std::size_t i = 0; i < graph.VertexSize(); ++i) { + search::VertexGenerator(context, graph.MutableVertex(i)); + } + search::PartialVertex top = graph.Root().RootPartial(); + if (top.Empty()) { + out << "NO PATH FOUND"; + } else { + search::PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + out << top.End() << " ||| " << top.End().Bound() << std::endl; + } +} + +template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); +template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); + +#ifdef WITH_THREADS +template <class Model> void DecodeHandler<Model>::operator()(Input message) { + std::stringstream assemble; + Decode(config_, model_, message.file, assemble); + Produce(message.sentence_id, assemble.str()); +} + +template <class Model> void DecodeHandler<Model>::Produce(unsigned int sentence_id, const std::string &str) { + Output out; + out.sentence_id = sentence_id; + out.str = new std::string(str); + out_.Produce(out); +} + +void PrintHandler::operator()(Output message) { + unsigned int relative = message.sentence_id - done_; + if (waiting_.size() <= relative) waiting_.resize(relative + 1); + waiting_[relative] = message.str; + for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { + out_ << *lead; + delete lead; + } +} + +template <class Model> Controller<Model>::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : + sentence_id_(0), + printer_(decode_workers, 1, boost::ref(to), Output::Poison()), + decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} + +template class Controller<lm::ngram::RestProbingModel>; +template class Controller<lm::ngram::ProbingModel>; + +#endif + +} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh new file mode 100644 index 00000000..0ab0f739 --- /dev/null +++ b/klm/alone/threading.hh @@ -0,0 +1,129 @@ +#ifndef ALONE_THREADING__ +#define ALONE_THREADING__ + +#ifdef WITH_THREADS +#include "util/pcqueue.hh" +#include "util/pool.hh" +#endif + +#include <iosfwd> +#include <queue> +#include <string> + +namespace util { +class FilePiece; +} // namespace util + +namespace search { +class Config; +template <class Model> class Context; +} // namespace search + +namespace alone { + +template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); + +class Graph; + +#ifdef WITH_THREADS +struct SentenceID { + unsigned int sentence_id; + bool operator==(const SentenceID &other) const { + return sentence_id == other.sentence_id; + } +}; + +struct Input : public SentenceID { + util::FilePiece *file; + static Input Poison() { + Input ret; + ret.sentence_id = static_cast<unsigned int>(-1); + ret.file = NULL; + return ret; + } +}; + +struct Output : public SentenceID { + std::string *str; + static Output Poison() { + Output ret; + ret.sentence_id = static_cast<unsigned int>(-1); + ret.str = NULL; + return ret; + } +}; + +template <class Model> class DecodeHandler { + public: + typedef Input Request; + + DecodeHandler(const search::Config &config, const Model &model, util::PCQueue<Output> &out) : config_(config), model_(model), out_(out) {} + + void operator()(Input message); + + private: + void Produce(unsigned int sentence_id, const std::string &str); + + const search::Config &config_; + + const Model &model_; + + util::PCQueue<Output> &out_; +}; + +class PrintHandler { + public: + typedef Output Request; + + explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} + + void operator()(Output message); + + private: + std::ostream &out_; + std::deque<std::string*> waiting_; + unsigned int done_; +}; + +template <class Model> class Controller { + public: + // This config must remain valid. + explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); + + // Takes ownership of in. + void Add(util::FilePiece *in) { + Input input; + input.sentence_id = sentence_id_++; + input.file = in; + decoder_.Produce(input); + } + + private: + unsigned int sentence_id_; + + util::Pool<PrintHandler> printer_; + + util::Pool<DecodeHandler<Model> > decoder_; +}; +#endif + +// Same API as controller. +template <class Model> class InThread { + public: + InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} + + // Takes ownership of in. + void Add(util::FilePiece *in) { + Decode(config_, model_, in, to_); + } + + private: + const search::Config &config_; + + const Model &model_; + + std::ostream &to_; +}; + +} // namespace alone +#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc new file mode 100644 index 00000000..ffe55301 --- /dev/null +++ b/klm/alone/vocab.cc @@ -0,0 +1,19 @@ +#include "alone/vocab.hh" + +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" + +namespace alone { + +Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("</s>")) {} + +const std::pair<const std::string, lm::WordIndex> &Vocab::FindOrAdd(const StringPiece &str) { + Map::const_iterator i(FindStringPiece(map_, str)); + if (i != map_.end()) return *i; + std::pair<std::string, lm::WordIndex> to_ins; + to_ins.first.assign(str.data(), str.size()); + to_ins.second = backing_.Index(str); + return *map_.insert(to_ins).first; +} + +} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh new file mode 100644 index 00000000..3ac0f542 --- /dev/null +++ b/klm/alone/vocab.hh @@ -0,0 +1,34 @@ +#ifndef ALONE_VOCAB__ +#define ALONE_VOCAB__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +#include <boost/functional/hash/hash.hpp> +#include <boost/unordered_map.hpp> + +#include <string> + +namespace lm { namespace base { class Vocabulary; } } + +namespace alone { + +class Vocab { + public: + explicit Vocab(const lm::base::Vocabulary &backing); + + const std::pair<const std::string, lm::WordIndex> &FindOrAdd(const StringPiece &str); + + const std::pair<const std::string, lm::WordIndex> &EndSentence() const { return end_sentence_; } + + private: + typedef boost::unordered_map<std::string, lm::WordIndex> Map; + Map map_; + + const lm::base::Vocabulary &backing_; + + const std::pair<const std::string, lm::WordIndex> &end_sentence_; +}; + +} // namespace alone +#endif // ALONE_VCOAB__ |