summaryrefslogtreecommitdiff
path: root/decoder/viterbi.h
blob: 3092f6daeb0c67b88f15a6a66cdbd53fc188dcb1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#ifndef _VITERBI_H_
#define _VITERBI_H_

#include <vector>
#include "prob.h"
#include "hg.h"
#include "tdict.h"

std::string viterbi_stats(Hypergraph const& hg, std::string const& name="forest", bool estring=true, bool etree=false, bool derivation_tree=false);

/// computes for each hg node the best (according to WeightType/WeightFunction) derivation, and some homomorphism (bottom up expression tree applied through Traversal) of it. T is the "return type" of Traversal, which is called only once for the best edge for a node's result (i.e. result will start default constructed)
//TODO: make T a typename inside Traversal and WeightType a typename inside WeightFunction?
// Traversal must implement:
//  typedef T Result;
//  void operator()(Hypergraph::Edge const& e,const vector<const Result*>& ants, Result* result) const;
// WeightFunction must implement:
//  typedef prob_t Weight;
//  Weight operator()(Hypergraph::Edge const& e) const;
template<class Traversal,class WeightFunction>
typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
                   typename Traversal::Result* result,
                   const Traversal& traverse,
                   const WeightFunction& weight) {
  typedef typename Traversal::Result T;
  typedef typename WeightFunction::Weight WeightType;
  const int num_nodes = hg.nodes_.size();
  std::vector<T> vit_result(num_nodes);
  std::vector<WeightType> vit_weight(num_nodes, WeightType());

  for (int i = 0; i < num_nodes; ++i) {
    const Hypergraph::Node& cur_node = hg.nodes_[i];
    WeightType* const cur_node_best_weight = &vit_weight[i];
    T*          const cur_node_best_result = &vit_result[i];

    const unsigned num_in_edges = cur_node.in_edges_.size();
    if (num_in_edges == 0) {
      *cur_node_best_weight = WeightType(1);
      continue;
    }
    Hypergraph::Edge const* edge_best=0;
    for (unsigned j = 0; j < num_in_edges; ++j) {
      const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
      WeightType score = weight(edge);
      for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k)
        score *= vit_weight[edge.tail_nodes_[k]];
      if (!edge_best || *cur_node_best_weight < score) {
        *cur_node_best_weight = score;
        edge_best=&edge;
      }
    }
    assert(edge_best);
    Hypergraph::Edge const& edgeb=*edge_best;
    std::vector<const T*> antsb(edgeb.tail_nodes_.size());
    for (unsigned k = 0; k < edgeb.tail_nodes_.size(); ++k)
      antsb[k] = &vit_result[edgeb.tail_nodes_[k]];
    traverse(edgeb, antsb, cur_node_best_result);
  }
  if (vit_result.empty())
    return WeightType(0);
  std::swap(*result, vit_result.back());
  return vit_weight.back();
}


/*
template<typename Traversal,typename WeightFunction>
typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
                   typename Traversal::Result* result)
{
  Traversal traverse;
  WeightFunction weight;
  return Viterbi(hg,result,traverse,weight);
}

template<class Traversal,class WeightFunction=EdgeProb>
typename WeightFunction::Weight Viterbi(const Hypergraph& hg,
                   typename Traversal::Result* result,
                   Traversal const& traverse=Traversal()
  )
{
  WeightFunction weight;
  return Viterbi(hg,result,traverse,weight);
}
*/

//spec for EdgeProb
template<class Traversal>
prob_t Viterbi(const Hypergraph& hg,
                   typename Traversal::Result* result,
                   Traversal const& traverse=Traversal()
  )
{
  EdgeProb weight;
  return Viterbi(hg,result,traverse,weight);
}

struct PathLengthTraversal {
  typedef int Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const int*>& ants,
                  int* result) const {
    (void) edge;
    *result = 1;
    for (unsigned i = 0; i < ants.size(); ++i) *result += *ants[i];
  }
};

struct ESentenceTraversal {
  typedef std::vector<WordID> Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const Result*>& ants,
                  Result* result) const {
    edge.rule_->ESubstitute(ants, result);
  }
};

struct ELengthTraversal {
  typedef int Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const int*>& ants,
                  int* result) const {
    *result = edge.rule_->ELength() - edge.rule_->Arity();
    for (unsigned i = 0; i < ants.size(); ++i) *result += *ants[i];
  }
};

struct FSentenceTraversal {
  typedef std::vector<WordID> Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const Result*>& ants,
                  Result* result) const {
    edge.rule_->FSubstitute(ants, result);
  }
};

// create a strings of the form (S (X the man) (X said (X he (X would (X go)))))
struct ETreeTraversal {
  ETreeTraversal() : left("("), space(" "), right(")") {}
  const std::string left;
  const std::string space;
  const std::string right;
  typedef std::vector<WordID> Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const Result*>& ants,
                  Result* result) const {
    Result tmp;
    edge.rule_->ESubstitute(ants, &tmp);
    const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
    if (cat == "Goal")
      result->swap(tmp);
    else
      TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
                          result);
  }
};

struct FTreeTraversal {
  FTreeTraversal() : left("("), space(" "), right(")") {}
  const std::string left;
  const std::string space;
  const std::string right;
  typedef std::vector<WordID> Result;
  void operator()(const Hypergraph::Edge& edge,
                  const std::vector<const Result*>& ants,
                  Result* result) const {
    Result tmp;
    edge.rule_->FSubstitute(ants, &tmp);
    const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
    if (cat == "Goal")
      result->swap(tmp);
    else
      TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
                          result);
  }
};

struct ViterbiPathTraversal {
  typedef std::vector<Hypergraph::Edge const*> Result;
  void operator()(const Hypergraph::Edge& edge,
                  std::vector<Result const*> const& ants,
                  Result* result) const {
    for (unsigned i = 0; i < ants.size(); ++i)
      for (unsigned j = 0; j < ants[i]->size(); ++j)
        result->push_back((*ants[i])[j]);
    result->push_back(&edge);
  }
};

struct FeatureVectorTraversal {
  typedef FeatureVector Result;
  void operator()(Hypergraph::Edge const& edge,
                  std::vector<Result const*> const& ants,
                  Result* result) const {
    for (unsigned i = 0; i < ants.size(); ++i)
      *result+=*ants[i];
    *result+=edge.feature_values_;
  }
};


std::string JoshuaVisualizationString(const Hypergraph& hg);
prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result);
std::string ViterbiETree(const Hypergraph& hg);
prob_t ViterbiFSentence(const Hypergraph& hg, std::vector<WordID>* result);
std::string ViterbiFTree(const Hypergraph& hg);
int ViterbiELength(const Hypergraph& hg);
int ViterbiPathLength(const Hypergraph& hg);

/// if weights supplied, assert viterbi prob = features.dot(*weights) (exception if fatal, cerr warn if not).  return features (sum over all edges in viterbi derivation)
FeatureVector ViterbiFeatures(Hypergraph const& hg,WeightVector const* weights=0,bool fatal_dotprod_disagreement=false);

#endif