From ea79e535d69f6854d01c62e3752971fb6730d8e7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 1 Oct 2012 22:01:07 -0400 Subject: mbr fix for non-deduped lists --- mteval/mbr_kbest.cc | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) (limited to 'mteval/mbr_kbest.cc') diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 2bd31566..2519bc01 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -1,7 +1,9 @@ #include #include +#include #include +#include #include "prob.h" #include "tdict.h" @@ -10,6 +12,7 @@ #include "stringlib.h" using namespace std; +using namespace std::tr1; namespace po = boost::program_options; @@ -31,27 +34,33 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } +struct ScoreComparer { + bool operator()(const pair, prob_t>& a, const pair, prob_t>& b) const { + return a.second > b.second; + } +}; + struct LossComparer { bool operator()(const pair, prob_t>& a, const pair, prob_t>& b) const { return a.second < b.second; } }; -bool ReadKBestList(istream* in, string* sent_id, vector, prob_t> >* list) { +bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector, prob_t> >* list) { static string cache_id; static pair, prob_t> cache_pair; list->clear(); string cur_id; + unordered_map, unsigned, boost::hash > > sent2id; if (cache_pair.first.size() > 0) { list->push_back(cache_pair); + sent2id[cache_pair.first] = 0; cur_id = cache_id; cache_pair.first.clear(); } string line; string tstr; - while(*in) { - getline(*in, line); - if (line.empty()) continue; + while(getline(*in, line)) { size_t p1 = line.find(" ||| "); if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } size_t p2 = line.find(" ||| ", p1 + 4); @@ -59,16 +68,25 @@ bool ReadKBestList(istream* in, string* sent_id, vector, pro size_t p3 = line.rfind(" ||| "); cache_id = line.substr(0, p1); tstr = line.substr(p1 + 5, p2 - p1 - 5); - double val = strtod(line.substr(p3 + 5).c_str(), NULL); + double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale; TD::ConvertSentence(tstr, &cache_pair.first); cache_pair.second.logeq(val); if (cur_id.empty()) cur_id = cache_id; if (cur_id == cache_id) { - list->push_back(cache_pair); + unordered_map, unsigned, boost::hash > >::iterator it = + sent2id.find(cache_pair.first); + if (it == sent2id.end()) { + sent2id.insert(make_pair(cache_pair.first, unsigned(list->size()))); + list->push_back(cache_pair); + } else { + (*list)[it->second].second += cache_pair.second; + // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl; + } *sent_id = cur_id; cache_pair.first.clear(); } else { break; } } + sort(list->begin(), list->end(), ScoreComparer()); return !list->empty(); } @@ -87,14 +105,14 @@ int main(int argc, char** argv) { vector, prob_t> > list; ReadFile rf(file); string sent_id; - while(ReadKBestList(rf.stream(), &sent_id, &list)) { + while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) { vector joints(list.size()); - const prob_t max_score = pow(list.front().second, mbr_scale); + const prob_t max_score = list.front().second; prob_t marginal = prob_t::Zero(); for (int i = 0 ; i < list.size(); ++i) { - const prob_t joint = pow(list[i].second, mbr_scale) / max_score; + const prob_t joint = list[i].second / max_score; joints[i] = joint; - // cerr << "list[" << i << "] joint=" << log(joint) << endl; + //cerr << "list[" << i << "] joint=" << log(joint) << endl; marginal += joint; } int mbr_idx = -1; -- cgit v1.2.3