summaryrefslogtreecommitdiff
path: root/vest/mbr_kbest.cc
blob: 5d70b4e2dcccbbe0ada5f061bbfc16b74f4952f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <iostream>
#include <vector>

#include <boost/program_options.hpp>

#include "prob.h"
#include "tdict.h"
#include "scorer.h"
#include "filelib.h"
#include "stringlib.h"

using namespace std;

namespace po = boost::program_options;

void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
  po::options_description opts("Configuration options");
  opts.add_options()
        ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)")
        ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function")
        ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")
        ("output_list,L", "Show reranked list as output")
        ("help,h", "Help");
  po::options_description dcmdline_options;
  dcmdline_options.add(opts);
  po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
  bool flag = false;
  if (flag || conf->count("help")) {
    cerr << dcmdline_options << endl;
    exit(1);
  }
}

struct LossComparer {
  bool operator()(const pair<vector<WordID>, double>& a, const pair<vector<WordID>, double>& b) const {
    return a.second < b.second;
  }
};

bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
  static string cache_id;
  static pair<vector<WordID>, prob_t> cache_pair;
  list->clear();
  string cur_id;
  if (cache_pair.first.size() > 0) {
    list->push_back(cache_pair);
    cur_id = cache_id;
    cache_pair.first.clear();
  }
  string line;
  string tstr;
  while(*in) {
    getline(*in, line);
    if (line.empty()) continue;
    size_t p1 = line.find(" ||| ");
    if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
    size_t p2 = line.find(" ||| ", p1 + 4);
    if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
    size_t p3 = line.rfind(" ||| ");
    cache_id = line.substr(0, p1);
    tstr = line.substr(p1 + 5, p2 - p1 - 5);
    double val = strtod(line.substr(p3 + 5).c_str(), NULL);
    TD::ConvertSentence(tstr, &cache_pair.first);
    cache_pair.second.logeq(val);
    if (cur_id.empty()) cur_id = cache_id;
    if (cur_id == cache_id) {
      list->push_back(cache_pair);
      *sent_id = cur_id;
      cache_pair.first.clear();
    } else { break; }
  }
  return !list->empty();
}

int main(int argc, char** argv) {
  po::variables_map conf;
  InitCommandLine(argc, argv, &conf);
  const string metric = conf["loss_function"].as<string>();
  const bool output_list = conf.count("output_list") > 0;
  const string file = conf["input"].as<string>();
  const double mbr_scale = conf["scale"].as<double>();
  cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl;

  ScoreType type = ScoreTypeFromString(metric);
  vector<pair<vector<WordID>, prob_t> > list;
  ReadFile rf(file);
  string sent_id;
  while(ReadKBestList(rf.stream(), &sent_id, &list)) {
    vector<prob_t> joints(list.size());
    const prob_t max_score = pow(list.front().second, mbr_scale);
    prob_t marginal = prob_t::Zero();
    for (int i = 0 ; i < list.size(); ++i) {
      const prob_t joint = pow(list[i].second, mbr_scale) / max_score;
      joints[i] = joint;
      // cerr << "list[" << i << "] joint=" << log(joint) << endl;
      marginal += joint;
    }
    int mbr_idx = -1;
    vector<double> mbr_scores(output_list ? list.size() : 0);
    double mbr_loss = numeric_limits<double>::max();
    for (int i = 0 ; i < list.size(); ++i) {
      vector<vector<WordID> > refs(1, list[i].first);
      //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
      SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs);
      double wl_acc = 0;
      for (int j = 0; j < list.size(); ++j) {
        if (i != j) {
          Score* s = scorer->ScoreCandidate(list[j].first);
          double loss = 1.0 - s->ComputeScore();
          if (type == TER || type == AER) loss = 1.0 - loss;
          delete s;
          double weighted_loss = loss * (joints[j] / marginal);
          wl_acc += weighted_loss;
          if ((!output_list) && wl_acc > mbr_loss) break;
        }
      }
      if (output_list) mbr_scores[i] = wl_acc;
      if (wl_acc < mbr_loss) {
        mbr_loss = wl_acc;
        mbr_idx = i;
      }
      delete scorer;
    }
    // cerr << "ML translation: " << TD::GetString(list[0].first) << endl;
    cerr << "MBR Best idx: " << mbr_idx << endl;
    if (output_list) {
      for (int i = 0; i < list.size(); ++i)
        list[i].second.logeq(mbr_scores[i]);
      sort(list.begin(), list.end(), LossComparer());
      for (int i = 0; i < list.size(); ++i)
        cout << sent_id << " ||| "
             << TD::GetString(list[i].first) << " ||| "
             << log(list[i].second) << endl;
    } else {
      cout << TD::GetString(list[mbr_idx].first) << endl;
    }
  }
  return 0;
}