summaryrefslogtreecommitdiff
path: root/decoder/ff_tagger.cc
blob: 21d0f8120ea1d557db134196922d91d4dcf2f461 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#include "ff_tagger.h"

#include <sstream>

#include "tdict.h"
#include "sentence_metadata.h"
#include "stringlib.h"

using namespace std;

Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) :
  FeatureFunction(sizeof(WordID)) {}

void Tagger_BigramIdentity::FireFeature(const WordID& left,
                                 const WordID& right,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[left][right];
  if (!fid) {
    ostringstream os;
    if (right == 0) {
      os << "Uni:" << TD::Convert(left);
    } else {
      os << "Bi:";
      if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
      os << '_';
      if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
    }
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  WordID& out_context = *static_cast<WordID*>(context);
  const int arity = edge.Arity();
  if (arity == 0) {
    out_context = edge.rule_->e_[0];
    FireFeature(out_context, 0, features);
  } else if (arity == 2) {
    WordID left = *static_cast<const WordID*>(ant_contexts[0]);
    WordID right = *static_cast<const WordID*>(ant_contexts[1]);
    if (edge.i_ == 0 && edge.j_ == 2)
      FireFeature(-1, left, features);
    FireFeature(left, right, features);
    if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength())
      FireFeature(right, -1, features);
    out_context = right;
  }
}

void LexicalPairIdentity::PrepareForInput(const SentenceMetadata& smeta) {
  lexmap_->PrepareForInput(smeta);
}

LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {
  name_ = "Id";
  if (param.size()) {
    // name corpus.f emap.txt
    vector<string> params;
    SplitOnWhitespace(param, &params);
    if (params.size() != 3) {
      cerr << "LexicalPairIdentity takes 3 parameters: <name> <corpus.src.txt> <trgmap.txt>\n";
      cerr << " * may be used for corpus.src.txt or trgmap.txt to use surface forms\n";
      cerr << " Received: " << param << endl;
      abort();
    }
    name_ = params[0];
    lexmap_.reset(new FactoredLexiconHelper(params[1], params[2]));
  } else {
    lexmap_.reset(new FactoredLexiconHelper);
  }
}

void LexicalPairIdentity::FireFeature(WordID src,
                                      WordID trg,
                                      SparseVector<double>* features) const {
  int& fid = fmap_[src][trg];
  if (!fid) {
    ostringstream os;
    os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  // inline WordID SourceWordAtPosition(const int i);
  // inline WordID CoarsenedTargetWordForTarget(const WordID surface_target);
  if (edge.Arity() == 0) {
    const WordID src = lexmap_->SourceWordAtPosition(edge.i_);
    const vector<WordID>& ew = edge.rule_->e_;
    assert(ew.size() == 1);
    const WordID trg = lexmap_->CoarsenedTargetWordForTarget(ew[0]);
    FireFeature(src, trg, features);
  }
}

OutputIdentity::OutputIdentity(const std::string& param) {}

void OutputIdentity::FireFeature(WordID trg,
                                 SparseVector<double>* features) const {
  int& fid = fmap_[trg];
  if (!fid) {
    static map<WordID, WordID> escape;
    if (escape.empty()) {
      escape[TD::Convert("=")] = TD::Convert("__EQ");
      escape[TD::Convert(";")] = TD::Convert("__SC");
      escape[TD::Convert(",")] = TD::Convert("__CO");
    }
    if (escape.count(trg)) trg = escape[trg];
    ostringstream os;
    os << "T:" << TD::Convert(trg);
    fid = FD::Convert(os.str());
  }
  features->set_value(fid, 1.0);
}

void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
                                     const Hypergraph::Edge& edge,
                                     const std::vector<const void*>& ant_contexts,
                                     SparseVector<double>* features,
                                     SparseVector<double>* estimated_features,
                                     void* context) const {
  const vector<WordID>& ew = edge.rule_->e_;
  for (int i = 0; i < ew.size(); ++i) {
    const WordID& e = ew[i];
    if (e > 0) FireFeature(e, features);
  }
}