summaryrefslogtreecommitdiff
path: root/decoder/ff_tagger.cc
blob: 05de8ba3ae47b6c8f6e437394f56b66c9a9d891c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include "ff_tagger.h"

#include "tdict.h"
#include "sentence_metadata.h"

#include <sstream>

using namespace std;

Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) :
  FeatureFunction(sizeof(WordID)) {}

void Tagger_BigramIdentity::FireFeature(const WordID& left,
                                 const WordID& right,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[left][right];
  if (!fid) {
    ostringstream os;
    if (right == 0) {
      os << "Uni:" << TD::Convert(left);
    } else {
      os << "Bi:";
      if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
      os << '_';
      if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
    }
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  WordID& out_context = *static_cast<WordID*>(context);
  const int arity = edge.Arity();
  if (arity == 0) {
    out_context = edge.rule_->e_[0];
    FireFeature(out_context, 0, features);
  } else if (arity == 2) {
    WordID left = *static_cast<const WordID*>(ant_contexts[0]);
    WordID right = *static_cast<const WordID*>(ant_contexts[1]);
    if (edge.i_ == 0 && edge.j_ == 2)
      FireFeature(-1, left, features);
    FireFeature(left, right, features);
    if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength())
      FireFeature(right, -1, features);
    out_context = right;
  }
}

LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {}

void LexicalPairIdentity::FireFeature(WordID src,
                                 WordID trg,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[src][trg];
  if (!fid) {
    static map<WordID, WordID> escape;
    if (escape.empty()) {
      escape[TD::Convert("=")] = TD::Convert("__EQ");
      escape[TD::Convert(";")] = TD::Convert("__SC");
      escape[TD::Convert(",")] = TD::Convert("__CO");
    }
    if (escape.count(src)) src = escape[src];
    if (escape.count(trg)) trg = escape[trg];
    ostringstream os;
    os << "Id:" << TD::Convert(src) << ':' << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  const vector<WordID>& ew = edge.rule_->e_;
  const vector<WordID>& fw = edge.rule_->f_;
  for (int i = 0; i < ew.size(); ++i) {
    const WordID& e = ew[i];
    if (e <= 0) continue;
    for (int j = 0; j < fw.size(); ++j) {
      const WordID& f = fw[j];
      if (f <= 0) continue;
      FireFeature(f, e, features);
    }
  }
}

OutputIdentity::OutputIdentity(const std::string& param) {}

void OutputIdentity::FireFeature(WordID trg,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[trg];
  if (!fid) {
    static map<WordID, WordID> escape;
    if (escape.empty()) {
      escape[TD::Convert("=")] = TD::Convert("__EQ");
      escape[TD::Convert(";")] = TD::Convert("__SC");
      escape[TD::Convert(",")] = TD::Convert("__CO");
    }
    if (escape.count(trg)) trg = escape[trg];
    ostringstream os;
    os << "T:" << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  const vector<WordID>& ew = edge.rule_->e_;
  for (int i = 0; i < ew.size(); ++i) {
    const WordID& e = ew[i];
    if (e > 0) FireFeature(e, features);
  }
}