diff options
Diffstat (limited to 'klm/search')
| -rw-r--r-- | klm/search/Makefile.am | 4 | ||||
| -rw-r--r-- | klm/search/applied.hh | 86 | ||||
| -rw-r--r-- | klm/search/config.hh | 25 | ||||
| -rw-r--r-- | klm/search/context.hh | 28 | ||||
| -rw-r--r-- | klm/search/dedupe.hh | 131 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 3 | ||||
| -rw-r--r-- | klm/search/edge_generator.hh | 1 | ||||
| -rw-r--r-- | klm/search/final.hh | 36 | ||||
| -rw-r--r-- | klm/search/header.hh | 9 | ||||
| -rw-r--r-- | klm/search/nbest.cc | 106 | ||||
| -rw-r--r-- | klm/search/nbest.hh | 81 | ||||
| -rw-r--r-- | klm/search/note.hh | 12 | ||||
| -rw-r--r-- | klm/search/rule.cc | 52 | ||||
| -rw-r--r-- | klm/search/rule.hh | 11 | ||||
| -rw-r--r-- | klm/search/types.hh | 17 | ||||
| -rw-r--r-- | klm/search/vertex.cc | 27 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 37 | ||||
| -rw-r--r-- | klm/search/vertex_generator.cc | 44 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 72 | ||||
| -rw-r--r-- | klm/search/weights.cc | 71 | ||||
| -rw-r--r-- | klm/search/weights.hh | 52 | ||||
| -rw-r--r-- | klm/search/weights_test.cc | 38 | 
22 files changed, 604 insertions, 339 deletions
| diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a  libksearch_a_SOURCES = \    edge_generator.cc \ +	nbest.cc \    rule.cc \    vertex.cc \ -  vertex_generator.cc \ -  weights.cc +  vertex_generator.cc  AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include <math.h> + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.   +template <class Below> class GenericApplied : public Header { +  public: +    GenericApplied() {} + +    GenericApplied(void *location, PartialEdge partial)  +      : Header(location) { +      memcpy(Base(), partial.Base(), kHeaderSize); +      Below *child_out = Children(); +      const PartialVertex *part = partial.NT(); +      const PartialVertex *const part_end_loop = part + partial.GetArity(); +      for (; part != part_end_loop; ++part, ++child_out) +        *child_out = Below(part->End()); +    } +     +    GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { +      SetScore(score); +      SetNote(note); +    } + +    explicit GenericApplied(History from) : Header(from) {} + + +    // These are arrays of length GetArity(). +    Below *Children() { +      return reinterpret_cast<Below*>(After()); +    } +    const Below *Children() const { +      return reinterpret_cast<const Below*>(After()); +    } + +    static std::size_t Size(Arity arity) { +      return kHeaderSize + arity * sizeof(const Below); +    } +}; + +// Applied rule that references itself.   +class Applied : public GenericApplied<Applied> { +  private: +    typedef GenericApplied<Applied> P; + +  public: +    Applied() {} +    Applied(void *location, PartialEdge partial) : P(location, partial) {} +    Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses.   +class SingleBest { +  public: +    typedef PartialEdge Combine; + +    void Add(PartialEdge &existing, PartialEdge add) const { +      if (!existing.Valid() || existing.GetScore() < add.GetScore()) +        existing = add; +    } + +    NBestComplete Complete(PartialEdge partial) { +      if (!partial.Valid())  +        return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); +      void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); +      Applied(place_final, partial); +      return NBestComplete( +          place_final, +          partial.CompletedState(), +          partial.GetScore()); +    } + +  private: +    util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@  #ifndef SEARCH_CONFIG__  #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh"  namespace search { +struct NBestConfig { +  explicit NBestConfig(unsigned int in_size) { +    keep = in_size; +    size = in_size; +  } +   +  unsigned int keep, size; +}; +  class Config {    public: -    Config(const Weights &weights, unsigned int pop_limit) : -      weights_(weights), pop_limit_(pop_limit) {} +    Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : +      lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} -    const Weights &GetWeights() const { return weights_; } +    Score LMWeight() const { return lm_weight_; }      unsigned int PopLimit() const { return pop_limit_; } +    const NBestConfig &GetNBest() const { return nbest_; } +    private: -    Weights weights_; +    Score lm_weight_; +      unsigned int pop_limit_; + +    NBestConfig nbest_;  };  } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@  #ifndef SEARCH_CONTEXT__  #define SEARCH_CONTEXT__ -#include "lm/model.hh"  #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh"  #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh"  #include <boost/pool/object_pool.hpp> -#include <boost/ptr_container/ptr_vector.hpp> - -#include <vector>  namespace search { -class Weights; -  class ContextBase {    public: -    explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - -    util::Pool &FinalPool() { -      return final_pool_; -    } +    explicit ContextBase(const Config &config) : config_(config) {}      VertexNode *NewVertexNode() {        VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase {        vertex_node_pool_.destroy(node);      } -    unsigned int PopLimit() const { return pop_limit_; } +    unsigned int PopLimit() const { return config_.PopLimit(); } -    const Weights &GetWeights() const { return weights_; } +    Score LMWeight() const { return config_.LMWeight(); } -  private: -    util::Pool final_pool_; +    const Config &GetConfig() const { return config_; } +  private:      boost::object_pool<VertexNode> vertex_node_pool_; -    unsigned int pop_limit_; - -    const Weights &weights_; +    Config config_;  };  template <class Model> class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/unordered_map.hpp> + +namespace search { + +class Dedupe { +  public: +    Dedupe() {} + +    PartialEdge AllocateEdge(Arity arity) { +      return behind_.AllocateEdge(arity); +    } + +    void AddEdge(PartialEdge edge) { +      edge.MutableFlags() = 0; + +      uint64_t hash = 0; +      const PartialVertex *v = edge.NT(); +      const PartialVertex *v_end = v + edge.GetArity(); +      for (; v != v_end; ++v) { +        const void *ptr = v->Identify(); +        hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); +      } +       +      const lm::ngram::ChartState *c = edge.Between(); +      const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; +      for (; c != c_end; ++c) hash = hash_value(*c, hash); + +      std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge))); +      if (!ret.second) FoundDupe(ret.first->second, edge); +    } + +    bool Empty() const { return behind_.Empty(); } + +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { +        behind_.AddEdge(i->second); +      } +      Unpack<Output> unpack(output, *this); +      behind_.Search(context, unpack); +    } + +  private: +    void FoundDupe(PartialEdge &table, PartialEdge adding) { +      if (table.GetFlags() & kPackedFlag) { +        Packed &packed = *static_cast<Packed*>(table.GetNote().mut); +        if (table.GetScore() >= adding.GetScore()) { +          packed.others.push_back(adding); +          return; +        } +        Note original(packed.original); +        packed.original = adding.GetNote(); +        adding.SetNote(table.GetNote()); +        table.SetNote(original); +        packed.others.push_back(table); +        packed.starting = adding.GetScore(); +        table = adding; +        table.MutableFlags() |= kPackedFlag; +        return; +      } +      PartialEdge loser; +      if (adding.GetScore() > table.GetScore()) { +        loser = table; +        table = adding; +      } else { +        loser = adding; +      } +      // table is winner, loser is loser... +      packed_.construct(table, loser); +    } + +    struct Packed { +      Packed(PartialEdge winner, PartialEdge loser)  +        : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { +        winner.MutableNote().vp = this; +        winner.MutableFlags() |= kPackedFlag; +        loser.MutableFlags() &= ~kPackedFlag; +      } +      Note original; +      Score starting; +      std::vector<PartialEdge> others; +    }; + +    template <class Output> class Unpack { +      public: +        explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + +        void NewHypothesis(PartialEdge edge) { +          if (edge.GetFlags() & kPackedFlag) { +            Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut); +            edge.SetNote(packed.original); +            edge.MutableFlags() = 0; +            std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); +            for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { +              PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); +              copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); +              copy.MutableFlags() = 0; +              copy.SetNote(i->GetNote()); +              memcpy(copy.NT(), edge.NT(), copy_size); +              output_.NewHypothesis(copy); +            } +          } +          output_.NewHypothesis(edge); +        } + +        void FinishedSearch() { +          output_.FinishedSearch(); +        } + +      private: +        Output &output_; +        Dedupe &owner_; +    }; + +    EdgeGenerator behind_; + +    typedef boost::unordered_map<uint64_t, PartialEdge> Table; +    Table table_; + +    boost::object_pool<Packed> packed_; + +    static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@  #include "search/edge_generator.hh"  #include "lm/left.hh" +#include "lm/model.hh"  #include "lm/partial.hh"  #include "search/context.hh"  #include "search/vertex.hh" @@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi        *cover = *(cover + 1);      }    } -  update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); +  update.SetScore(update.GetScore() + adjustment * context.LMWeight());  }  } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@  #define SEARCH_EDGE_GENERATOR__  #include "search/edge.hh" -#include "search/note.hh"  #include "search/types.hh"  #include <queue> diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { -  public: -    Final() {} - -    Final(util::Pool &pool, Score score, Arity arity, Note note)  -      : Header(pool.Allocate(Size(arity)), arity) { -      SetScore(score); -      SetNote(note); -    } - -    // These are arrays of length GetArity(). -    Final *Children() { -      return reinterpret_cast<Final*>(After()); -    } -    const Final *Children() const { -      return reinterpret_cast<const Final*>(After()); -    } - -  private: -    static std::size_t Size(Arity arity) { -      return kHeaderSize + arity * sizeof(const Final); -    } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@  // Header consisting of Score, Arity, and Note -#include "search/note.hh"  #include "search/types.hh"  #include <stdint.h> @@ -24,6 +23,9 @@ class Header {      bool operator<(const Header &other) const {        return GetScore() < other.GetScore();      } +    bool operator>(const Header &other) const { +      return GetScore() > other.GetScore(); +    }      Arity GetArity() const {        return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header {        *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;      } +    uint8_t *Base() { return base_; } +    const uint8_t *Base() const { return base_; } +    protected:      Header() : base_(NULL) {} +    explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {} +      Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {        *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;      } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { +  assert(!partials.empty()); +  std::vector<PartialEdge>::iterator end; +  if (partials.size() > keep) { +    end = partials.begin() + keep; +    std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); +  } else { +    end = partials.end(); +  } +  for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { +    queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); +  } +} + +Score NBestList::TopAfterConstructor() const { +  assert(revealed_.empty()); +  return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { +  while (revealed_.size() < n && !queue_.empty()) { +    MoveTop(pool); +  } +  return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { +  if (index + 1 < revealed_.size()) +    return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); +  if (queue_.empty())  +    return -INFINITY; +  if (index + 1 == revealed_.size()) +    return queue_.top().GetScore() - revealed_[index].GetScore(); +  assert(index == revealed_.size()); + +  MoveTop(pool); + +  if (queue_.empty()) return -INFINITY; +  return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { +  assert(index <= revealed_.size()); +  if (index == revealed_.size()) MoveTop(pool); +  return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { +  assert(!queue_.empty()); +  QueueEntry entry(queue_.top()); +  queue_.pop(); +  RevealedRef *const children_begin = entry.Children(); +  RevealedRef *const children_end = children_begin + entry.GetArity(); +  Score basis = entry.GetScore(); +  for (RevealedRef *child = children_begin; child != children_end; ++child) { +    Score change = child->in_->Visit(pool, child->index_); +    if (change != -INFINITY) { +      assert(change < 0.001); +      QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); +      std::copy(children_begin, child, new_entry.Children()); +      RevealedRef *update = new_entry.Children() + (child - children_begin); +      update->in_ = child->in_; +      update->index_ = child->index_ + 1; +      std::copy(child + 1, children_end, update + 1); +      queue_.push(new_entry); +    } +    // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. +    if (child->index_) break; +  } + +  // Convert QueueEntry to Applied.  This leaves some unused memory.   +  void *overwrite = entry.Children(); +  for (unsigned int i = 0; i < entry.GetArity(); ++i) { +    RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); +    *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); +  } +  revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { +  assert(!partials.empty()); +  NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); +  return NBestComplete( +      list, +      partials.front().CompletedState(), // All partials have the same state +      list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { +  return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include <boost/pool/object_pool.hpp> + +#include <cstddef> +#include <queue> +#include <vector> + +#include <assert.h> + +namespace search { + +class NBestList; + +class NBestList { +  private: +    class RevealedRef { +      public:  +        explicit RevealedRef(History history)  +          : in_(static_cast<NBestList*>(history)), index_(0) {} + +      private: +        friend class NBestList; + +        NBestList *in_; +        std::size_t index_; +    }; +     +    typedef GenericApplied<RevealedRef> QueueEntry; + +  public: +    NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep); + +    Score TopAfterConstructor() const; + +    const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n); + +  private: +    Score Visit(util::Pool &pool, std::size_t index); + +    Applied Get(util::Pool &pool, std::size_t index); + +    void MoveTop(util::Pool &pool); + +    typedef std::vector<Applied> Revealed; +    Revealed revealed_; + +    typedef std::priority_queue<QueueEntry> Queue; +    Queue queue_; +}; + +class NBest { +  public: +    typedef std::vector<PartialEdge> Combine; + +    explicit NBest(const NBestConfig &config) : config_(config) {} + +    void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const { +      existing.push_back(addition); +    } + +    NBestComplete Complete(std::vector<PartialEdge> &partials); + +    const std::vector<Applied> &Extract(History root); + +  private: +    const NBestConfig config_; + +    boost::object_pool<NBestList> list_pool_; + +    util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { -  const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@  #include "search/rule.hh" +#include "lm/model.hh"  #include "search/context.hh" -#include "search/final.hh"  #include <ostream> @@ -9,35 +9,35 @@  namespace search { -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { -  unsigned int oov_count = 0; -  float prob = 0.0; -  const Model &model = context.LanguageModel(); -  const lm::WordIndex oov = model.GetVocabulary().NotFound(); -  for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { -    lm::ngram::RuleScore<Model> scorer(model, *(writing++)); -    // TODO: optimize -    if (prepend_bos && (word == words.begin())) { -      scorer.BeginSentence(); -    } -    for (; ; ++word) { -      if (word == words.end()) { -        prob += scorer.Finish(); -        return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); -      } -      if (*word == kNonTerminal) break; -      if (*word == oov) ++oov_count; +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) { +  ScoreRuleRet ret; +  ret.prob = 0.0; +  ret.oov = 0; +  const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); +  lm::ngram::RuleScore<Model> scorer(model, *(writing++)); +  std::vector<lm::WordIndex>::const_iterator word = words.begin(); +  if (word != words.end() && *word == bos) { +    scorer.BeginSentence(); +    ++word; +  } +  for (; word != words.end(); ++word) { +    if (*word == kNonTerminal) { +      ret.prob += scorer.Finish(); +      scorer.Reset(*(writing++)); +    } else { +      if (*word == oov) ++ret.oov;        scorer.Terminal(*word);      } -    prob += scorer.Finish();    } +  ret.prob += scorer.Finish(); +  return ret;  } -template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);  } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@  namespace search { -template <class Model> class Context; -  const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { +  Score prob; +  unsigned int oov; +}; + +// Pass <s> and </s> normally.   +// Indicate non-terminals with kNonTerminal.   +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);  } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@  #include <stdint.h> +namespace lm { namespace ngram { class ChartState; } } +  namespace search {  typedef float Score;  typedef uint32_t Arity; +union Note { +  const void *vp; +}; + +typedef void *History; + +struct NBestComplete { +  NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)  +    : history(in_history), state(&in_state), score(in_score) {} + +  History history; +  const lm::ngram::ChartState *state; +  Score score; +}; +  } // namespace search  #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve  } // namespace -void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { +void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {    if (Complete()) { -    assert(end_.Valid()); +    assert(end_);      assert(extend_.empty()); -    bound_ = end_.GetScore();      return;    } -  if (extend_.size() == 1 && parent_ptr) { -    *parent_ptr = extend_[0]; -    extend_[0]->SortAndSet(context, parent_ptr); +  if (extend_.size() == 1) { +    parent_ptr = extend_[0]; +    extend_[0]->RecursiveSortAndSet(context, parent_ptr);      context.DeleteVertexNode(this);      return;    }    for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->SortAndSet(context, &*i); +    (*i)->RecursiveSortAndSet(context, *i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { +  // This is the root.  The root might be empty.   +  if (extend_.empty()) { +    bound_ = -INFINITY; +    return; +  } +  // The root cannot be replaced.  There's always one transition.   +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->RecursiveSortAndSet(context, *i);    }    std::sort(extend_.begin(), extend_.end(), GreaterByBound());    bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@  #define SEARCH_VERTEX__  #include "lm/left.hh" -#include "search/final.hh"  #include "search/types.hh"  #include <boost/unordered_set.hpp> @@ -10,6 +9,7 @@  #include <queue>  #include <vector> +#include <math.h>  #include <stdint.h>  namespace search { @@ -18,7 +18,7 @@ class ContextBase;  class VertexNode {    public: -    VertexNode() {} +    VertexNode() : end_() {}      void InitRoot() {        extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode {        state_.left.length = 0;        state_.right.length = 0;        right_full_ = false; -      end_ = Final(); +      end_ = History();      }      lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode {        extend_.push_back(next);      } -    void SetEnd(Final end) { -      assert(!end_.Valid()); +    void SetEnd(History end, Score score) { +      assert(!end_);        end_ = end; +      bound_ = score;      } -    void SortAndSet(ContextBase &context, VertexNode **parent_pointer); +    void SortAndSet(ContextBase &context);      // Should only happen to a root node when the entire vertex is empty.         bool Empty() const { -      return !end_.Valid() && extend_.empty(); +      return !end_ && extend_.empty();      }      bool Complete() const { -      return end_.Valid(); +      return end_;      }      const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode {      }      // Will be invalid unless this is a leaf.    -    const Final End() const { return end_; } +    const History End() const { return end_; }      const VertexNode &operator[](size_t index) const {        return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode {      }    private: +    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); +      std::vector<VertexNode*> extend_;      lm::ngram::ChartState state_;      bool right_full_;      Score bound_; -    Final end_; +    History end_;  };  class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex {      const lm::ngram::ChartState &State() const { return back_->State(); }      bool RightFull() const { return back_->RightFull(); } -    Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } +    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }      unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex {        return ret;      } -    const Final End() const { +    const History End() const {        return back_->End();      } @@ -130,16 +133,18 @@ class PartialVertex {      unsigned int index_;  }; +template <class Output> class VertexGenerator; +  class Vertex {    public:      Vertex() {}      PartialVertex RootPartial() const { return PartialVertex(root_); } -    const Final BestChild() const { +    const History BestChild() const {        PartialVertex top(RootPartial());        if (top.Empty()) { -        return Final(); +        return History();        } else {          PartialVertex continuation;          while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex {      }    private: -    friend class VertexGenerator; - +    template <class Output> friend class VertexGenerator; +    template <class Output> friend class RootVertexGenerator;      VertexNode root_;  }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@  #include "search/context.hh"  #include "search/edge.hh" +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> +  #include <stdint.h>  namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { -  gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200  namespace {  const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); -// Parallel structure to VertexNode.   -struct Trie { -  Trie() : under(NULL) {} - -  VertexNode *under; -  boost::unordered_map<uint64_t, Trie> extend; -}; -  Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {    Trie &next = node.extend[added];    if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n    return next;  } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { -  Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); -  Final *child_out = final.Children(); -  const PartialVertex *part = partial.NT(); -  const PartialVertex *const part_end_loop = part + partial.GetArity(); -  for (; part != part_end_loop; ++part, ++child_out) -    *child_out = part->End(); - -  starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { -  const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { +  const lm::ngram::ChartState &state = *end.state;    unsigned char left = 0, right = 0;    Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {    }    node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); -  CompleteTransition(context, *node, partial); +  node->under->SetEnd(end.history, end.score);  } -} // namespace - -void VertexGenerator::FinishedSearch() { -  Trie root; -  root.under = &gen_.root_; -  for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { -    AddHypothesis(context_, root, i->second); -  } -  root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION  } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@  #define SEARCH_VERTEX_GENERATOR__  #include "search/edge.hh" +#include "search/types.hh"  #include "search/vertex.hh"  #include <boost/unordered_map.hpp> +#include <boost/version.hpp>  namespace lm {  namespace ngram { @@ -15,21 +17,44 @@ class ChartState;  namespace search {  class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode.   +struct Trie { +  Trie() : under(NULL) {} + +  VertexNode *under; +  boost::unordered_map<uint64_t, Trie> extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list.    +template <class Output> class VertexGenerator {    public: -    VertexGenerator(ContextBase &context, Vertex &gen); +    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { +      gen.root_.InitRoot(); +    }      void NewHypothesis(PartialEdge partial) { -      const lm::ngram::ChartState &state = partial.CompletedState(); -      std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); -      if (!ret.second && ret.first->second < partial) { -        ret.first->second = partial; -      } +      nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);      } -    void FinishedSearch(); +    void FinishedSearch() { +#if BOOST_VERSION > 104200 +      Trie root; +      root.under = &gen_.root_; +      for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { +        AddHypothesis(context_, root, nbest_.Complete(i->second)); +      } +      existing_.clear(); +      root.under->SortAndSet(context_); +#else +      UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif +    }      const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator {      Vertex &gen_; -    typedef boost::unordered_map<uint64_t, PartialEdge> Existing; +    typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;      Existing existing_; + +    Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node.  In theory, this should happen naturally due to state collapsing with +// <s> and </s>.  If that's the case, VertexGenerator is fine, though it will +// make one connection.   +template <class Output> class RootVertexGenerator { +  public: +    RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + +    void NewHypothesis(PartialEdge partial) { +      out_.Add(combine_, partial); +    } + +    void FinishedSearch() { +      gen_.root_.InitRoot(); +      NBestComplete completed(out_.Complete(combine_)); +      gen_.root_.SetEnd(completed.history, completed.score); +    } + +  private: +    Vertex &gen_; +     +    typename Output::Combine combine_; +    Output &out_;  };  } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include <cstdlib> - -namespace search { - -namespace { -struct Insert { -  void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { -    std::string copy(name.data(), name.size()); -    map[copy] = score; -  } -}; - -struct DotProduct { -  search::Score total; -  DotProduct() : total(0.0) {} - -  void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { -    boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); -    if (i != map.end())  -      total += score * i->second; -  } -}; - -template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { -  for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { -    util::TokenIter<util::SingleCharacter> equals(*spaces, '='); -    UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); -    StringPiece name(*equals); -    UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); -    char *end; -    // Assumes proper termination.   -    double value = std::strtod(equals->data(), &end); -    UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); -    UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); -    op(map, name, value); -  } -} - -} // namespace - -Weights::Weights(StringPiece text) { -  Insert op; -  Parse<Map, Insert>(text, map_, op); -  lm_ = Steal("LanguageModel"); -  oov_ = Steal("OOV"); -  word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { -  DotProduct dot; -  Parse<const Map, DotProduct>(text, map_, dot); -  return dot.total; -} - -float Weights::Steal(const std::string &str) { -  Map::iterator i(map_.find(str)); -  if (i == map_.end()) { -    return 0.0; -  } else { -    float ret = i->second; -    map_.erase(i); -    return ret; -  } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept.   -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include <boost/unordered_map.hpp> - -#include <string> - -namespace search { - -class WeightParseException : public util::Exception { -  public: -    WeightParseException() {} -    ~WeightParseException() throw() {} -}; - -class Weights { -  public: -    // Parses weights, sets lm_weight_, removes it from map_. -    explicit Weights(StringPiece text); - -    // Just the three scores we care about adding.    -    Weights(Score lm, Score oov, Score word_penalty); - -    Score DotNoLM(StringPiece text) const; - -    Score LM() const { return lm_; } - -    Score OOV() const { return oov_; } - -    Score WordPenalty() const { return word_penalty_; } - -    // Mostly for testing.   -    const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } - -  private: -    float Steal(const std::string &str); - -    typedef boost::unordered_map<std::string, Score> Map; - -    Map map_; - -    Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include <boost/test/unit_test.hpp> -#include <boost/test/floating_point_comparison.hpp> - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ -  i = parsed.find(string); \ -  BOOST_REQUIRE(i != parsed.end()); \ -  BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { -  // These are not real feature weights.   -  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); -  const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); -  boost::unordered_map<std::string, search::Score>::const_iterator i; -  CHECK_WEIGHT(0.0, "rarity"); -  CHECK_WEIGHT(0.0, "phrase-SGT"); -  CHECK_WEIGHT(9.45117, "phrase-TGS"); -  CHECK_WEIGHT(2.33833, "lexical-SGT"); -  BOOST_CHECK(parsed.end() == parsed.find("lm")); -  BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); -  CHECK_WEIGHT(-28.3317, "lexical-TGS"); -  CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { -  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); -  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); -  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); -  BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search | 
