diff options
| author | Patrick Simianer <p@simianer.de> | 2014-02-28 14:09:55 +0100 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2014-02-28 14:09:55 +0100 | 
| commit | 1abb6039dc6f86f930f9cb1ace048cb0fbf16b0b (patch) | |
| tree | b396b79a1cba971ede206ef7e0054573fc050455 /mteval | |
| parent | bb5b6464826c765f4795381830acae158987f46b (diff) | |
| parent | 324a978e7d766a3864e42efc0938fb6c9ef5a01c (diff) | |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'mteval')
| -rw-r--r-- | mteval/mbr_kbest.cc | 102 | 
1 files changed, 65 insertions, 37 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 76d2c7fc..eb36e009 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -23,9 +23,10 @@ namespace po = boost::program_options;  void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options() -        ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)") +        ("input,i",po::value<vector<string> >(), "Files to read k-best lists from") +        ("scale,a",po::value<vector<double> >(), "Posterior scaling factors (per file)") +        ("offset,b",po::value<vector<double> >(), "Log posterior offsets (per file)")          ("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric") -        ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")          ("output_list,L", "Show reranked list as output")          ("help,h", "Help");    po::options_description dcmdline_options; @@ -50,47 +51,54 @@ struct LossComparer {    }  }; -bool ReadKBestList(const double mbr_scale, istream* in, string* sent_id, vector<pair<vector<WordID>, prob_t> >* list) { +bool ReadKBestList(const vector<double>& mbr_scale, +                   const vector<double>& mbr_offset, +                   const vector<ReadFile*>& rfs, +                   string* sent_id, +                   vector<pair<vector<WordID>, prob_t> >* list) {    static string cache_id; -  static pair<vector<WordID>, prob_t> cache_pair; +  pair<vector<WordID>, prob_t> tmp_pair; +  static vector<pair<vector<WordID>, prob_t> > cache_pair(rfs.size());    list->clear();    string cur_id; -  unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > > sent2id; -  if (cache_pair.first.size() > 0) { -    list->push_back(cache_pair); -    sent2id[cache_pair.first] = 0; +  if (cache_pair[0].first.size() > 0) { +    for (unsigned i = 0; i < cache_pair.size(); ++i) +      list->push_back(cache_pair[i]);      cur_id = cache_id; -    cache_pair.first.clear(); +    cache_pair.clear(); +    cache_pair.resize(rfs.size());    }    string line;    string tstr; -  while(getline(*in, line)) { -    size_t p1 = line.find(" ||| "); -    if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } -    size_t p2 = line.find(" ||| ", p1 + 4); -    if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } -    size_t p3 = line.rfind(" ||| "); -    cache_id = line.substr(0, p1); -    tstr = line.substr(p1 + 5, p2 - p1 - 5); -    double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale; -    TD::ConvertSentence(tstr, &cache_pair.first); -    cache_pair.second.logeq(val); -    if (cur_id.empty()) cur_id = cache_id; -    if (cur_id == cache_id) { -      unordered_map<vector<WordID>, unsigned, boost::hash<vector<WordID> > >::iterator it = -        sent2id.find(cache_pair.first); -      if (it == sent2id.end()) { -        sent2id.insert(make_pair(cache_pair.first, unsigned(list->size()))); -        list->push_back(cache_pair); +  for (unsigned fi = 0; fi < rfs.size(); ++fi) { +    istream& in = *rfs[fi]->stream(); +    while(getline(in, line)) { +      size_t p1 = line.find(" ||| "); +      if (p1 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } +      size_t p2 = line.find(" ||| ", p1 + 4); +      if (p2 == string::npos) { cerr << "Bad format: " << line << endl; abort(); } +      size_t p3 = line.rfind(" ||| "); +      cache_id = line.substr(0, p1); +      tstr = line.substr(p1 + 5, p2 - p1 - 5); +      double val = strtod(line.substr(p3 + 5).c_str(), NULL) * mbr_scale[fi] + mbr_offset[fi]; +      TD::ConvertSentence(tstr, &tmp_pair.first); +      tmp_pair.second.logeq(val); +      if (cur_id.empty()) cur_id = cache_id; +      if (cur_id == cache_id) { +        list->push_back(tmp_pair); +        *sent_id = cur_id; +        tmp_pair.first.clear();        } else { -        (*list)[it->second].second += cache_pair.second; -        // cerr << "Cruch: " << line << "\n newp=" << (*list)[it->second].second << endl; +        swap(cache_pair[fi], tmp_pair); +        break;        } -      *sent_id = cur_id; -      cache_pair.first.clear(); -    } else { break; } +    }    }    sort(list->begin(), list->end(), ScoreComparer()); +  // for (unsigned i = 0; i < list->size(); ++i) { +  //  cerr << TD::GetString((*list)[i].first) << " ||| " << (*list)[i].second << endl; +  //} +  //cerr << endl;    return !list->empty();  } @@ -102,14 +110,34 @@ int main(int argc, char** argv) {    const bool is_loss = (UppercaseString(smetric) == "TER");    const bool output_list = conf.count("output_list") > 0; -  const string file = conf["input"].as<string>(); -  const double mbr_scale = conf["scale"].as<double>(); -  cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; +  vector<string> file; +  if (conf.count("input") == 0) +    file.push_back("-"); +  else +    file = conf["input"].as<vector<string> >(); +  vector<double> mbr_scale; +  if (conf.count("scale")) mbr_scale = conf["scale"].as<vector<double> >(); +  vector<double> mbr_offset; +  if (conf.count("offset")) mbr_offset = conf["offset"].as<vector<double> >(); +  if (file.size() > mbr_scale.size()) mbr_scale.resize(file.size(), 1.0); +  if (file.size() > mbr_offset.size()) mbr_offset.resize(file.size(), 0.0); +  if (file.size() != mbr_scale.size()) { +    cerr << file.size() << " files specified but " << mbr_scale.size() << " scale factors given!\n"; +    return 1; +  } +  if (file.size() != mbr_offset.size()) { +    cerr << file.size() << " files specified but " << mbr_offset.size() << " scale factors given!\n"; +    return 1; +  } +  for (unsigned i = 0; i < file.size(); ++i) +    cerr << "Kbest file " << (i+1) << ": " << file[i] << "\t(scale=" << mbr_scale[i] << ", offset=" << mbr_offset[i] << ")\n";    vector<pair<vector<WordID>, prob_t> > list; -  ReadFile rf(file); +  vector<ReadFile*> rfs(file.size()); +  for (unsigned i = 0; i < file.size(); ++i) +    rfs[i] = new ReadFile(file[i]);    string sent_id; -  while(ReadKBestList(mbr_scale, rf.stream(), &sent_id, &list)) { +  while(ReadKBestList(mbr_scale, mbr_offset, rfs, &sent_id, &list)) {      vector<prob_t> joints(list.size());      const prob_t max_score = list.front().second;      prob_t marginal = prob_t::Zero();  | 
