diff options
Diffstat (limited to 'vest/mbr_kbest.cc')
-rw-r--r-- | vest/mbr_kbest.cc | 140 |
1 files changed, 140 insertions, 0 deletions
diff --git a/vest/mbr_kbest.cc b/vest/mbr_kbest.cc new file mode 100644 index 00000000..5d70b4e2 --- /dev/null +++ b/vest/mbr_kbest.cc @@ -0,0 +1,140 @@ +#include <iostream> +#include <vector> + +#include <boost/program_options.hpp> + +#include "prob.h" +#include "tdict.h" +#include "scorer.h" +#include "filelib.h" +#include "stringlib.h" + +using namespace std; + +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)") + ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function") + ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from") + ("output_list,L", "Show reranked list as output") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct LossComparer { + bool operator()(const pair<vector<WordID>, double>& a, const pair<vector<WordID>, double>& b) const { + return a.second < b.second; + } +}; + +bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) { + static string cache_id; + static pair<vector<WordID>, prob_t> cache_pair; + list->clear(); + string cur_id; + if (cache_pair.first.size() > 0) { + list->push_back(cache_pair); + cur_id = cache_id; + cache_pair.first.clear(); + } + string line; + string tstr; + while(*in) { + getline(*in, line); + if (line.empty()) continue; + size_t p1 = line.find(" ||| "); + if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } + size_t p2 = line.find(" ||| ", p1 + 4); + if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } + size_t p3 = line.rfind(" ||| "); + cache_id = line.substr(0, p1); + tstr = line.substr(p1 + 5, p2 - p1 - 5); + double val = strtod(line.substr(p3 + 5).c_str(), NULL); + TD::ConvertSentence(tstr, &cache_pair.first); + cache_pair.second.logeq(val); + if (cur_id.empty()) cur_id = cache_id; + if (cur_id == cache_id) { + list->push_back(cache_pair); + *sent_id = cur_id; + cache_pair.first.clear(); + } else { break; } + } + return !list->empty(); +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string metric = conf["loss_function"].as<string>(); + const bool output_list = conf.count("output_list") > 0; + const string file = conf["input"].as<string>(); + const double mbr_scale = conf["scale"].as<double>(); + cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; + + ScoreType type = ScoreTypeFromString(metric); + vector<pair<vector<WordID>, prob_t> > list; + ReadFile rf(file); + string sent_id; + while(ReadKBestList(rf.stream(), &sent_id, &list)) { + vector<prob_t> joints(list.size()); + const prob_t max_score = pow(list.front().second, mbr_scale); + prob_t marginal = prob_t::Zero(); + for (int i = 0 ; i < list.size(); ++i) { + const prob_t joint = pow(list[i].second, mbr_scale) / max_score; + joints[i] = joint; + // cerr << "list[" << i << "] joint=" << log(joint) << endl; + marginal += joint; + } + int mbr_idx = -1; + vector<double> mbr_scores(output_list ? list.size() : 0); + double mbr_loss = numeric_limits<double>::max(); + for (int i = 0 ; i < list.size(); ++i) { + vector<vector<WordID> > refs(1, list[i].first); + //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl; + SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs); + double wl_acc = 0; + for (int j = 0; j < list.size(); ++j) { + if (i != j) { + Score* s = scorer->ScoreCandidate(list[j].first); + double loss = 1.0 - s->ComputeScore(); + if (type == TER || type == AER) loss = 1.0 - loss; + delete s; + double weighted_loss = loss * (joints[j] / marginal); + wl_acc += weighted_loss; + if ((!output_list) && wl_acc > mbr_loss) break; + } + } + if (output_list) mbr_scores[i] = wl_acc; + if (wl_acc < mbr_loss) { + mbr_loss = wl_acc; + mbr_idx = i; + } + delete scorer; + } + // cerr << "ML translation: " << TD::GetString(list[0].first) << endl; + cerr << "MBR Best idx: " << mbr_idx << endl; + if (output_list) { + for (int i = 0; i < list.size(); ++i) + list[i].second.logeq(mbr_scores[i]); + sort(list.begin(), list.end(), LossComparer()); + for (int i = 0; i < list.size(); ++i) + cout << sent_id << " ||| " + << TD::GetString(list[i].first) << " ||| " + << log(list[i].second) << endl; + } else { + cout << TD::GetString(list[mbr_idx].first) << endl; + } + } + return 0; +} + |