diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2014-02-26 00:03:42 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2014-02-26 00:03:42 -0500 |
commit | 7c0ee6a2e22a1ace580ed1dcad65a4c591783135 (patch) | |
tree | b4b872eb7ecbe44923388c1c76e3a547a4b1e9d1 | |
parent | d843587027d815f3a1c9b8dd5394f3fe04ac85fa (diff) |
support multiple inputs in mbr
-rw-r--r-- | mteval/mbr_kbest.cc | 102 |
1 files changed, 65 insertions, 37 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 76d2c7fc..eb36e009 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -23,9 +23,10 @@ namespace po = boost::program_options; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)") + ("input,i",po::value<vector<string> >(), "Files to read k-best lists from") + ("scale,a",po::value<vector<double> >(), "Posterior scaling factors (per file)") + ("offset,b",po::value<vector<double> >(), "Log posterior offsets (per file)") ("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric") - ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from") ("output_list,L", "Show reranked list as output") ("help,h", "Help"); po::options_description dcmdline_options; @@ -50,47 +51,54 @@ struct LossComparer { } }; -bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) { +bool ReadKBestList(const vector<double>& mbr_scale, + const vector<double>& mbr_offset, + const vector<ReadFile*>& rfs, + string* sent_id, + vector<pair<vector<WordID>, prob_t> >* list) { static string cache_id; - static pair<vector<WordID>, prob_t> cache_pair; + pair<vector<WordID>, prob_t> tmp_pair; + static vector<pair<vector<WordID>, prob_t> > cache_pair(rfs.size()); list->clear(); string cur_id; - unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > > sent2id; - if (cache_pair.first.size() > 0) { - list->push_back(cache_pair); - sent2id[cache_pair.first] = 0; + if (cache_pair[0].first.size() > 0) { + for (unsigned i = 0; i < cache_pair.size(); ++i) + list->push_back(cache_pair[i]); cur_id = cache_id; - cache_pair.first.clear(); + cache_pair.clear(); + cache_pair.resize(rfs.size()); } string line; string tstr; - while(getline(*in, line)) { - size_t p1 = line.find(" ||| "); - if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } - size_t p2 = line.find(" ||| ", p1 + 4); - if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } - size_t p3 = line.rfind(" ||| "); - cache_id = line.substr(0, p1); - tstr = line.substr(p1 + 5, p2 - p1 - 5); - double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale; - TD::ConvertSentence(tstr, &cache_pair.first); - cache_pair.second.logeq(val); - if (cur_id.empty()) cur_id = cache_id; - if (cur_id == cache_id) { - unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > >::iterator it = - sent2id.find(cache_pair.first); - if (it == sent2id.end()) { - sent2id.insert(make_pair(cache_pair.first, unsigned(list->size()))); - list->push_back(cache_pair); + for (unsigned fi = 0; fi < rfs.size(); ++fi) { + istream& in = *rfs[fi]->stream(); + while(getline(in, line)) { + size_t p1 = line.find(" ||| "); + if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } + size_t p2 = line.find(" ||| ", p1 + 4); + if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } + size_t p3 = line.rfind(" ||| "); + cache_id = line.substr(0, p1); + tstr = line.substr(p1 + 5, p2 - p1 - 5); + double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale[fi] + mbr_offset[fi]; + TD::ConvertSentence(tstr, &tmp_pair.first); + tmp_pair.second.logeq(val); + if (cur_id.empty()) cur_id = cache_id; + if (cur_id == cache_id) { + list->push_back(tmp_pair); + *sent_id = cur_id; + tmp_pair.first.clear(); } else { - (*list)[it->second].second += cache_pair.second; - // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl; + swap(cache_pair[fi], tmp_pair); + break; } - *sent_id = cur_id; - cache_pair.first.clear(); - } else { break; } + } } sort(list->begin(), list->end(), ScoreComparer()); + // for (unsigned i = 0; i < list->size(); ++i) { + // cerr << TD::GetString((*list)[i].first) << " ||| " << (*list)[i].second << endl; + //} + //cerr << endl; return !list->empty(); } @@ -102,14 +110,34 @@ int main(int argc, char** argv) { const bool is_loss = (UppercaseString(smetric) == "TER"); const bool output_list = conf.count("output_list") > 0; - const string file = conf["input"].as<string>(); - const double mbr_scale = conf["scale"].as<double>(); - cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; + vector<string> file; + if (conf.count("input") == 0) + file.push_back("-"); + else + file = conf["input"].as<vector<string> >(); + vector<double> mbr_scale; + if (conf.count("scale")) mbr_scale = conf["scale"].as<vector<double> >(); + vector<double> mbr_offset; + if (conf.count("offset")) mbr_offset = conf["offset"].as<vector<double> >(); + if (file.size() > mbr_scale.size()) mbr_scale.resize(file.size(), 1.0); + if (file.size() > mbr_offset.size()) mbr_offset.resize(file.size(), 0.0); + if (file.size() != mbr_scale.size()) { + cerr << file.size() << " files specified but " << mbr_scale.size() << " scale factors given!\n"; + return 1; + } + if (file.size() != mbr_offset.size()) { + cerr << file.size() << " files specified but " << mbr_offset.size() << " scale factors given!\n"; + return 1; + } + for (unsigned i = 0; i < file.size(); ++i) + cerr << "Kbest file " << (i+1) << ": " << file[i] << "\t(scale=" << mbr_scale[i] << ", offset=" << mbr_offset[i] << ")\n"; vector<pair<vector<WordID>, prob_t> > list; - ReadFile rf(file); + vector<ReadFile*> rfs(file.size()); + for (unsigned i = 0; i < file.size(); ++i) + rfs[i] = new ReadFile(file[i]); string sent_id; - while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) { + while(ReadKBestList(mbr_scale, mbr_offset, rfs, &sent_id, &list)) { vector<prob_t> joints(list.size()); const prob_t max_score = list.front().second; prob_t marginal = prob_t::Zero(); |