summaryrefslogtreecommitdiff
path: root/training/mira/ada_opt_sm.cc
blob: e5d2540145c679d28a80be16deccd447cb9ce3e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#include "config.h"

#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>

#include "filelib.h"
#include "stringlib.h"
#include "weights.h"
#include "sparse_vector.h"
#include "candidate_set.h"
#include "sentence_metadata.h"
#include "ns.h"
#include "ns_docscorer.h"
#include "verbose.h"
#include "hg.h"
#include "ff_register.h"
#include "decoder.h"
#include "fdict.h"
#include "sampler.h"

using namespace std;
namespace po = boost::program_options;

boost::shared_ptr<MT19937> rng;
vector<training::CandidateSet> kbests;
SparseVector<weight_t> G, u, lambdas;
double pseudo_doc_decay = 0.9;

bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
    ("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file")
    ("devset,d",po::value<string>(),"[REQ] Source/reference development set")
    ("weights,w",po::value<string>(),"Initial feature weights file")
    ("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)")
    ("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)")
    ("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)")
    ("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)")
    ("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)")
    ("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)")
    ("eta0", po::value<double>()->default_value(0.1), "Initial step size")
    ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
    ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount")
    ("pseudo_doc,e", "Use pseudo-documents for approximate scoring")
    ("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles");
  po::options_description clo("Command line options");
  clo.add_options()
    ("config", po::value<string>(), "Configuration file")
    ("help,H", "Print this help message and exit");
  po::options_description dconfig_options, dcmdline_options;
  dconfig_options.add(opts);
  dcmdline_options.add(opts).add(clo);
  
  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  if (conf->count("config")) {
    ifstream config((*conf)["config"].as<string>().c_str());
    po::store(po::parse_config_file(config, dconfig_options), *conf);
  }
  po::notify(*conf);

  if (conf->count("help")
          || !conf->count("decoder_config")
          || !conf->count("devset")) {
    cerr << dcmdline_options << endl;
    return false;
  }
  return true;
}

struct TrainingObserver : public DecoderObserver {
  explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {}

  const EvaluationMetric& metric;
  const int kbest_size;
  const SegmentEvaluator* cur_eval;
  SufficientStats pdoc;
  unsigned hi, vi, fi;  // hope, viterbi, fear

  void SetSegmentEvaluator(const SegmentEvaluator* eval) {
    cur_eval = eval;
  }

  virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) {
    cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl;
    abort();
  }

  unsigned CostAugmentedDecode(const training::CandidateSet& cs,
                               const SparseVector<double>& w,
                               double alpha = 0) {
    unsigned best_i = 0;
    double best = -numeric_limits<double>::infinity();
    for (unsigned i = 0; i < cs.size(); ++i) {
      double s = cs[i].fmap.dot(w);
      if (alpha)
        s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc);
      if (s > best) {
        best = s;
        best_i = i;
      }
    }
    return best_i;
  }

  virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
    pdoc *= pseudo_doc_decay;
    const unsigned sent_id = smeta.GetSentenceID();
    kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval);
    vi = CostAugmentedDecode(kbests[sent_id], lambdas);
    hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0);
    fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0);
    cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl;
    pdoc += kbests[sent_id][vi].eval_feats;  // update pseudodoc stats
  }
};

int main(int argc, char** argv) {
  SetSilent(true);  // turn off verbose decoder output
  register_feature_functions();

  po::variables_map conf;
  if (!InitCommandLine(argc, argv, &conf)) return 1;

  if (conf.count("random_seed"))
    rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
  else
    rng.reset(new MT19937);
  
  string metric_name = UppercaseString(conf["mt_metric"].as<string>());
  if (metric_name == "COMBI") {
    cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n";
    metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5";
  } else if (metric_name == "BLEU") {
    cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n";
    metric_name = "IBM_BLEU";
  }
  EvaluationMetric* metric = EvaluationMetric::Instance(metric_name);
  DocumentScorer ds(metric, conf["devset"].as<string>());
  cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
  kbests.resize(ds.size());
  double eta = 0.001;

  ReadFile ini_rf(conf["decoder_config"].as<string>());
  Decoder decoder(ini_rf.stream());

  vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
  if (conf.count("weights")) {
    Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights);
    Weights::InitSparseVector(dense_weights, &lambdas);
  }

  TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>());

  unsigned num = 200;
  for (unsigned iter = 1; iter < num; ++iter) {
    lambdas.init_vector(&dense_weights);
    unsigned sent_id = rng->next() * ds.size();
    cerr << "Learning from sentence id: " << sent_id << endl;
    observer.SetSegmentEvaluator(ds[sent_id]);
    decoder.SetId(sent_id);
    decoder.Decode(ds[sent_id]->src, &observer);
    if (observer.vi != observer.hi) {  // viterbi != hope
      SparseVector<double> grad = kbests[sent_id][observer.fi].fmap;
      grad -= kbests[sent_id][observer.hi].fmap;
      cerr << "GRAD: " << grad << endl;
      const SparseVector<double>& g = grad;
#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4)
      for (auto& gi : g) {
#else
      for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
        const pair<unsigned,double>& gi = *it;
#endif
        if (gi.second) {
          u[gi.first] += gi.second;
          G[gi.first] += gi.second * gi.second;
          lambdas.set_value(gi.first, 1.0);  // this is a dummy value to trigger recomputation
        }
      }
      for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
        const pair<unsigned,double>& xi = *it;
        double z = fabs(u[xi.first] / iter) - 0.0;
        double s = 1;
        if (u[xi.first] > 0) s = -1;
        if (z > 0 && G[xi.first]) {
          lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first]));
        } else {
          lambdas.set_value(xi.first, 0.0);
        }
      }
    }
  }
  cerr << "Optimization complete.\n";
  Weights::WriteToFile("-", dense_weights, true);
  return 0;
}