summaryrefslogtreecommitdiff
path: root/mteval
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
committerKenneth Heafield <github@kheafield.com>2012-10-22 12:07:20 +0100
commit5f98fe5c4f2a2090eeb9d30c030305a70a8347d1 (patch)
tree9b6002f850e6dea1e3400c6b19bb31a9cdf3067f /mteval
parentcf9994131993b40be62e90e213b1e11e6b550143 (diff)
parent21825a09d97c2e0afd20512f306fb25fed55e529 (diff)
Merge remote branch 'upstream/master'
Conflicts: Jamroot bjam decoder/Jamfile decoder/cdec.cc dpmert/Jamfile jam-files/sanity.jam klm/lm/Jamfile klm/util/Jamfile mira/Jamfile
Diffstat (limited to 'mteval')
-rw-r--r--mteval/Jamfile8
-rw-r--r--mteval/mbr_kbest.cc38
-rw-r--r--mteval/ns_docscorer.cc4
-rw-r--r--mteval/ns_docscorer.h2
4 files changed, 31 insertions, 21 deletions
diff --git a/mteval/Jamfile b/mteval/Jamfile
deleted file mode 100644
index 3ed2c2cc..00000000
--- a/mteval/Jamfile
+++ /dev/null
@@ -1,8 +0,0 @@
-import testing ;
-
-lib mteval : ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc ..//utils : <include>. : : <include>. <library>..//z ;
-exe fast_score : fast_score.cc mteval ..//utils ..//boost_program_options ;
-exe mbr_kbest : mbr_kbest.cc mteval ..//utils ..//boost_program_options ;
-alias programs : fast_score mbr_kbest ;
-
-unit-test scorer_test : scorer_test.cc mteval ..//utils ..//z ..//boost_unit_test_framework : <testing.arg>$(TOP)/mteval/test_data ;
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc
index 2bd31566..2519bc01 100644
--- a/mteval/mbr_kbest.cc
+++ b/mteval/mbr_kbest.cc
@@ -1,7 +1,9 @@
#include <iostream>
#include <vector>
+#include <tr1/unordered_map>
#include <boost/program_options.hpp>
+#include <boost/functional/hash.hpp>
#include "prob.h"
#include "tdict.h"
@@ -10,6 +12,7 @@
#include "stringlib.h"
using namespace std;
+using namespace std::tr1;
namespace po = boost::program_options;
@@ -31,27 +34,33 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
+struct ScoreComparer {
+ bool operator()(const pair<vector<WordID>, prob_t>& a, const pair<vector<WordID>, prob_t>& b) const {
+ return a.second > b.second;
+ }
+};
+
struct LossComparer {
bool operator()(const pair<vector<WordID>, prob_t>& a, const pair<vector<WordID>, prob_t>& b) const {
return a.second < b.second;
}
};
-bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
+bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) {
static string cache_id;
static pair<vector<WordID>, prob_t> cache_pair;
list->clear();
string cur_id;
+ unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > > sent2id;
if (cache_pair.first.size() > 0) {
list->push_back(cache_pair);
+ sent2id[cache_pair.first] = 0;
cur_id = cache_id;
cache_pair.first.clear();
}
string line;
string tstr;
- while(*in) {
- getline(*in, line);
- if (line.empty()) continue;
+ while(getline(*in, line)) {
size_t p1 = line.find(" ||| ");
if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); }
size_t p2 = line.find(" ||| ", p1 + 4);
@@ -59,16 +68,25 @@ bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, pro
size_t p3 = line.rfind(" ||| ");
cache_id = line.substr(0, p1);
tstr = line.substr(p1 + 5, p2 - p1 - 5);
- double val = strtod(line.substr(p3 + 5).c_str(), NULL);
+ double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale;
TD::ConvertSentence(tstr, &cache_pair.first);
cache_pair.second.logeq(val);
if (cur_id.empty()) cur_id = cache_id;
if (cur_id == cache_id) {
- list->push_back(cache_pair);
+ unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > >::iterator it =
+ sent2id.find(cache_pair.first);
+ if (it == sent2id.end()) {
+ sent2id.insert(make_pair(cache_pair.first, unsigned(list->size())));
+ list->push_back(cache_pair);
+ } else {
+ (*list)[it->second].second += cache_pair.second;
+ // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl;
+ }
*sent_id = cur_id;
cache_pair.first.clear();
} else { break; }
}
+ sort(list->begin(), list->end(), ScoreComparer());
return !list->empty();
}
@@ -87,14 +105,14 @@ int main(int argc, char** argv) {
vector<pair<vector<WordID>, prob_t> > list;
ReadFile rf(file);
string sent_id;
- while(ReadKBestList(rf.stream(), &sent_id, &list)) {
+ while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) {
vector<prob_t> joints(list.size());
- const prob_t max_score = pow(list.front().second, mbr_scale);
+ const prob_t max_score = list.front().second;
prob_t marginal = prob_t::Zero();
for (int i = 0 ; i < list.size(); ++i) {
- const prob_t joint = pow(list[i].second, mbr_scale) / max_score;
+ const prob_t joint = list[i].second / max_score;
joints[i] = joint;
- // cerr << "list[" << i << "] joint=" << log(joint) << endl;
+ //cerr << "list[" << i << "] joint=" << log(joint) << endl;
marginal += joint;
}
int mbr_idx = -1;
diff --git a/mteval/ns_docscorer.cc b/mteval/ns_docscorer.cc
index 28a2fd09..f72ad115 100644
--- a/mteval/ns_docscorer.cc
+++ b/mteval/ns_docscorer.cc
@@ -16,7 +16,7 @@ void DocumentScorer::Init(const EvaluationMetric* metric,
const string& src_file,
bool verbose) {
scorers_.clear();
- cerr << "Loading references (" << ref_files.size() << " files)\n";
+ if (verbose) cerr << "Loading references (" << ref_files.size() << " files)\n";
assert(src_file.empty());
std::vector<ReadFile> ifs(ref_files.begin(),ref_files.end());
for (int i=0; i < ref_files.size(); ++i) ifs[i].Init(ref_files[i]);
@@ -55,6 +55,6 @@ void DocumentScorer::Init(const EvaluationMetric* metric,
++line;
}
}
- cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";
+ if (verbose) cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";
}
diff --git a/mteval/ns_docscorer.h b/mteval/ns_docscorer.h
index 170ac627..a5757258 100644
--- a/mteval/ns_docscorer.h
+++ b/mteval/ns_docscorer.h
@@ -5,7 +5,7 @@
#include <string>
#include <boost/shared_ptr.hpp>
-struct EvaluationMetric;
+class EvaluationMetric;
struct SegmentEvaluator;
class DocumentScorer {
public: