summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-01 22:01:07 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-01 22:01:07 -0400
commit0870d4a1f5e14cc7daf553b180d599f09f6614a2 (patch)
tree74aae175745f5895d45d997d291488107275bb64
parent38ae810fe374ff7fb548b1a15f7c2ee7dcd94000 (diff)
mbr fix for non-deduped lists
-rw-r--r--mteval/mbr_kbest.cc38
1 files changed, 28 insertions, 10 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc
index 2bd31566..2519bc01 100644
--- a/mteval/mbr_kbest.cc
+++ b/mteval/mbr_kbest.cc
@@ -1,7 +1,9 @@
#include <iostream>
#include <vector>
+#include <tr1/unordered_map>
#include <boost/program_options.hpp>
+#include <boost/functional/hash.hpp>
#include "prob.h"
#include "tdict.h"
@@ -10,6 +12,7 @@
#include "stringlib.h"
using namespace std;
+using namespace std::tr1;
namespace po = boost::program_options;
@@ -31,27 +34,33 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
+struct ScoreComparer {
+ bool operator()(const pair<vector<WordID>, prob_t>& a, const pair<vector<WordID>, prob_t>& b) const {
+ return a.second > b.second;
+ }
+};
+
struct LossComparer {
bool operator()(const pair<vector<WordID>, prob_t>& a, const pair<vector<WordID>, prob_t>& b) const {
return a.second < b.second;
}
};
-bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
+bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
static string cache_id;
static pair<vector<WordID>, prob_t> cache_pair;
list->clear();
string cur_id;
+ unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > > sent2id;
if (cache_pair.first.size() > 0) {
list->push_back(cache_pair);
+ sent2id[cache_pair.first] = 0;
cur_id = cache_id;
cache_pair.first.clear();
}
string line;
string tstr;
- while(*in) {
- getline(*in, line);
- if (line.empty()) continue;
+ while(getline(*in, line)) {
size_t p1 = line.find(" ||| ");
if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
size_t p2 = line.find(" ||| ", p1 + 4);
@@ -59,16 +68,25 @@ bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, pro
size_t p3 = line.rfind(" ||| ");
cache_id = line.substr(0, p1);
tstr = line.substr(p1 + 5, p2 - p1 - 5);
- double val = strtod(line.substr(p3 + 5).c_str(), NULL);
+ double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale;
TD::ConvertSentence(tstr, &cache_pair.first);
cache_pair.second.logeq(val);
if (cur_id.empty()) cur_id = cache_id;
if (cur_id == cache_id) {
- list->push_back(cache_pair);
+ unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > >::iterator it =
+ sent2id.find(cache_pair.first);
+ if (it == sent2id.end()) {
+ sent2id.insert(make_pair(cache_pair.first, unsigned(list->size())));
+ list->push_back(cache_pair);
+ } else {
+ (*list)[it->second].second += cache_pair.second;
+ // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl;
+ }
*sent_id = cur_id;
cache_pair.first.clear();
} else { break; }
}
+ sort(list->begin(), list->end(), ScoreComparer());
return !list->empty();
}
@@ -87,14 +105,14 @@ int main(int argc, char** argv) {
vector<pair<vector<WordID>, prob_t> > list;
ReadFile rf(file);
string sent_id;
- while(ReadKBestList(rf.stream(), &sent_id, &list)) {
+ while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) {
vector<prob_t> joints(list.size());
- const prob_t max_score = pow(list.front().second, mbr_scale);
+ const prob_t max_score = list.front().second;
prob_t marginal = prob_t::Zero();
for (int i = 0 ; i < list.size(); ++i) {
- const prob_t joint = pow(list[i].second, mbr_scale) / max_score;
+ const prob_t joint = list[i].second / max_score;
joints[i] = joint;
- // cerr << "list[" << i << "] joint=" << log(joint) << endl;
+ //cerr << "list[" << i << "] joint=" << log(joint) << endl;
marginal += joint;
}
int mbr_idx = -1;