summaryrefslogtreecommitdiff
path: root/klm/search/nbest.cc
blob: ec3322c978b6415fac2eab91bbc2b67473f3716c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include "search/nbest.hh"

#include "util/pool.hh"

#include <algorithm>
#include <functional>
#include <queue>

#include <assert.h>
#include <math.h>

namespace search {

NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
  assert(!partials.empty());
  std::vector<PartialEdge>::iterator end;
  if (partials.size() > keep) {
    end = partials.begin() + keep;
    std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
  } else {
    end = partials.end();
  }
  for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
    queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
  }
}

Score NBestList::TopAfterConstructor() const {
  assert(revealed_.empty());
  return queue_.top().GetScore();
}

const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
  while (revealed_.size() < n && !queue_.empty()) {
    MoveTop(pool);
  }
  return revealed_;
}

Score NBestList::Visit(util::Pool &pool, std::size_t index) {
  if (index + 1 < revealed_.size())
    return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
  if (queue_.empty()) 
    return -INFINITY;
  if (index + 1 == revealed_.size())
    return queue_.top().GetScore() - revealed_[index].GetScore();
  assert(index == revealed_.size());

  MoveTop(pool);

  if (queue_.empty()) return -INFINITY;
  return queue_.top().GetScore() - revealed_[index].GetScore();
}

Applied NBestList::Get(util::Pool &pool, std::size_t index) {
  assert(index <= revealed_.size());
  if (index == revealed_.size()) MoveTop(pool);
  return revealed_[index];
}

void NBestList::MoveTop(util::Pool &pool) {
  assert(!queue_.empty());
  QueueEntry entry(queue_.top());
  queue_.pop();
  RevealedRef *const children_begin = entry.Children();
  RevealedRef *const children_end = children_begin + entry.GetArity();
  Score basis = entry.GetScore();
  for (RevealedRef *child = children_begin; child != children_end; ++child) {
    Score change = child->in_->Visit(pool, child->index_);
    if (change != -INFINITY) {
      assert(change < 0.001);
      QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote());
      std::copy(children_begin, child, new_entry.Children());
      RevealedRef *update = new_entry.Children() + (child - children_begin);
      update->in_ = child->in_;
      update->index_ = child->index_ + 1;
      std::copy(child + 1, children_end, update + 1);
      queue_.push(new_entry);
    }
    // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010.
    if (child->index_) break;
  }

  // Convert QueueEntry to Applied.  This leaves some unused memory.  
  void *overwrite = entry.Children();
  for (unsigned int i = 0; i < entry.GetArity(); ++i) {
    RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
    *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
  }
  revealed_.push_back(Applied(entry.Base()));
}

NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
  assert(!partials.empty());
  NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
  return NBestComplete(
      list,
      partials.front().CompletedState(), // All partials have the same state
      list->TopAfterConstructor());
}

const std::vector<Applied> &NBest::Extract(History history) {
  return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
}

} // namespace search