summaryrefslogtreecommitdiff
path: root/klm/search/edge_generator.cc
blob: 56239dfbba2d3e17f61059ca5111b6c1c03275b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include "search/edge_generator.hh"

#include "lm/left.hh"
#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
#include "search/vertex_generator.hh"

#include <numeric>

namespace search {

EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
/*  for (unsigned char i = 0; i < edge.Arity(); ++i) {
    root.nt[i] = edge.GetVertex(i).RootPartial();
  }
  for (unsigned char i = edge.Arity(); i < 2; ++i) {
    root.nt[i] = kBlankPartialVertex;
  }*/
  generate_.push(&root);
  top_score_ = root.score;
}

namespace {

template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
  memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));

  float ret = 0.0;
  lm::ngram::ChartState *before, *after;
  if (victim == 0) {
    before = &update.between[0];
    after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
  } else {
    assert(victim == 1);
    assert(arity == 2);
    before = &update.between[previous.nt[0].Complete() ? 0 : 1];
    after = &update.between[2];
  }
  const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
  const PartialVertex &update_nt = update.nt[victim];
  const lm::ngram::ChartState &update_reveal = update_nt.State();
  float just_after = 0.0;
  if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
    just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
  }
  if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
    ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
  }
  if (update_nt.Complete()) {
    if (update_reveal.left.full) {
      before->left.full = true;
    } else {
      assert(update_reveal.left.length == update_reveal.right.length);
      ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
    }
    if (victim == 0) {
      update.between[0].right = after->right;
    } else {
      update.between[2].left = before->left;
    }
  }
  return previous.score + (ret + just_after) * context.GetWeights().LM();
}

} // namespace

template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
  assert(!generate_.empty());
  PartialEdge &top = *generate_.top();
  generate_.pop();
  unsigned int victim = 0;
  unsigned char lowest_length = 255;
  for (unsigned char i = 0; i != arity_; ++i) {
    if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
      lowest_length = top.nt[i].Length();
      victim = i;
    }
  }
  if (lowest_length == 255) {
    // All states report complete.  
    top.between[0].right = top.between[arity_].right;
    // Now top.between[0] is the full edge state.  
    top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
    return &top;
  }

  unsigned int stay = !victim;
  PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
  float old_bound = top.nt[victim].Bound();
  // The alternate's score will change because alternate.nt[victim] changes.  
  bool split = top.nt[victim].Split(continuation.nt[victim]);
  // top is now the alternate.  

  continuation.nt[stay] = top.nt[stay];
  continuation.score = FastScore(context, victim, arity_, top, continuation);
  // TODO: dedupe?  
  generate_.push(&continuation);

  if (split) {
    // We have an alternate.  
    top.score += top.nt[victim].Bound() - old_bound;
    // TODO: dedupe?  
    generate_.push(&top);
  } else {
    partial_edge_pool.free(&top);
  }

  top_score_ = generate_.top()->score;
  return NULL;
}

template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);

} // namespace search