summaryrefslogtreecommitdiff
path: root/decoder/ff_tagger.cc
blob: 019315a2168ddfdb69da3e0419c977ddf7c21a54 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include "ff_tagger.h"

#include <sstream>

#include "tdict.h"
#include "sentence_metadata.h"
#include "stringlib.h"

using namespace std;

Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) :
  FeatureFunction(sizeof(WordID)) {
   no_uni_ = (LowercaseString(param) == "no_uni");
}

void Tagger_BigramIndicator::FireFeature(const WordID& left,
                                 const WordID& right,
                                 SparseVector<double>* features) const {
  if (no_uni_ && right == 0) return;
  int& fid = fmap_[left][right];
  if (!fid) {
    ostringstream os;
    if (right == 0) {
      os << "Uni:" << TD::Convert(left);
    } else {
      os << "Bi:";
      if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
      os << '_';
      if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
    }
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  WordID& out_context = *static_cast<WordID*>(context);
  const int arity = edge.Arity();
  if (arity == 0) {
    out_context = edge.rule_->e_[0];
    FireFeature(out_context, 0, features);
  } else if (arity == 1) {
    out_context = *static_cast<const WordID*>(ant_contexts[0]);
  } else if (arity == 2) {
    WordID left = *static_cast<const WordID*>(ant_contexts[0]);
    WordID right = *static_cast<const WordID*>(ant_contexts[1]);
    if (edge.i_ == 0 && edge.j_ == 2)
      FireFeature(-1, left, features);
    FireFeature(left, right, features);
    if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength())
      FireFeature(right, -1, features);
    out_context = right;
  } else {
    assert(!"shouldn't happen");
  }
}

void LexicalPairIndicator::PrepareForInput(const SentenceMetadata& smeta) {
  lexmap_->PrepareForInput(smeta);
}

LexicalPairIndicator::LexicalPairIndicator(const std::string& param) {
  name_ = "Id";
  if (param.size()) {
    // name corpus.f emap.txt
    vector<string> params;
    SplitOnWhitespace(param, &params);
    if (params.size() != 3) {
      cerr << "LexicalPairIndicator takes 3 parameters: <name> <corpus.src.txt> <trgmap.txt>\n";
      cerr << " * may be used for corpus.src.txt or trgmap.txt to use surface forms\n";
      cerr << " Received: " << param << endl;
      abort();
    }
    name_ = params[0];
    lexmap_.reset(new FactoredLexiconHelper(params[1], params[2]));
  } else {
    lexmap_.reset(new FactoredLexiconHelper);
  }
}

void LexicalPairIndicator::FireFeature(WordID src,
                                      WordID trg,
                                      SparseVector<double>* features) const {
  int& fid = fmap_[src][trg];
  if (!fid) {
    ostringstream os;
    os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void LexicalPairIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  // inline WordID SourceWordAtPosition(const int i);
  // inline WordID CoarsenedTargetWordForTarget(const WordID surface_target);
  if (edge.Arity() == 0) {
    const WordID src = lexmap_->SourceWordAtPosition(edge.i_);
    const vector<WordID>& ew = edge.rule_->e_;
    assert(ew.size() == 1);
    const WordID trg = lexmap_->CoarsenedTargetWordForTarget(ew[0]);
    FireFeature(src, trg, features);
  }
}

OutputIndicator::OutputIndicator(const std::string& param) {}

void OutputIndicator::FireFeature(WordID trg,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[trg];
  if (!fid) {
    static map<WordID, WordID> escape;
    if (escape.empty()) {
      escape[TD::Convert("=")] = TD::Convert("__EQ");
      escape[TD::Convert(";")] = TD::Convert("__SC");
      escape[TD::Convert(",")] = TD::Convert("__CO");
    }
    if (escape.count(trg)) trg = escape[trg];
    ostringstream os;
    os << "T:" << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void OutputIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  const vector<WordID>& ew = edge.rule_->e_;
  for (int i = 0; i < ew.size(); ++i) {
    const WordID& e = ew[i];
    if (e > 0) FireFeature(e, features);
  }
}