diff options
Diffstat (limited to 'klm/search')
| -rw-r--r-- | klm/search/Makefile.am | 23 | ||||
| -rw-r--r-- | klm/search/applied.hh | 86 | ||||
| -rw-r--r-- | klm/search/config.hh | 38 | ||||
| -rw-r--r-- | klm/search/context.hh | 49 | ||||
| -rw-r--r-- | klm/search/dedupe.hh | 131 | ||||
| -rw-r--r-- | klm/search/edge.hh | 54 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 111 | ||||
| -rw-r--r-- | klm/search/edge_generator.hh | 56 | ||||
| -rw-r--r-- | klm/search/header.hh | 64 | ||||
| -rw-r--r-- | klm/search/nbest.cc | 106 | ||||
| -rw-r--r-- | klm/search/nbest.hh | 81 | ||||
| -rw-r--r-- | klm/search/rule.cc | 43 | ||||
| -rw-r--r-- | klm/search/rule.hh | 25 | ||||
| -rw-r--r-- | klm/search/types.hh | 31 | ||||
| -rw-r--r-- | klm/search/vertex.cc | 55 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 164 | ||||
| -rw-r--r-- | klm/search/vertex_generator.cc | 68 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 99 | 
18 files changed, 1284 insertions, 0 deletions
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..03554276 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,23 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ +  applied.hh \ +  config.hh \ +  context.hh \ +  dedupe.hh \ +  edge.hh \ +  edge_generator.hh \ +  header.hh \ +  nbest.hh \ +  rule.hh \ +  types.hh \ +  vertex.hh \ +  vertex_generator.hh \ +  edge_generator.cc \ +	nbest.cc \ +  rule.cc \ +  vertex.cc \ +  vertex_generator.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm + diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include <math.h> + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.   +template <class Below> class GenericApplied : public Header { +  public: +    GenericApplied() {} + +    GenericApplied(void *location, PartialEdge partial)  +      : Header(location) { +      memcpy(Base(), partial.Base(), kHeaderSize); +      Below *child_out = Children(); +      const PartialVertex *part = partial.NT(); +      const PartialVertex *const part_end_loop = part + partial.GetArity(); +      for (; part != part_end_loop; ++part, ++child_out) +        *child_out = Below(part->End()); +    } +     +    GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { +      SetScore(score); +      SetNote(note); +    } + +    explicit GenericApplied(History from) : Header(from) {} + + +    // These are arrays of length GetArity(). +    Below *Children() { +      return reinterpret_cast<Below*>(After()); +    } +    const Below *Children() const { +      return reinterpret_cast<const Below*>(After()); +    } + +    static std::size_t Size(Arity arity) { +      return kHeaderSize + arity * sizeof(const Below); +    } +}; + +// Applied rule that references itself.   +class Applied : public GenericApplied<Applied> { +  private: +    typedef GenericApplied<Applied> P; + +  public: +    Applied() {} +    Applied(void *location, PartialEdge partial) : P(location, partial) {} +    Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses.   +class SingleBest { +  public: +    typedef PartialEdge Combine; + +    void Add(PartialEdge &existing, PartialEdge add) const { +      if (!existing.Valid() || existing.GetScore() < add.GetScore()) +        existing = add; +    } + +    NBestComplete Complete(PartialEdge partial) { +      if (!partial.Valid())  +        return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); +      void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); +      Applied(place_final, partial); +      return NBestComplete( +          place_final, +          partial.CompletedState(), +          partial.GetScore()); +    } + +  private: +    util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..ba18c09e --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,38 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/types.hh" + +namespace search { + +struct NBestConfig { +  explicit NBestConfig(unsigned int in_size) { +    keep = in_size; +    size = in_size; +  } +   +  unsigned int keep, size; +}; + +class Config { +  public: +    Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : +      lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} + +    Score LMWeight() const { return lm_weight_; } + +    unsigned int PopLimit() const { return pop_limit_; } + +    const NBestConfig &GetNBest() const { return nbest_; } + +  private: +    Score lm_weight_; + +    unsigned int pop_limit_; + +    NBestConfig nbest_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..08f21bbf --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,49 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "search/config.hh" +#include "search/vertex.hh" + +#include <boost/pool/object_pool.hpp> + +namespace search { + +class ContextBase { +  public: +    explicit ContextBase(const Config &config) : config_(config) {} + +    VertexNode *NewVertexNode() { +      VertexNode *ret = vertex_node_pool_.construct(); +      assert(ret); +      return ret; +    } + +    void DeleteVertexNode(VertexNode *node) { +      vertex_node_pool_.destroy(node); +    } + +    unsigned int PopLimit() const { return config_.PopLimit(); } + +    Score LMWeight() const { return config_.LMWeight(); } + +    const Config &GetConfig() const { return config_; } + +  private: +    boost::object_pool<VertexNode> vertex_node_pool_; + +    Config config_; +}; + +template <class Model> class Context : public ContextBase { +  public: +    Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + +    const Model &LanguageModel() const { return model_; } + +  private: +    const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/unordered_map.hpp> + +namespace search { + +class Dedupe { +  public: +    Dedupe() {} + +    PartialEdge AllocateEdge(Arity arity) { +      return behind_.AllocateEdge(arity); +    } + +    void AddEdge(PartialEdge edge) { +      edge.MutableFlags() = 0; + +      uint64_t hash = 0; +      const PartialVertex *v = edge.NT(); +      const PartialVertex *v_end = v + edge.GetArity(); +      for (; v != v_end; ++v) { +        const void *ptr = v->Identify(); +        hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); +      } +       +      const lm::ngram::ChartState *c = edge.Between(); +      const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; +      for (; c != c_end; ++c) hash = hash_value(*c, hash); + +      std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge))); +      if (!ret.second) FoundDupe(ret.first->second, edge); +    } + +    bool Empty() const { return behind_.Empty(); } + +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { +        behind_.AddEdge(i->second); +      } +      Unpack<Output> unpack(output, *this); +      behind_.Search(context, unpack); +    } + +  private: +    void FoundDupe(PartialEdge &table, PartialEdge adding) { +      if (table.GetFlags() & kPackedFlag) { +        Packed &packed = *static_cast<Packed*>(table.GetNote().mut); +        if (table.GetScore() >= adding.GetScore()) { +          packed.others.push_back(adding); +          return; +        } +        Note original(packed.original); +        packed.original = adding.GetNote(); +        adding.SetNote(table.GetNote()); +        table.SetNote(original); +        packed.others.push_back(table); +        packed.starting = adding.GetScore(); +        table = adding; +        table.MutableFlags() |= kPackedFlag; +        return; +      } +      PartialEdge loser; +      if (adding.GetScore() > table.GetScore()) { +        loser = table; +        table = adding; +      } else { +        loser = adding; +      } +      // table is winner, loser is loser... +      packed_.construct(table, loser); +    } + +    struct Packed { +      Packed(PartialEdge winner, PartialEdge loser)  +        : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { +        winner.MutableNote().vp = this; +        winner.MutableFlags() |= kPackedFlag; +        loser.MutableFlags() &= ~kPackedFlag; +      } +      Note original; +      Score starting; +      std::vector<PartialEdge> others; +    }; + +    template <class Output> class Unpack { +      public: +        explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + +        void NewHypothesis(PartialEdge edge) { +          if (edge.GetFlags() & kPackedFlag) { +            Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut); +            edge.SetNote(packed.original); +            edge.MutableFlags() = 0; +            std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); +            for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { +              PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); +              copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); +              copy.MutableFlags() = 0; +              copy.SetNote(i->GetNote()); +              memcpy(copy.NT(), edge.NT(), copy_size); +              output_.NewHypothesis(copy); +            } +          } +          output_.NewHypothesis(edge); +        } + +        void FinishedSearch() { +          output_.FinishedSearch(); +        } + +      private: +        Output &output_; +        Dedupe &owner_; +    }; + +    EdgeGenerator behind_; + +    typedef boost::unordered_map<uint64_t, PartialEdge> Table; +    Table table_; + +    boost::object_pool<Packed> packed_; + +    static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..187904bf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/header.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/pool.hh" + +#include <functional> + +#include <stdint.h> + +namespace search { + +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { +  public: +    // Allow default construction for STL.   +    PartialEdge() {} + +    PartialEdge(util::Pool &pool, Arity arity)  +      : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} +     +    PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)  +      : Header(pool.Allocate(Size(arity, chart_states)), arity) {} + +    // Non-terminals +    const PartialVertex *NT() const { +      return reinterpret_cast<const PartialVertex*>(After()); +    } +    PartialVertex *NT() { +      return reinterpret_cast<PartialVertex*>(After()); +    } + +    const lm::ngram::ChartState &CompletedState() const { +      return *Between(); +    } +    const lm::ngram::ChartState *Between() const { +      return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); +    } +    lm::ngram::ChartState *Between() { +      return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); +    } + +  private: +    static std::size_t Size(Arity arity, Arity chart_states) { +      return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); +    } +}; + + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..eacf5de5 --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,111 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" + +#include <numeric> + +namespace search { + +namespace { + +template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { +  lm::ngram::ChartState *between = update.Between(); +  lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + +  float adjustment = 0.0; +  const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); +  const PartialVertex &update_nt = update.NT()[victim]; +  const lm::ngram::ChartState &update_reveal = update_nt.State(); +  if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { +    adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); +  } +  if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { +    adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); +  } +  if (update_nt.Complete()) { +    if (update_reveal.left.full) { +      before->left.full = true; +    } else { +      assert(update_reveal.left.length == update_reveal.right.length); +      adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); +    } +    before->right = after->right; +    // Shift the others shifted one down, covering after.   +    for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { +      *cover = *(cover + 1); +    } +  } +  update.SetScore(update.GetScore() + adjustment * context.LMWeight()); +} + +} // namespace + +template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { +  assert(!generate_.empty()); +  PartialEdge top = generate_.top(); +  generate_.pop(); +  PartialVertex *const top_nt = top.NT(); +  const Arity arity = top.GetArity(); + +  Arity victim = 0; +  Arity victim_completed; +  Arity incomplete; +  // Select victim or return if complete.    +  { +    Arity completed = 0; +    unsigned char lowest_length = 255; +    for (Arity i = 0; i != arity; ++i) { +      if (top_nt[i].Complete()) { +        ++completed; +      } else if (top_nt[i].Length() < lowest_length) { +        lowest_length = top_nt[i].Length(); +        victim = i; +        victim_completed = completed; +      } +    } +    if (lowest_length == 255) { +      return top; +    } +    incomplete = arity - completed; +  } + +  PartialVertex old_value(top_nt[victim]); +  PartialVertex alternate_changed; +  if (top_nt[victim].Split(alternate_changed)) { +    PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); +    alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); + +    alternate.SetNote(top.GetNote()); + +    PartialVertex *alternate_nt = alternate.NT(); +    for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; +    alternate_nt[victim] = alternate_changed; +    for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + +    memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); + +    // TODO: dedupe?   +    generate_.push(alternate); +  } + +  // top is now the continuation. +  FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); +  // TODO: dedupe?   +  generate_.push(top); + +  // Invalid indicates no new hypothesis generated.   +  return PartialEdge(); +} + +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..203942c6 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,56 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" +#include "search/types.hh" + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class EdgeGenerator { +  public: +    EdgeGenerator() {} + +    PartialEdge AllocateEdge(Arity arity) { +      return PartialEdge(partial_edge_pool_, arity); +    } + +    void AddEdge(PartialEdge edge) { +      generate_.push(edge); +    } + +    bool Empty() const { return generate_.empty(); } + +    // Pop.  If there's a complete hypothesis, return it.  Otherwise return an invalid PartialEdge. +    template <class Model> PartialEdge Pop(Context<Model> &context); + +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      unsigned to_pop = context.PopLimit(); +      while (to_pop > 0 && !generate_.empty()) { +        PartialEdge got(Pop(context)); +        if (got.Valid()) { +          output.NewHypothesis(got); +          --to_pop; +        } +      } +      output.FinishedSearch(); +    } + +  private: +    util::Pool partial_edge_pool_; + +    typedef std::priority_queue<PartialEdge> Generate; +    Generate generate_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..69f0eed0 --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,64 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/types.hh" + +#include <stdint.h> + +namespace search { + +// Copying is shallow.   +class Header { +  public: +    bool Valid() const { return base_; } + +    Score GetScore() const { +      return *reinterpret_cast<const float*>(base_); +    } +    void SetScore(Score to) { +      *reinterpret_cast<float*>(base_) = to; +    } +    bool operator<(const Header &other) const { +      return GetScore() < other.GetScore(); +    } +    bool operator>(const Header &other) const { +      return GetScore() > other.GetScore(); +    } + +    Arity GetArity() const { +      return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); +    } + +    Note GetNote() const { +      return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); +    } +    void SetNote(Note to) { +      *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; +    } + +    uint8_t *Base() { return base_; } +    const uint8_t *Base() const { return base_; } + +  protected: +    Header() : base_(NULL) {} + +    explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {} + +    Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { +      *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; +    } + +    static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + +    uint8_t *After() { return base_ + kHeaderSize; } +    const uint8_t *After() const { return base_ + kHeaderSize; } + +  private: +    uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { +  assert(!partials.empty()); +  std::vector<PartialEdge>::iterator end; +  if (partials.size() > keep) { +    end = partials.begin() + keep; +    std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); +  } else { +    end = partials.end(); +  } +  for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { +    queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); +  } +} + +Score NBestList::TopAfterConstructor() const { +  assert(revealed_.empty()); +  return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { +  while (revealed_.size() < n && !queue_.empty()) { +    MoveTop(pool); +  } +  return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { +  if (index + 1 < revealed_.size()) +    return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); +  if (queue_.empty())  +    return -INFINITY; +  if (index + 1 == revealed_.size()) +    return queue_.top().GetScore() - revealed_[index].GetScore(); +  assert(index == revealed_.size()); + +  MoveTop(pool); + +  if (queue_.empty()) return -INFINITY; +  return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { +  assert(index <= revealed_.size()); +  if (index == revealed_.size()) MoveTop(pool); +  return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { +  assert(!queue_.empty()); +  QueueEntry entry(queue_.top()); +  queue_.pop(); +  RevealedRef *const children_begin = entry.Children(); +  RevealedRef *const children_end = children_begin + entry.GetArity(); +  Score basis = entry.GetScore(); +  for (RevealedRef *child = children_begin; child != children_end; ++child) { +    Score change = child->in_->Visit(pool, child->index_); +    if (change != -INFINITY) { +      assert(change < 0.001); +      QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); +      std::copy(children_begin, child, new_entry.Children()); +      RevealedRef *update = new_entry.Children() + (child - children_begin); +      update->in_ = child->in_; +      update->index_ = child->index_ + 1; +      std::copy(child + 1, children_end, update + 1); +      queue_.push(new_entry); +    } +    // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. +    if (child->index_) break; +  } + +  // Convert QueueEntry to Applied.  This leaves some unused memory.   +  void *overwrite = entry.Children(); +  for (unsigned int i = 0; i < entry.GetArity(); ++i) { +    RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); +    *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); +  } +  revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { +  assert(!partials.empty()); +  NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); +  return NBestComplete( +      list, +      partials.front().CompletedState(), // All partials have the same state +      list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { +  return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include <boost/pool/object_pool.hpp> + +#include <cstddef> +#include <queue> +#include <vector> + +#include <assert.h> + +namespace search { + +class NBestList; + +class NBestList { +  private: +    class RevealedRef { +      public:  +        explicit RevealedRef(History history)  +          : in_(static_cast<NBestList*>(history)), index_(0) {} + +      private: +        friend class NBestList; + +        NBestList *in_; +        std::size_t index_; +    }; +     +    typedef GenericApplied<RevealedRef> QueueEntry; + +  public: +    NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep); + +    Score TopAfterConstructor() const; + +    const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n); + +  private: +    Score Visit(util::Pool &pool, std::size_t index); + +    Applied Get(util::Pool &pool, std::size_t index); + +    void MoveTop(util::Pool &pool); + +    typedef std::vector<Applied> Revealed; +    Revealed revealed_; + +    typedef std::priority_queue<QueueEntry> Queue; +    Queue queue_; +}; + +class NBest { +  public: +    typedef std::vector<PartialEdge> Combine; + +    explicit NBest(const NBestConfig &config) : config_(config) {} + +    void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const { +      existing.push_back(addition); +    } + +    NBestComplete Complete(std::vector<PartialEdge> &partials); + +    const std::vector<Applied> &Extract(History root); + +  private: +    const NBestConfig config_; + +    boost::object_pool<NBestList> list_pool_; + +    util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..0244a09f --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,43 @@ +#include "search/rule.hh" + +#include "lm/model.hh" +#include "search/context.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) { +  ScoreRuleRet ret; +  ret.prob = 0.0; +  ret.oov = 0; +  const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); +  lm::ngram::RuleScore<Model> scorer(model, *(writing++)); +  std::vector<lm::WordIndex>::const_iterator word = words.begin(); +  if (word != words.end() && *word == bos) { +    scorer.BeginSentence(); +    ++word; +  } +  for (; word != words.end(); ++word) { +    if (*word == kNonTerminal) { +      ret.prob += scorer.Finish(); +      scorer.Reset(*(writing++)); +    } else { +      if (*word == oov) ++ret.oov; +      scorer.Terminal(*word); +    } +  } +  ret.prob += scorer.Finish(); +  return ret; +} + +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..43ca6162 --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "lm/word_index.hh" +#include "search/types.hh" + +#include <vector> + +namespace search { + +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; + +struct ScoreRuleRet { +  Score prob; +  unsigned int oov; +}; + +// Pass <s> and </s> normally.   +// Indicate non-terminals with kNonTerminal.   +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..f9c849b3 --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,31 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <stdint.h> + +namespace lm { namespace ngram { class ChartState; } } + +namespace search { + +typedef float Score; + +typedef uint32_t Arity; + +union Note { +  const void *vp; +}; + +typedef void *History; + +struct NBestComplete { +  NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)  +    : history(in_history), state(&in_state), score(in_score) {} + +  History history; +  const lm::ngram::ChartState *state; +  Score score; +}; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..45842982 --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,55 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { +  bool operator()(const VertexNode *first, const VertexNode *second) const { +    return first->Bound() > second->Bound(); +  } +}; + +} // namespace + +void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { +  if (Complete()) { +    assert(end_); +    assert(extend_.empty()); +    return; +  } +  if (extend_.size() == 1) { +    parent_ptr = extend_[0]; +    extend_[0]->RecursiveSortAndSet(context, parent_ptr); +    context.DeleteVertexNode(this); +    return; +  } +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->RecursiveSortAndSet(context, *i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { +  // This is the root.  The root might be empty.   +  if (extend_.empty()) { +    bound_ = -INFINITY; +    return; +  } +  // The root cannot be replaced.  There's always one transition.   +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->RecursiveSortAndSet(context, *i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..ca9a4fcd --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,164 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <math.h> +#include <stdint.h> + +namespace search { + +class ContextBase; + +class VertexNode { +  public: +    VertexNode() : end_() {} + +    void InitRoot() { +      extend_.clear(); +      state_.left.full = false; +      state_.left.length = 0; +      state_.right.length = 0; +      right_full_ = false; +      end_ = History(); +    } + +    lm::ngram::ChartState &MutableState() { return state_; } +    bool &MutableRightFull() { return right_full_; } + +    void AddExtend(VertexNode *next) { +      extend_.push_back(next); +    } + +    void SetEnd(History end, Score score) { +      assert(!end_); +      end_ = end; +      bound_ = score; +    } +     +    void SortAndSet(ContextBase &context); + +    // Should only happen to a root node when the entire vertex is empty.    +    bool Empty() const { +      return !end_ && extend_.empty(); +    } + +    bool Complete() const { +      return end_; +    } + +    const lm::ngram::ChartState &State() const { return state_; } +    bool RightFull() const { return right_full_; } + +    Score Bound() const { +      return bound_; +    } + +    unsigned char Length() const { +      return state_.left.length + state_.right.length; +    } + +    // Will be invalid unless this is a leaf.    +    History End() const { return end_; } + +    const VertexNode &operator[](size_t index) const { +      return *extend_[index]; +    } + +    size_t Size() const { +      return extend_.size(); +    } + +  private: +    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + +    std::vector<VertexNode*> extend_; + +    lm::ngram::ChartState state_; +    bool right_full_; + +    Score bound_; +    History end_; +}; + +class PartialVertex { +  public: +    PartialVertex() {} + +    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + +    bool Empty() const { return back_->Empty(); } + +    bool Complete() const { return back_->Complete(); } + +    const lm::ngram::ChartState &State() const { return back_->State(); } +    bool RightFull() const { return back_->RightFull(); } + +    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } + +    unsigned char Length() const { return back_->Length(); } + +    bool HasAlternative() const { +      return index_ + 1 < back_->Size(); +    } + +    // Split into continuation and alternative, rendering this the continuation. +    bool Split(PartialVertex &alternative) { +      assert(!Complete()); +      bool ret; +      if (index_ + 1 < back_->Size()) { +        alternative.index_ = index_ + 1; +        alternative.back_ = back_; +        ret = true; +      } else { +        ret = false; +      } +      back_ = &((*back_)[index_]); +      index_ = 0; +      return ret; +    } + +    History End() const { +      return back_->End(); +    } + +  private: +    const VertexNode *back_; +    unsigned int index_; +}; + +template <class Output> class VertexGenerator; + +class Vertex { +  public: +    Vertex() {} + +    PartialVertex RootPartial() const { return PartialVertex(root_); } + +    History BestChild() const { +      PartialVertex top(RootPartial()); +      if (top.Empty()) { +        return History(); +      } else { +        PartialVertex continuation; +        while (!top.Complete()) { +          top.Split(continuation); +        } +        return top.End(); +      } +    } + +  private: +    template <class Output> friend class VertexGenerator; +    template <class Output> friend class RootVertexGenerator; +    VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..73139ffc --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,68 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" +#include "search/edge.hh" + +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +#include <stdint.h> + +namespace search { + +#if BOOST_VERSION > 104200 +namespace { + +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { +  Trie &next = node.extend[added]; +  if (!next.under) { +    next.under = context.NewVertexNode(); +    lm::ngram::ChartState &writing = next.under->MutableState(); +    writing = state; +    writing.left.full &= left_full && state.left.full; +    next.under->MutableRightFull() = right_full && state.left.full; +    writing.left.length = left; +    writing.right.length = right; +    node.under->AddExtend(next.under); +  } +  return next; +} + +} // namespace + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { +  const lm::ngram::ChartState &state = *end.state; +   +  unsigned char left = 0, right = 0; +  Trie *node = &root; +  while (true) { +    if (left == state.left.length) { +      node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); +      for (; right < state.right.length; ++right) { +        node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); +      } +      break; +    } +    node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); +    left++; +    if (right == state.right.length) { +      node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); +      for (; left < state.left.length; ++left) { +        node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); +      } +      break; +    } +    node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); +    right++; +  } + +  node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); +  node->under->SetEnd(end.history, end.score); +} + +#endif // BOOST_VERSION + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..646b8189 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,99 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +class ContextBase; + +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode.   +struct Trie { +  Trie() : under(NULL) {} + +  VertexNode *under; +  boost::unordered_map<uint64_t, Trie> extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list.    +template <class Output> class VertexGenerator { +  public: +    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { +      gen.root_.InitRoot(); +    } + +    void NewHypothesis(PartialEdge partial) { +      nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); +    } + +    void FinishedSearch() { +#if BOOST_VERSION > 104200 +      Trie root; +      root.under = &gen_.root_; +      for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { +        AddHypothesis(context_, root, nbest_.Complete(i->second)); +      } +      existing_.clear(); +      root.under->SortAndSet(context_); +#else +      UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif +    } + +    const Vertex &Generating() const { return gen_; } + +  private: +    ContextBase &context_; + +    Vertex &gen_; + +    typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing; +    Existing existing_; + +    Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node.  In theory, this should happen naturally due to state collapsing with +// <s> and </s>.  If that's the case, VertexGenerator is fine, though it will +// make one connection.   +template <class Output> class RootVertexGenerator { +  public: +    RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + +    void NewHypothesis(PartialEdge partial) { +      out_.Add(combine_, partial); +    } + +    void FinishedSearch() { +      gen_.root_.InitRoot(); +      NBestComplete completed(out_.Complete(combine_)); +      gen_.root_.SetEnd(completed.history, completed.score); +    } + +  private: +    Vertex &gen_; +     +    typename Output::Combine combine_; +    Output &out_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__  | 
