diff options
Diffstat (limited to 'klm')
39 files changed, 2229 insertions, 1 deletions
| diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile new file mode 100644 index 00000000..2cc90c05 --- /dev/null +++ b/klm/alone/Jamfile @@ -0,0 +1,4 @@ +lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : <include>.. : : <include>.. <library>../search//search <library>../lm//kenlm ; + +exe decode : main.cc standalone main.cc : <threading>multi:<library>..//boost_thread ; +exe just_vocab : just_vocab.cc standalone : <threading>multi:<library>..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc new file mode 100644 index 00000000..2ae72ce9 --- /dev/null +++ b/klm/alone/assemble.cc @@ -0,0 +1,76 @@ +#include "alone/assemble.hh" + +#include "alone/labeled_edge.hh" +#include "search/final.hh" + +#include <iostream> + +namespace alone { + +std::ostream &operator<<(std::ostream &o, const search::Final &final) { +  const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words(); +  if (words.empty()) return o; +  const search::Final *const *child = final.Children().data(); +  std::vector<const std::string*>::const_iterator i(words.begin()); +  for (; i != words.end() - 1; ++i) { +    if (*i) { +      o << **i << ' '; +    } else { +      o << **child << ' '; +      ++child; +    } +  } + +  if (*i) { +    if (**i != "</s>") { +      o << **i; +    } +  } else { +    o << **child; +  } + +  return o; +} + +namespace { + +void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { +  for (unsigned int i = 0; i < level; ++i) +    o << indent_str; +} + +void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { +  o << "(\n"; +  MakeIndent(o, indent_str, indent); +  const std::vector<const std::string*> &words = static_cast<const LabeledEdge&>(final.From()).Words(); +  const search::Final *const *child = final.Children().data(); +  for (std::vector<const std::string*>::const_iterator i(words.begin()); i != words.end(); ++i) { +    if (*i) { +      o << **i; +      if (i == words.end() - 1) { +        o << '\n'; +        MakeIndent(o, indent_str, indent); +      } else { +        o << ' '; +      } +    } else { +      // One extra indent from the line we're currently on.   +      o << indent_str; +      DetailedFinalInternal(o, **child, indent_str, indent + 1); +      for (unsigned int i = 0; i < indent; ++i) o << indent_str; +      ++child; +    } +  } +  o << ")=" << final.Bound() << '\n'; +} +} // namespace + +void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { +  DetailedFinalInternal(o, final, indent_str, 0); +} + +void PrintFinal(const search::Final &final) { +  std::cout << final << std::endl; +} + +} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh new file mode 100644 index 00000000..e6b0ad5c --- /dev/null +++ b/klm/alone/assemble.hh @@ -0,0 +1,21 @@ +#ifndef ALONE_ASSEMBLE__ +#define ALONE_ASSEMBLE__ + +#include <iosfwd> + +namespace search { +class Final; +} // namespace search + +namespace alone { + +std::ostream &operator<<(std::ostream &o, const search::Final &final); + +void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = "  "); + +// This isn't called anywhere but makes it easy to print from gdb. +void PrintFinal(const search::Final &final); + +} // namespace alone + +#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh new file mode 100644 index 00000000..788352c9 --- /dev/null +++ b/klm/alone/graph.hh @@ -0,0 +1,87 @@ +#ifndef ALONE_GRAPH__ +#define ALONE_GRAPH__ + +#include "alone/labeled_edge.hh" +#include "search/rule.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include <boost/noncopyable.hpp> +#include <boost/pool/object_pool.hpp> +#include <boost/scoped_array.hpp> + +namespace alone { + +template <class T> class FixedAllocator : boost::noncopyable { +  public: +    FixedAllocator() : current_(NULL), end_(NULL) {} + +    void Init(std::size_t count) { +      assert(!current_); +      array_.reset(new T[count]); +      current_ = array_.get(); +      end_ = current_ + count; +    } + +    T &operator[](std::size_t idx) { +      return array_.get()[idx]; +    } + +    T *New() { +      T *ret = current_++; +      UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); +      return ret; +    } + +    std::size_t Size() const { +      return end_ - array_.get(); +    } + +  private: +    boost::scoped_array<T> array_; +    T *current_, *end_; +}; + +class Graph : boost::noncopyable { +  public: +    typedef LabeledEdge Edge; +    typedef search::Vertex Vertex; + +    Graph() {} + +    void SetCounts(std::size_t vertices, std::size_t edges) { +      vertices_.Init(vertices); +      edges_.Init(edges); +    } + +    Vertex *NewVertex() { +      return vertices_.New(); +    } + +    std::size_t VertexSize() const { return vertices_.Size(); } + +    Vertex &MutableVertex(std::size_t index) { +      return vertices_[index]; +    } + +    Edge *NewEdge() {       +      return edges_.New(); +    } + +    std::size_t EdgeSize() const { return edges_.Size(); } + +    void SetRoot(Vertex *root) { root_ = root; } + +    Vertex &Root() { return *root_; } + +  private: +    FixedAllocator<Vertex> vertices_; +    FixedAllocator<Edge> edges_; +     +    Vertex *root_; +}; + +} // namespace alone + +#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc new file mode 100644 index 00000000..35aea5ed --- /dev/null +++ b/klm/alone/just_vocab.cc @@ -0,0 +1,14 @@ +#include "alone/read.hh" +#include "util/file_piece.hh" + +#include <iostream> + +int main() { +  util::FilePiece f(0, "stdin", &std::cerr); +  while (true) { +    try { +      alone::JustVocab(f, std::cout); +    } catch (const util::EndOfFileException &e) { break; } +    std::cout << '\n'; +  } +} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh new file mode 100644 index 00000000..94d8cbdf --- /dev/null +++ b/klm/alone/labeled_edge.hh @@ -0,0 +1,30 @@ +#ifndef ALONE_LABELED_EDGE__ +#define ALONE_LABELED_EDGE__ + +#include "search/edge.hh" + +#include <string> +#include <vector> + +namespace alone { + +class LabeledEdge : public search::Edge { +  public: +    LabeledEdge() {} + +    void AppendWord(const std::string *word) { +      words_.push_back(word); +    } + +    const std::vector<const std::string *> &Words() const { +      return words_; +    } + +  private: +    // NULL for non-terminals.   +    std::vector<const std::string*> words_; +}; + +} // namespace alone + +#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc new file mode 100644 index 00000000..e09ab01d --- /dev/null +++ b/klm/alone/main.cc @@ -0,0 +1,85 @@ +#include "alone/threading.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "util/exception.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include <boost/lexical_cast.hpp> + +#include <iostream> +#include <memory> + +namespace alone { + +template <class Control> void ReadLoop(const std::string &graph_prefix, Control &control) { +  for (unsigned int sentence = 0; ; ++sentence) { +    std::stringstream name; +    name << graph_prefix << '/' << sentence; +    std::auto_ptr<util::FilePiece> file; +    try { +      file.reset(new util::FilePiece(name.str().c_str())); +    } catch (const util::ErrnoException &e) { +      if (e.Error() == ENOENT) return; +      throw; +    } +    control.Add(file.release()); +  } +} + +template <class Model> void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { +  Model model(model_file); +  search::Weights weights(weight_str); +  search::Config config(weights, pop_limit); + +  if (threads > 1) { +#ifdef WITH_THREADS +    Controller<Model> controller(config, model, threads, std::cout); +    ReadLoop(graph_prefix, controller); +#else +    UTIL_THROW(util::Exception, "Threading support not compiled in."); +#endif +  } else { +    InThread<Model> controller(config, model, std::cout); +    ReadLoop(graph_prefix, controller); +  } +} + +void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { +  lm::ngram::ModelType model_type; +  if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; +  switch (model_type) { +    case lm::ngram::PROBING: +      RunWithModelType<lm::ngram::ProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); +      break; +    case lm::ngram::REST_PROBING: +      RunWithModelType<lm::ngram::RestProbingModel>(graph_prefix, lm_name, weight_str, pop_limit, threads); +      break; +    default: +      UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); +  } +} + +} // namespace alone + +int main(int argc, char *argv[]) { +  if (argc < 5 || argc > 6) { +    std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; +    return 1; +  } + +#ifdef WITH_THREADS +  unsigned thread_count = boost::thread::hardware_concurrency(); +#else +  unsigned thread_count = 1; +#endif +  if (argc == 6) { +    thread_count = boost::lexical_cast<unsigned>(argv[5]); +    UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); +  } +  UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are.  Pass it on the command line."); +  alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast<unsigned int>(argv[4]), thread_count); + +  util::PrintUsage(std::cerr); +  return 0; +} diff --git a/klm/alone/read.cc b/klm/alone/read.cc new file mode 100644 index 00000000..0b20be35 --- /dev/null +++ b/klm/alone/read.cc @@ -0,0 +1,118 @@ +#include "alone/read.hh" + +#include "alone/graph.hh" +#include "alone/vocab.hh" +#include "search/arity.hh" +#include "search/context.hh" +#include "search/weights.hh" +#include "util/file_piece.hh" + +#include <boost/unordered_set.hpp> +#include <boost/unordered_map.hpp> + +#include <cstdlib> + +namespace alone { + +namespace { + +template <class Model> Graph::Edge &ReadEdge(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { +  Graph::Edge *ret = to.NewEdge(); + +  StringPiece got; + +  std::vector<lm::WordIndex> words; +  unsigned long int terminals = 0; +  while ("|||" != (got = from.ReadDelimited())) { +    if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { +      // non-terminal +      char *end_ptr; +      unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); +      UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); +      UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices.  Is the file in bottom-up format?"); +      ret->Add(to.MutableVertex(child)); +      words.push_back(lm::kMaxWordIndex); +      ret->AppendWord(NULL); +    } else { +      const std::pair<const std::string, lm::WordIndex> &found = vocab.FindOrAdd(got); +      words.push_back(found.second); +      ret->AppendWord(&found.first); +      ++terminals; +    } +  } +  if (final) { +    // This is not counted for the word penalty.   +    words.push_back(vocab.EndSentence().second); +    ret->AppendWord(&vocab.EndSentence().first); +  } +  // Hard-coded word penalty.   +  float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10; +  ret->InitRule().Init(context, additive, words, final); +  unsigned int arity = ret->GetRule().Arity(); +  UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); +  return *ret; +} + +} // namespace + +// TODO: refactor +void JustVocab(util::FilePiece &from, std::ostream &out) { +  boost::unordered_set<std::string> seen; +  unsigned long int vertices = from.ReadULong(); +  from.ReadULong(); // edges +  UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); +  UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); +  std::string temp; +  for (unsigned long int i = 0; i < vertices; ++i) { +    unsigned long int edge_count = from.ReadULong(); +    UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); +    for (unsigned long int e = 0; e < edge_count; ++e) { +      StringPiece got; +      while ("|||" != (got = from.ReadDelimited())) { +        if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; +        temp.assign(got.data(), got.size()); +        if (seen.insert(temp).second) out << temp << ' '; +      } +      from.ReadLine(); // weights +    } +  } +  // Eat sentence +  from.ReadLine(); +} + +template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab) { +  unsigned long int vertices; +  try { +    vertices = from.ReadULong(); +  } catch (const util::EndOfFileException &e) { return false; } +  unsigned long int edges = from.ReadULong(); +  UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); +  UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); +  --vertices; +  --edges; +  UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); +  to.SetCounts(vertices, edges); +  Graph::Vertex *vertex; +  for (unsigned long int i = 0; ; ++i) { +    vertex = to.NewVertex(); +    unsigned long int edge_count = from.ReadULong(); +    bool root = (i == vertices - 1); +    UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); +    for (unsigned long int e = 0; e < edge_count; ++e) { +      vertex->Add(ReadEdge(context, from, to, vocab, root)); +    } +    vertex->FinishedAdding(); +    if (root) break; +  } +  to.SetRoot(vertex); +  StringPiece str = from.ReadLine(); +  UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); +  // The edge +  from.ReadLine(); +  return true; +} + +template bool ReadCDec(search::Context<lm::ngram::ProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab); +template bool ReadCDec(search::Context<lm::ngram::RestProbingModel> &context, util::FilePiece &from, Graph &to, Vocab &vocab); + +} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh new file mode 100644 index 00000000..10769a86 --- /dev/null +++ b/klm/alone/read.hh @@ -0,0 +1,29 @@ +#ifndef ALONE_READ__ +#define ALONE_READ__ + +#include "util/exception.hh" + +#include <iosfwd> + +namespace util { class FilePiece; } + +namespace search { template <class Model> class Context; } + +namespace alone { + +class Graph; +class Vocab; + +class FormatException : public util::Exception { +  public: +    FormatException() {} +    ~FormatException() throw() {} +}; + +void JustVocab(util::FilePiece &from, std::ostream &to); + +template <class Model> bool ReadCDec(search::Context<Model> &context, util::FilePiece &from, Graph &to, Vocab &vocab); + +} // namespace alone + +#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc new file mode 100644 index 00000000..475386b6 --- /dev/null +++ b/klm/alone/threading.cc @@ -0,0 +1,80 @@ +#include "alone/threading.hh" + +#include "alone/assemble.hh" +#include "alone/graph.hh" +#include "alone/read.hh" +#include "alone/vocab.hh" +#include "lm/model.hh" +#include "search/context.hh" +#include "search/vertex_generator.hh" + +#include <boost/ref.hpp> +#include <boost/scoped_ptr.hpp> +#include <boost/utility/in_place_factory.hpp> + +#include <sstream> + +namespace alone { +template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { +  search::Context<Model> context(config, model); +  Graph graph; +  Vocab vocab(model.GetVocabulary()); +  { +    boost::scoped_ptr<util::FilePiece> in(in_ptr); +    ReadCDec(context, *in, graph, vocab); +  } + +  for (std::size_t i = 0; i < graph.VertexSize(); ++i) { +    search::VertexGenerator(context, graph.MutableVertex(i)); +  } +  search::PartialVertex top = graph.Root().RootPartial(); +  if (top.Empty()) { +    out << "NO PATH FOUND"; +  } else { +    search::PartialVertex continuation; +    while (!top.Complete()) { +      top.Split(continuation); +      top = continuation; +    } +    out << top.End() << " ||| " << top.End().Bound() << std::endl; +  } +} + +template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); +template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); + +#ifdef WITH_THREADS +template <class Model> void DecodeHandler<Model>::operator()(Input message) { +  std::stringstream assemble; +  Decode(config_, model_, message.file, assemble); +  Produce(message.sentence_id, assemble.str()); +} + +template <class Model> void DecodeHandler<Model>::Produce(unsigned int sentence_id, const std::string &str) { +  Output out; +  out.sentence_id = sentence_id; +  out.str = new std::string(str); +  out_.Produce(out); +} + +void PrintHandler::operator()(Output message) { +  unsigned int relative = message.sentence_id - done_; +  if (waiting_.size() <= relative) waiting_.resize(relative + 1); +  waiting_[relative] = message.str; +  for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { +    out_ << *lead; +    delete lead; +  } +} + +template <class Model> Controller<Model>::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) :  +  sentence_id_(0), +  printer_(decode_workers, 1, boost::ref(to), Output::Poison()), +  decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} + +template class Controller<lm::ngram::RestProbingModel>; +template class Controller<lm::ngram::ProbingModel>; + +#endif + +} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh new file mode 100644 index 00000000..0ab0f739 --- /dev/null +++ b/klm/alone/threading.hh @@ -0,0 +1,129 @@ +#ifndef ALONE_THREADING__ +#define ALONE_THREADING__ + +#ifdef WITH_THREADS +#include "util/pcqueue.hh" +#include "util/pool.hh" +#endif + +#include <iosfwd> +#include <queue> +#include <string> + +namespace util { +class FilePiece; +} // namespace util + +namespace search { +class Config; +template <class Model> class Context; +} // namespace search + +namespace alone { + +template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); + +class Graph; + +#ifdef WITH_THREADS +struct SentenceID { +  unsigned int sentence_id; +  bool operator==(const SentenceID &other) const { +    return sentence_id == other.sentence_id; +  } +}; + +struct Input : public SentenceID { +  util::FilePiece *file; +  static Input Poison() { +    Input ret; +    ret.sentence_id = static_cast<unsigned int>(-1); +    ret.file = NULL; +    return ret; +  } +}; + +struct Output : public SentenceID { +  std::string *str; +  static Output Poison() { +    Output ret; +    ret.sentence_id = static_cast<unsigned int>(-1); +    ret.str = NULL; +    return ret; +  } +}; + +template <class Model> class DecodeHandler { +  public: +    typedef Input Request; + +    DecodeHandler(const search::Config &config, const Model &model, util::PCQueue<Output> &out) : config_(config), model_(model), out_(out) {} + +    void operator()(Input message); + +  private: +    void Produce(unsigned int sentence_id, const std::string &str); + +    const search::Config &config_; + +    const Model &model_; +     +    util::PCQueue<Output> &out_; +}; + +class PrintHandler { +  public: +    typedef Output Request; + +    explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} + +    void operator()(Output message); + +  private: +    std::ostream &out_; +    std::deque<std::string*> waiting_; +    unsigned int done_; +}; + +template <class Model> class Controller { +  public: +    // This config must remain valid.    +    explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); + +    // Takes ownership of in.     +    void Add(util::FilePiece *in) { +      Input input; +      input.sentence_id = sentence_id_++; +      input.file = in; +      decoder_.Produce(input); +    } + +  private: +    unsigned int sentence_id_; + +    util::Pool<PrintHandler> printer_; + +    util::Pool<DecodeHandler<Model> > decoder_; +}; +#endif + +// Same API as controller.   +template <class Model> class InThread { +  public: +    InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} + +    // Takes ownership of in.   +    void Add(util::FilePiece *in) { +      Decode(config_, model_, in, to_); +    } + +  private: +    const search::Config &config_; + +    const Model &model_; + +    std::ostream &to_;  +}; + +} // namespace alone +#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc new file mode 100644 index 00000000..ffe55301 --- /dev/null +++ b/klm/alone/vocab.cc @@ -0,0 +1,19 @@ +#include "alone/vocab.hh" + +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" + +namespace alone { + +Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("</s>")) {} + +const std::pair<const std::string, lm::WordIndex> &Vocab::FindOrAdd(const StringPiece &str) { +  Map::const_iterator i(FindStringPiece(map_, str)); +  if (i != map_.end()) return *i; +  std::pair<std::string, lm::WordIndex> to_ins; +  to_ins.first.assign(str.data(), str.size()); +  to_ins.second = backing_.Index(str); +  return *map_.insert(to_ins).first; +} + +} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh new file mode 100644 index 00000000..3ac0f542 --- /dev/null +++ b/klm/alone/vocab.hh @@ -0,0 +1,34 @@ +#ifndef ALONE_VOCAB__ +#define ALONE_VOCAB__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +#include <boost/functional/hash/hash.hpp> +#include <boost/unordered_map.hpp> + +#include <string> + +namespace lm { namespace base { class Vocabulary; } } + +namespace alone { + +class Vocab { +  public: +    explicit Vocab(const lm::base::Vocabulary &backing); + +    const std::pair<const std::string, lm::WordIndex> &FindOrAdd(const StringPiece &str); + +    const std::pair<const std::string, lm::WordIndex> &EndSentence() const { return end_sentence_; } + +  private: +    typedef boost::unordered_map<std::string, lm::WordIndex> Map; +    Map map_; + +    const lm::base::Vocabulary &backing_; + +    const std::pair<const std::string, lm::WordIndex> &end_sentence_; +}; + +} // namespace alone +#endif // ALONE_VCOAB__ diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template <class Model> void Query(const char *name) { +  Model model(name); +  std::string line; +  lm::ngram::ChartState ignored; +  while (getline(std::cin, line)) { +    lm::ngram::RuleScore<Model> scorer(model, ignored); +    for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) { +      scorer.Terminal(model.GetVocabulary().Index(*i)); +    } +    std::cout << scorer.Finish() << '\n'; +  } +} + +int main(int argc, char *argv[]) { +  if (argc != 2) { +    std::cerr << "Expected model file name." << std::endl; +    return 1; +  } +  const char *name = argv[1]; +  lm::ngram::ModelType model_type = lm::ngram::PROBING; +  lm::ngram::RecognizeBinary(name, model_type); +  switch (model_type) { +    case lm::ngram::PROBING: +      Query<lm::ngram::ProbingModel>(name); +      break; +    case lm::ngram::REST_PROBING: +      Query<lm::ngram::RestProbingModel>(name); +      break; +    default: +      std::cerr << "Model type not supported yet." << std::endl; +  } +} diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh new file mode 100644 index 00000000..1dede359 --- /dev/null +++ b/klm/lm/partial.hh @@ -0,0 +1,167 @@ +#ifndef LM_PARTIAL__ +#define LM_PARTIAL__ + +#include "lm/return.hh" +#include "lm/state.hh" + +#include <algorithm> + +#include <assert.h> + +namespace lm { +namespace ngram { + +struct ExtendReturn { +  float adjust; +  bool make_full; +  unsigned char next_use; +}; + +template <class Model> ExtendReturn ExtendLoop( +    const Model &model, +    unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start, +    const uint64_t *pointers, const uint64_t *pointers_end, +    uint64_t *&pointers_write, +    float *backoff_write) { +  unsigned char add_length = add_rend - add_rbegin; + +  float backoff_buf[2][KENLM_MAX_ORDER - 1]; +  float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1]; +  std::copy(backoff_start, backoff_start + add_length, backoff_in); + +  ExtendReturn value; +  value.make_full = false; +  value.adjust = 0.0; +  value.next_use = add_length; + +  unsigned char i = 0; +  unsigned char length = pointers_end - pointers; +  // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.   +  if (pointers_write) { +    // Using full context, writing to new left state.    +    for (; i < length; ++i) { +      FullScoreReturn ret(model.ExtendLeft( +          add_rbegin, add_rbegin + value.next_use, +          backoff_in, +          pointers[i], i + seen + 1, +          backoff_out, +          value.next_use)); +      std::swap(backoff_in, backoff_out); +      if (ret.independent_left) { +        value.adjust += ret.prob; +        value.make_full = true; +        ++i; +        break; +      } +      value.adjust += ret.rest; +      *pointers_write++ = ret.extend_left; +      if (value.next_use != add_length) { +        value.make_full = true; +        ++i; +        break; +      } +    } +  } +  // Using some of the new context.   +  for (; i < length && value.next_use; ++i) { +    FullScoreReturn ret(model.ExtendLeft( +        add_rbegin, add_rbegin + value.next_use, +        backoff_in, +        pointers[i], i + seen + 1, +        backoff_out, +        value.next_use)); +    std::swap(backoff_in, backoff_out); +    value.adjust += ret.prob; +  } +  float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1); +  // Using none of the new context.   +  value.adjust += unrest; + +  std::copy(backoff_in, backoff_in + value.next_use, backoff_write); +  return value; +} + +template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) { +  assert(seen < reveal.length || reveal_full); +  uint64_t *pointers_write = reveal_full ? NULL : left.pointers; +  float backoff_buffer[KENLM_MAX_ORDER - 1]; +  ExtendReturn value(ExtendLoop( +      model, +      seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen, +      left.pointers, left.pointers + left.length, +      pointers_write, +      left.full ? backoff_buffer : (right.backoff + right.length))); +  if (reveal_full) { +    left.length = 0; +    value.make_full = true; +  } else { +    left.length = pointers_write - left.pointers; +    value.make_full |= (left.length == model.Order() - 1); +  } +  if (left.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; +  } else { +    // If left wasn't full when it came in, put words into right state.   +    std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length); +    right.length += value.next_use; +    left.full = value.make_full || (right.length == model.Order() - 1); +  } +  return value.adjust; +} + +template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) { +  assert(seen < reveal.length || reveal.full); +  uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length); +  ExtendReturn value(ExtendLoop( +      model, +      seen, right.words, right.words + right.length, right.backoff, +      reveal.pointers + seen, reveal.pointers + reveal.length, +      pointers_write, +      right.backoff)); +  if (reveal.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i]; +    right.length = 0; +    value.make_full = true; +  } else { +    right.length = value.next_use; +    value.make_full |= (right.length == model.Order() - 1); +  } +  if (!left.full) { +    left.length = pointers_write - left.pointers; +    left.full = value.make_full || (left.length == model.Order() - 1); +  } +  return value.adjust; +} + +template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) { +  assert(first_right.length < KENLM_MAX_ORDER); +  assert(second_left.length < KENLM_MAX_ORDER); +  assert(between_length < KENLM_MAX_ORDER - 1); +  uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length); +  float backoff_buffer[KENLM_MAX_ORDER - 1]; +  ExtendReturn value(ExtendLoop( +        model, +        between_length, first_right.words, first_right.words + first_right.length, first_right.backoff, +        second_left.pointers, second_left.pointers + second_left.length, +        pointers_write, +        second_left.full ? backoff_buffer : (second_right.backoff + second_right.length))); +  if (second_left.full) { +    for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; +  } else { +    std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length); +    second_right.length += value.next_use; +    value.make_full |= (second_right.length == model.Order() - 1); +  } +  if (!first_left.full) { +    first_left.length = pointers_write - first_left.pointers; +    first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1); +  } +  assert(first_left.length < KENLM_MAX_ORDER); +  assert(second_right.length < KENLM_MAX_ORDER); +  return value.adjust; +} + +} // namespace ngram +} // namespace lm + +#endif // LM_PARTIAL__ diff --git a/klm/lm/partial_test.cc b/klm/lm/partial_test.cc new file mode 100644 index 00000000..8d309c85 --- /dev/null +++ b/klm/lm/partial_test.cc @@ -0,0 +1,199 @@ +#include "lm/partial.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "util/tokenize_piece.hh" + +#define BOOST_TEST_MODULE PartialTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace lm { +namespace ngram { +namespace { + +const char *TestLocation() { +  if (boost::unit_test::framework::master_test_suite().argc < 2) { +    return "test.arpa"; +  } +  return boost::unit_test::framework::master_test_suite().argv[1]; +} + +Config SilentConfig() { +  Config config; +  config.arpa_complain = Config::NONE; +  config.messages = NULL; +  return config; +} + +struct ModelFixture { +  ModelFixture() : m(TestLocation(), SilentConfig()) {} + +  RestProbingModel m; +}; + +BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture) + +BOOST_AUTO_TEST_CASE(SimpleBefore) { +  Left left; +  left.full = false; +  left.length = 0; +  Right right; +  right.length = 0; + +  Right reveal; +  reveal.length = 1; +  WordIndex period = m.GetVocabulary().Index("."); +  reveal.words[0] = period; +  reveal.backoff[0] = -0.845098; + +  BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001); +  BOOST_CHECK_EQUAL(0, left.length); +  BOOST_CHECK(!left.full); +  BOOST_CHECK_EQUAL(1, right.length); +  BOOST_CHECK_EQUAL(period, right.words[0]); +  BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); + +  WordIndex more = m.GetVocabulary().Index("more"); +  reveal.words[1] = more; +  reveal.backoff[1] =  -0.4771212; +  reveal.length = 2; +  BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001); +  BOOST_CHECK_EQUAL(0, left.length); +  BOOST_CHECK(!left.full); +  BOOST_CHECK_EQUAL(2, right.length); +  BOOST_CHECK_EQUAL(period, right.words[0]); +  BOOST_CHECK_EQUAL(more, right.words[1]); +  BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001); +  BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001); +} + +BOOST_AUTO_TEST_CASE(AlsoWouldConsider) { +  WordIndex would = m.GetVocabulary().Index("would"); +  WordIndex consider = m.GetVocabulary().Index("consider"); + +  ChartState current; +  current.left.length = 1; +  current.left.pointers[0] = would; +  current.left.full = false; +  current.right.length = 1; +  current.right.words[0] = would; +  current.right.backoff[0] = -0.30103; + +  Left after; +  after.full = false; +  after.length = 1; +  after.pointers[0] = consider; + +  // adjustment for would consider +  BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001); + +  BOOST_CHECK_EQUAL(2, current.left.length); +  BOOST_CHECK_EQUAL(would, current.left.pointers[0]); +  BOOST_CHECK_EQUAL(false, current.left.full); + +  WordIndex also = m.GetVocabulary().Index("also"); +  Right before; +  before.length = 1; +  before.words[0] = also; +  before.backoff[0] = -0.30103; +  // r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)] +  // p(also -> would) = -2, p(also would -> consider) = -3 +  BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001); +  BOOST_CHECK_EQUAL(0, current.left.length); +  BOOST_CHECK(current.left.full); +  BOOST_CHECK_EQUAL(2, current.right.length); +  BOOST_CHECK_EQUAL(would, current.right.words[0]); +  BOOST_CHECK_EQUAL(also, current.right.words[1]); +} + +BOOST_AUTO_TEST_CASE(EndSentence) { +  WordIndex loin = m.GetVocabulary().Index("loin"); +  WordIndex period = m.GetVocabulary().Index("."); +  WordIndex eos = m.GetVocabulary().EndSentence(); + +  ChartState between; +  between.left.length = 1; +  between.left.pointers[0] = eos; +  between.left.full = true; +  between.right.length = 0; + +  Right before; +  before.words[0] = period; +  before.words[1] = loin; +  before.backoff[0] = -0.845098; +  before.backoff[1] = 0.0; +   +  before.length = 1; +  BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001); +  BOOST_CHECK_EQUAL(0, between.left.length); +} + +float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) { +  RuleScore<RestProbingModel> scorer(model, out); +  for (unsigned int *i = begin; i < end; ++i) { +    scorer.Terminal(*i); +  } +  return scorer.Finish(); +} + +void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) { +  Right before(before_in); +  Left after(after_in); +  after.full = false; +  float got = 0.0; +  for (unsigned int i = 1; i < 5; ++i) { +    if (before_in.length >= i) { +      before.length = i; +      got += RevealBefore(model, before, i - 1, false, between.left, between.right); +    } +    if (after_in.length >= i) { +      after.length = i; +      got += RevealAfter(model, between.left, between.right, after, i - 1); +    } +  } +  if (after_in.full) { +    after.full = true; +    got += RevealAfter(model, between.left, between.right, after, after.length); +  } +  if (before_full) { +    got += RevealBefore(model, before, before.length, true, between.left, between.right); +  } +  // Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.  +  BOOST_CHECK(fabs(expect - got) < 0.001); +} + +void FullDivide(const RestProbingModel &model, StringPiece str) { +  std::vector<WordIndex> indices; +  for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) { +    indices.push_back(model.GetVocabulary().Index(*i)); +  } +  ChartState full_state; +  float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state); + +  ChartState before_state; +  before_state.left.full = false; +  RuleScore<RestProbingModel> before_scorer(model, before_state); +  float before_score = 0.0; +  for (unsigned int before = 0; before < indices.size(); ++before) { +    for (unsigned int after = before; after <= indices.size(); ++after) { +      ChartState after_state, between_state; +      float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state); +      float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state); +      CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left); +    } +    before_scorer.Terminal(indices[before]); +    before_score = before_scorer.Finish(); +  } +} + +BOOST_AUTO_TEST_CASE(Strings) { +  FullDivide(m, "also would consider"); +  FullDivide(m, "looking on a little more loin . </s>"); +  FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); +} + +BOOST_AUTO_TEST_SUITE_END() +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/search/Jamfile b/klm/search/Jamfile new file mode 100644 index 00000000..e8b14363 --- /dev/null +++ b/klm/search/Jamfile @@ -0,0 +1,5 @@ +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ; + +import testing ; + +unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/search/arity.hh b/klm/search/arity.hh new file mode 100644 index 00000000..09c2c671 --- /dev/null +++ b/klm/search/arity.hh @@ -0,0 +1,8 @@ +#ifndef SEARCH_ARITY__ +#define SEARCH_ARITY__ +namespace search { + +const unsigned int kMaxArity = 2; + +} // namespace search +#endif // SEARCH_ARITY__ diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..ef8e2354 --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/weights.hh" +#include "util/string_piece.hh" + +namespace search { + +class Config { +  public: +    Config(const Weights &weights, unsigned int pop_limit) : +      weights_(weights), pop_limit_(pop_limit) {} + +    const Weights &GetWeights() const { return weights_; } + +    unsigned int PopLimit() const { return pop_limit_; } + +  private: +    Weights weights_; +    unsigned int pop_limit_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..27940053 --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,65 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "lm/model.hh" +#include "search/config.hh" +#include "search/final.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <vector> + +namespace search { + +class Weights; + +class ContextBase { +  public: +    explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} + +    Final *NewFinal() { +     Final *ret = final_pool_.construct(); +     assert(ret); +     return ret; +    } + +    VertexNode *NewVertexNode() { +      VertexNode *ret = vertex_node_pool_.construct(); +      assert(ret); +      return ret; +    } + +    void DeleteVertexNode(VertexNode *node) { +      vertex_node_pool_.destroy(node); +    } + +    unsigned int PopLimit() const { return pop_limit_; } + +    const Weights &GetWeights() const { return weights_; } + +  private: +    boost::object_pool<Final> final_pool_; +    boost::object_pool<VertexNode> vertex_node_pool_; + +    unsigned int pop_limit_; + +    const Weights &weights_; +}; + +template <class Model> class Context : public ContextBase { +  public: +    Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + +    const Model &LanguageModel() const { return model_; } + +  private: +    const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..77ab0ade --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,31 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/arity.hh" +#include "search/rule.hh" +#include "search/types.hh" +#include "search/vertex.hh" + +#include <queue> + +namespace search { + +struct PartialEdge { +  Score score; +  // Terminals +  lm::ngram::ChartState between[kMaxArity + 1]; +  // Non-terminals +  PartialVertex nt[kMaxArity]; + +  const lm::ngram::ChartState &CompletedState() const { +    return between[0]; +  } + +  bool operator<(const PartialEdge &other) const { +    return score < other.score; +  } +}; + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..56239dfb --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,120 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" + +#include <numeric> + +namespace search { + +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/*  for (unsigned char i = 0; i < edge.Arity(); ++i) { +    root.nt[i] = edge.GetVertex(i).RootPartial(); +  } +  for (unsigned char i = edge.Arity(); i < 2; ++i) { +    root.nt[i] = kBlankPartialVertex; +  }*/ +  generate_.push(&root); +  top_score_ = root.score; +} + +namespace { + +template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { +  memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); + +  float ret = 0.0; +  lm::ngram::ChartState *before, *after; +  if (victim == 0) { +    before = &update.between[0]; +    after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; +  } else { +    assert(victim == 1); +    assert(arity == 2); +    before = &update.between[previous.nt[0].Complete() ? 0 : 1]; +    after = &update.between[2]; +  } +  const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); +  const PartialVertex &update_nt = update.nt[victim]; +  const lm::ngram::ChartState &update_reveal = update_nt.State(); +  float just_after = 0.0; +  if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { +    just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); +  } +  if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { +    ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); +  } +  if (update_nt.Complete()) { +    if (update_reveal.left.full) { +      before->left.full = true; +    } else { +      assert(update_reveal.left.length == update_reveal.right.length); +      ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); +    } +    if (victim == 0) { +      update.between[0].right = after->right; +    } else { +      update.between[2].left = before->left; +    } +  } +  return previous.score + (ret + just_after) * context.GetWeights().LM(); +} + +} // namespace + +template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) { +  assert(!generate_.empty()); +  PartialEdge &top = *generate_.top(); +  generate_.pop(); +  unsigned int victim = 0; +  unsigned char lowest_length = 255; +  for (unsigned char i = 0; i != arity_; ++i) { +    if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { +      lowest_length = top.nt[i].Length(); +      victim = i; +    } +  } +  if (lowest_length == 255) { +    // All states report complete.   +    top.between[0].right = top.between[arity_].right; +    // Now top.between[0] is the full edge state.   +    top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; +    return ⊤ +  } + +  unsigned int stay = !victim; +  PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc()); +  float old_bound = top.nt[victim].Bound(); +  // The alternate's score will change because alternate.nt[victim] changes.   +  bool split = top.nt[victim].Split(continuation.nt[victim]); +  // top is now the alternate.   + +  continuation.nt[stay] = top.nt[stay]; +  continuation.score = FastScore(context, victim, arity_, top, continuation); +  // TODO: dedupe?   +  generate_.push(&continuation); + +  if (split) { +    // We have an alternate.   +    top.score += top.nt[victim].Bound() - old_bound; +    // TODO: dedupe?   +    generate_.push(&top); +  } else { +    partial_edge_pool.free(&top); +  } + +  top_score_ = generate_.top()->score; +  return NULL; +} + +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..875ccc5e --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,58 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" +#include "search/note.hh" + +#include <boost/pool/pool.hpp> +#include <boost/unordered_map.hpp> + +#include <functional> +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class VertexGenerator; + +struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> { +  bool operator()(const PartialEdge *first, const PartialEdge *second) const { +    return *first < *second; +  } +}; + +class EdgeGenerator { +  public: +    EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + +    Score TopScore() const { +      return top_score_; +    } + +    Note GetNote() const { +      return note_; +    } + +    // Pop.  If there's a complete hypothesis, return it.  Otherwise return NULL.   +    template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool); + +  private: +    Score top_score_; + +    unsigned char arity_; + +    typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate; +    Generate generate_; + +    Note note_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include <stdint.h> + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { +  take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { +  // Ignore empty edges.   +  for (unsigned char i = 0; i < edge.Arity(); ++i) { +    PartialVertex root(edge.GetVertex(i).RootPartial()); +    if (root.Empty()) return; +    total_score += root.Bound(); +  } +  PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +  allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include <boost/pool/pool.hpp> +#include <boost/pool/object_pool.hpp> + +#include <queue> + +namespace search { + +template <class Model> class Context; + +class EdgeQueue { +  public: +    explicit EdgeQueue(unsigned int pop_limit_hint); + +    PartialEdge &InitializeEdge() { +      return *take_; +    } + +    void AddEdge(unsigned char arity, Note note) { +      generate_.push(edge_pool_.construct(*take_, arity, note)); +      take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +    } + +    bool Empty() const { return generate_.empty(); } + +    /* Generate hypotheses and send them to output.  Normally, output is a +     * VertexGenerator, but the decoder may want to route edges to different +     * vertices i.e. if they have different LHS non-terminal labels.   +     */ +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      int to_pop = context.PopLimit(); +      while (to_pop > 0 && !generate_.empty()) { +        EdgeGenerator *top = generate_.top(); +        generate_.pop(); +        PartialEdge *ret = top->Pop(context, partial_edge_pool_); +        if (ret) { +          output.NewHypothesis(*ret, top->GetNote()); +          --to_pop; +          if (top->TopScore() != -kScoreInf) { +            generate_.push(top); +          } +        } else { +          generate_.push(top); +        } +      } +      output.FinishedSearch(); +    } + +  private: +    boost::object_pool<EdgeGenerator> edge_pool_; + +    struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { +      bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { +        return first->TopScore() < second->TopScore(); +      } +    }; + +    typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate; +    Generate generate_; + +    boost::pool<> partial_edge_pool_; + +    PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh new file mode 100644 index 00000000..1b3092ac --- /dev/null +++ b/klm/search/final.hh @@ -0,0 +1,39 @@ +#ifndef SEARCH_FINAL__ +#define SEARCH_FINAL__ + +#include "search/arity.hh" +#include "search/note.hh" +#include "search/types.hh" + +#include <boost/array.hpp> + +namespace search { + +class Final { +  public: +    typedef boost::array<const Final*, search::kMaxArity> ChildArray; + +    void Reset(Score bound, Note note, const Final &left, const Final &right) { +      bound_ = bound; +      note_ = note; +      children_[0] = &left; +      children_[1] = &right; +    } + +    const ChildArray &Children() const { return children_; } + +    Note GetNote() const { return note_; } + +    Score Bound() const { return bound_; } + +  private: +    Score bound_; + +    Note note_; + +    ChildArray children_; +}; + +} // namespace search + +#endif // SEARCH_FINAL__ diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { +  const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..5b00207e --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,43 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { +  unsigned int oov_count = 0; +  float prob = 0.0; +  const Model &model = context.LanguageModel(); +  const lm::WordIndex oov = model.GetVocabulary().NotFound(); +  for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { +    lm::ngram::RuleScore<Model> scorer(model, *(writing++)); +    // TODO: optimize +    if (prepend_bos && (word == words.begin())) { +      scorer.BeginSentence(); +    } +    for (; ; ++word) { +      if (word == words.end()) { +        prob += scorer.Finish(); +        return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); +      } +      if (*word == kNonTerminal) break; +      if (*word == oov) ++oov_count; +      scorer.Terminal(*word); +    } +    prob += scorer.Finish(); +  } +} + +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..0ce2794d --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,20 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "lm/word_index.hh" +#include "search/types.hh" + +#include <vector> + +namespace search { + +template <class Model> class Context; + +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/source.hh b/klm/search/source.hh new file mode 100644 index 00000000..11839f7b --- /dev/null +++ b/klm/search/source.hh @@ -0,0 +1,48 @@ +#ifndef SEARCH_SOURCE__ +#define SEARCH_SOURCE__ + +#include "search/types.hh" + +#include <assert.h> +#include <vector> + +namespace search { + +template <class Final> class Source { +  public: +    Source() : bound_(kScoreInf) {} + +    Index Size() const { +      return final_.size(); +    } + +    Score Bound() const { +      return bound_; +    } + +    const Final &operator[](Index index) const { +      return *final_[index]; +    } + +    Score ScoreOrBound(Index index) const { +      return Size() > index ? final_[index]->Total() : Bound(); +    } + +  protected: +    void AddFinal(const Final &store) { +      final_.push_back(&store); +    } + +    void SetBound(Score to) { +      assert(to <= bound_ + 0.001); +      bound_ = to; +    } + +  private: +    std::vector<const Final *> final_; + +    Score bound_; +}; + +} // namespace search +#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..9726379f --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,18 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <cmath> + +namespace search { + +typedef float Score; +const Score kScoreInf = INFINITY; + +// This could have been an enum but gcc wants 4 bytes.   +typedef bool ExtendDirection; +const ExtendDirection kExtendLeft = 0; +const ExtendDirection kExtendRight = 1; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..cc53c0dd --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,48 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { +  bool operator()(const VertexNode *first, const VertexNode *second) const { +    return first->Bound() > second->Bound(); +  } +}; + +} // namespace + +void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { +  if (Complete()) { +    assert(end_); +    assert(extend_.empty()); +    bound_ = end_->Bound(); +    return; +  } +  if (extend_.size() == 1 && parent_ptr) { +    *parent_ptr = extend_[0]; +    extend_[0]->SortAndSet(context, parent_ptr); +    context.DeleteVertexNode(this); +    return; +  } +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->SortAndSet(context, &*i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +namespace { +VertexNode kBlankVertexNode; +} // namespace + +PartialVertex kBlankPartialVertex(kBlankVertexNode); + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..e1a9ad11 --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,158 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/final.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <stdint.h> + +namespace search { + +class ContextBase; + +class VertexNode { +  public: +    VertexNode() : end_(NULL) {} + +    void InitRoot() { +      extend_.clear(); +      state_.left.full = false; +      state_.left.length = 0; +      state_.right.length = 0; +      right_full_ = false; +      bound_ = -kScoreInf; +      end_ = NULL; +    } + +    lm::ngram::ChartState &MutableState() { return state_; } +    bool &MutableRightFull() { return right_full_; } + +    void AddExtend(VertexNode *next) { +      extend_.push_back(next); +    } + +    void SetEnd(Final *end) { end_ = end; } +     +    Final &MutableEnd() { return *end_; } + +    void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + +    // Should only happen to a root node when the entire vertex is empty.    +    bool Empty() const { +      return !end_ && extend_.empty(); +    } + +    bool Complete() const { +      return end_; +    } + +    const lm::ngram::ChartState &State() const { return state_; } +    bool RightFull() const { return right_full_; } + +    Score Bound() const { +      return bound_; +    } + +    unsigned char Length() const { +      return state_.left.length + state_.right.length; +    } + +    // May be NULL. +    const Final *End() const { return end_; } + +    const VertexNode &operator[](size_t index) const { +      return *extend_[index]; +    } + +    size_t Size() const { +      return extend_.size(); +    } + +  private: +    std::vector<VertexNode*> extend_; + +    lm::ngram::ChartState state_; +    bool right_full_; + +    Score bound_; +    Final *end_; +}; + +class PartialVertex { +  public: +    PartialVertex() {} + +    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + +    bool Empty() const { return back_->Empty(); } + +    bool Complete() const { return back_->Complete(); } + +    const lm::ngram::ChartState &State() const { return back_->State(); } +    bool RightFull() const { return back_->RightFull(); } + +    Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + +    unsigned char Length() const { return back_->Length(); } + +    bool HasAlternative() const { +      return index_ + 1 < back_->Size(); +    } + +    // Split into continuation and alternative, rendering this the alternative. +    bool Split(PartialVertex &continuation) { +      assert(!Complete()); +      continuation.back_ = &((*back_)[index_]); +      continuation.index_ = 0; +      if (index_ + 1 < back_->Size()) { +        ++index_; +        return true; +      } +      return false; +    } + +    const Final &End() const { +      return *back_->End(); +    } + +  private: +    const VertexNode *back_; +    unsigned int index_; +}; + +extern PartialVertex kBlankPartialVertex; + +class Vertex { +  public: +    Vertex() {} + +    PartialVertex RootPartial() const { return PartialVertex(root_); } + +    const Final *BestChild() const { +      PartialVertex top(RootPartial()); +      if (top.Empty()) { +        return NULL; +      } else { +        PartialVertex continuation; +        while (!top.Complete()) { +          top.Split(continuation); +          top = continuation; +        } +        return &top.End(); +      } +    } + +  private: +    friend class VertexGenerator; + +    VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..d94e6e06 --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,83 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" +#include "search/edge.hh" + +#include <stdint.h> + +namespace search { + +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { +  gen.root_.InitRoot(); +  root_.under = &gen.root_; +} + +namespace { +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); +} // namespace + +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { +  const lm::ngram::ChartState &state = partial.CompletedState(); +  std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL))); +  if (!got.second) { +    // Found it already.   +    Final &exists = *got.first->second; +    if (exists.Bound() < partial.score) { +      exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); +    } +    return; +  } +  unsigned char left = 0, right = 0; +  Trie *node = &root_; +  while (true) { +    if (left == state.left.length) { +      node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); +      for (; right < state.right.length; ++right) { +        node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); +      } +      break; +    } +    node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); +    left++; +    if (right == state.right.length) { +      node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); +      for (; left < state.left.length; ++left) { +        node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); +      } +      break; +    } +    node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); +    right++; +  } + +  node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); +  got.first->second = CompleteTransition(*node, state, note, partial); +} + +VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { +  VertexGenerator::Trie &next = node.extend[added]; +  if (!next.under) { +    next.under = context_.NewVertexNode(); +    lm::ngram::ChartState &writing = next.under->MutableState(); +    writing = state; +    writing.left.full &= left_full && state.left.full; +    next.under->MutableRightFull() = right_full && state.left.full; +    writing.left.length = left; +    writing.right.length = right; +    node.under->AddExtend(next.under); +  } +  return next; +} + +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { +  VertexNode &node = *starter.under; +  assert(node.State().left.full == state.left.full); +  assert(!node.End()); +  Final *final = context_.NewFinal(); +  final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); +  node.SetEnd(final); +  return final; +} + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..6b98da3e --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,59 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/note.hh" +#include "search/vertex.hh" + +#include <boost/unordered_map.hpp> + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +class ContextBase; +class Final; +struct PartialEdge; + +class VertexGenerator { +  public: +    VertexGenerator(ContextBase &context, Vertex &gen); + +    void NewHypothesis(const PartialEdge &partial, Note note); + +    void FinishedSearch() { +      root_.under->SortAndSet(context_, NULL); +    } + +    const Vertex &Generating() const { return gen_; } + +  private: +    // Parallel structure to VertexNode.   +    struct Trie { +      Trie() : under(NULL) {} + +      VertexNode *under; +      boost::unordered_map<uint64_t, Trie> extend; +    }; + +    Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); + +    Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); + +    ContextBase &context_; + +    Vertex &gen_; + +    Trie root_; + +    typedef boost::unordered_map<uint64_t, Final*> Existing; +    Existing existing_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..d65471ad --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,71 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { +  void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { +    std::string copy(name.data(), name.size()); +    map[copy] = score; +  } +}; + +struct DotProduct { +  search::Score total; +  DotProduct() : total(0.0) {} + +  void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { +    boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); +    if (i != map.end())  +      total += score * i->second; +  } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { +  for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { +    util::TokenIter<util::SingleCharacter> equals(*spaces, '='); +    UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); +    StringPiece name(*equals); +    UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); +    char *end; +    // Assumes proper termination.   +    double value = std::strtod(equals->data(), &end); +    UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); +    UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); +    op(map, name, value); +  } +} + +} // namespace + +Weights::Weights(StringPiece text) { +  Insert op; +  Parse<Map, Insert>(text, map_, op); +  lm_ = Steal("LanguageModel"); +  oov_ = Steal("OOV"); +  word_penalty_ = Steal("WordPenalty"); +} + +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + +search::Score Weights::DotNoLM(StringPiece text) const { +  DotProduct dot; +  Parse<const Map, DotProduct>(text, map_, dot); +  return dot.total; +} + +float Weights::Steal(const std::string &str) { +  Map::iterator i(map_.find(str)); +  if (i == map_.end()) { +    return 0.0; +  } else { +    float ret = i->second; +    map_.erase(i); +    return ret; +  } +} + +} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh new file mode 100644 index 00000000..df1c419f --- /dev/null +++ b/klm/search/weights.hh @@ -0,0 +1,52 @@ +// For now, the individual features are not kept.   +#ifndef SEARCH_WEIGHTS__ +#define SEARCH_WEIGHTS__ + +#include "search/types.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <string> + +namespace search { + +class WeightParseException : public util::Exception { +  public: +    WeightParseException() {} +    ~WeightParseException() throw() {} +}; + +class Weights { +  public: +    // Parses weights, sets lm_weight_, removes it from map_. +    explicit Weights(StringPiece text); + +    // Just the three scores we care about adding.    +    Weights(Score lm, Score oov, Score word_penalty); + +    Score DotNoLM(StringPiece text) const; + +    Score LM() const { return lm_; } + +    Score OOV() const { return oov_; } + +    Score WordPenalty() const { return word_penalty_; } + +    // Mostly for testing.   +    const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } + +  private: +    float Steal(const std::string &str); + +    typedef boost::unordered_map<std::string, Score> Map; + +    Map map_; + +    Score lm_, oov_, word_penalty_; +}; + +} // namespace search + +#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc new file mode 100644 index 00000000..4811ff06 --- /dev/null +++ b/klm/search/weights_test.cc @@ -0,0 +1,38 @@ +#include "search/weights.hh" + +#define BOOST_TEST_MODULE WeightTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace search { +namespace { + +#define CHECK_WEIGHT(value, string) \ +  i = parsed.find(string); \ +  BOOST_REQUIRE(i != parsed.end()); \ +  BOOST_CHECK_CLOSE((value), i->second, 0.001); + +BOOST_AUTO_TEST_CASE(parse) { +  // These are not real feature weights.   +  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); +  const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); +  boost::unordered_map<std::string, search::Score>::const_iterator i; +  CHECK_WEIGHT(0.0, "rarity"); +  CHECK_WEIGHT(0.0, "phrase-SGT"); +  CHECK_WEIGHT(9.45117, "phrase-TGS"); +  CHECK_WEIGHT(2.33833, "lexical-SGT"); +  BOOST_CHECK(parsed.end() == parsed.find("lm")); +  BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); +  CHECK_WEIGHT(-28.3317, "lexical-TGS"); +  CHECK_WEIGHT(5.0, "glue?"); +} + +BOOST_AUTO_TEST_CASE(dot) { +  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); +  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); +  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); +  BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); +} + +} // namespace +} // namespace search diff --git a/klm/util/have.hh b/klm/util/have.hh index 1d76a7fc..b8181e99 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -13,7 +13,7 @@  #endif  #ifndef HAVE_BOOST -//#define HAVE_BOOST +#define HAVE_BOOST  #endif  #ifndef HAVE_THREADS | 
