summaryrefslogtreecommitdiff
path: root/mteval/mbr_kbest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'mteval/mbr_kbest.cc')
-rw-r--r--mteval/mbr_kbest.cc102
1 files changed, 65 insertions, 37 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc
index 76d2c7fc..eb36e009 100644
--- a/mteval/mbr_kbest.cc
+++ b/mteval/mbr_kbest.cc
@@ -23,9 +23,10 @@ namespace po = boost::program_options;
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)")
+ ("input,i",po::value<vector<string> >(), "Files to read k-best lists from")
+ ("scale,a",po::value<vector<double> >(), "Posterior scaling factors (per file)")
+ ("offset,b",po::value<vector<double> >(), "Log posterior offsets (per file)")
("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric")
- ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")
("output_list,L", "Show reranked list as output")
("help,h", "Help");
po::options_description dcmdline_options;
@@ -50,47 +51,54 @@ struct LossComparer {
}
};
-bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
+bool ReadKBestList(const vector<double>& mbr_scale,
+ const vector<double>& mbr_offset,
+ const vector<ReadFile*>& rfs,
+ string* sent_id,
+ vector<pair<vector<WordID>, prob_t> >* list) {
static string cache_id;
- static pair<vector<WordID>, prob_t> cache_pair;
+ pair<vector<WordID>, prob_t> tmp_pair;
+ static vector<pair<vector<WordID>, prob_t> > cache_pair(rfs.size());
list->clear();
string cur_id;
- unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > > sent2id;
- if (cache_pair.first.size() > 0) {
- list->push_back(cache_pair);
- sent2id[cache_pair.first] = 0;
+ if (cache_pair[0].first.size() > 0) {
+ for (unsigned i = 0; i < cache_pair.size(); ++i)
+ list->push_back(cache_pair[i]);
cur_id = cache_id;
- cache_pair.first.clear();
+ cache_pair.clear();
+ cache_pair.resize(rfs.size());
}
string line;
string tstr;
- while(getline(*in, line)) {
- size_t p1 = line.find(" ||| ");
- if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
- size_t p2 = line.find(" ||| ", p1 + 4);
- if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
- size_t p3 = line.rfind(" ||| ");
- cache_id = line.substr(0, p1);
- tstr = line.substr(p1 + 5, p2 - p1 - 5);
- double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale;
- TD::ConvertSentence(tstr, &cache_pair.first);
- cache_pair.second.logeq(val);
- if (cur_id.empty()) cur_id = cache_id;
- if (cur_id == cache_id) {
- unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > >::iterator it =
- sent2id.find(cache_pair.first);
- if (it == sent2id.end()) {
- sent2id.insert(make_pair(cache_pair.first, unsigned(list->size())));
- list->push_back(cache_pair);
+ for (unsigned fi = 0; fi < rfs.size(); ++fi) {
+ istream& in = *rfs[fi]->stream();
+ while(getline(in, line)) {
+ size_t p1 = line.find(" ||| ");
+ if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
+ size_t p2 = line.find(" ||| ", p1 + 4);
+ if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
+ size_t p3 = line.rfind(" ||| ");
+ cache_id = line.substr(0, p1);
+ tstr = line.substr(p1 + 5, p2 - p1 - 5);
+ double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale[fi] + mbr_offset[fi];
+ TD::ConvertSentence(tstr, &tmp_pair.first);
+ tmp_pair.second.logeq(val);
+ if (cur_id.empty()) cur_id = cache_id;
+ if (cur_id == cache_id) {
+ list->push_back(tmp_pair);
+ *sent_id = cur_id;
+ tmp_pair.first.clear();
} else {
- (*list)[it->second].second += cache_pair.second;
- // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl;
+ swap(cache_pair[fi], tmp_pair);
+ break;
}
- *sent_id = cur_id;
- cache_pair.first.clear();
- } else { break; }
+ }
}
sort(list->begin(), list->end(), ScoreComparer());
+ // for (unsigned i = 0; i < list->size(); ++i) {
+ // cerr << TD::GetString((*list)[i].first) << " ||| " << (*list)[i].second << endl;
+ //}
+ //cerr << endl;
return !list->empty();
}
@@ -102,14 +110,34 @@ int main(int argc, char** argv) {
const bool is_loss = (UppercaseString(smetric) == "TER");
const bool output_list = conf.count("output_list") > 0;
- const string file = conf["input"].as<string>();
- const double mbr_scale = conf["scale"].as<double>();
- cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl;
+ vector<string> file;
+ if (conf.count("input") == 0)
+ file.push_back("-");
+ else
+ file = conf["input"].as<vector<string> >();
+ vector<double> mbr_scale;
+ if (conf.count("scale")) mbr_scale = conf["scale"].as<vector<double> >();
+ vector<double> mbr_offset;
+ if (conf.count("offset")) mbr_offset = conf["offset"].as<vector<double> >();
+ if (file.size() > mbr_scale.size()) mbr_scale.resize(file.size(), 1.0);
+ if (file.size() > mbr_offset.size()) mbr_offset.resize(file.size(), 0.0);
+ if (file.size() != mbr_scale.size()) {
+ cerr << file.size() << " files specified but " << mbr_scale.size() << " scale factors given!\n";
+ return 1;
+ }
+ if (file.size() != mbr_offset.size()) {
+ cerr << file.size() << " files specified but " << mbr_offset.size() << " scale factors given!\n";
+ return 1;
+ }
+ for (unsigned i = 0; i < file.size(); ++i)
+ cerr << "Kbest file " << (i+1) << ": " << file[i] << "\t(scale=" << mbr_scale[i] << ", offset=" << mbr_offset[i] << ")\n";
vector<pair<vector<WordID>, prob_t> > list;
- ReadFile rf(file);
+ vector<ReadFile*> rfs(file.size());
+ for (unsigned i = 0; i < file.size(); ++i)
+ rfs[i] = new ReadFile(file[i]);
string sent_id;
- while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) {
+ while(ReadKBestList(mbr_scale, mbr_offset, rfs, &sent_id, &list)) {
vector<prob_t> joints(list.size());
const prob_t max_score = list.front().second;
prob_t marginal = prob_t::Zero();