From 671c21451542e2dd20e45b4033d44d8e8735f87b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 3 Dec 2009 16:33:55 -0500 Subject: initial check in --- vest/mr_vest_map.cc | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 vest/mr_vest_map.cc (limited to 'vest/mr_vest_map.cc') diff --git a/vest/mr_vest_map.cc b/vest/mr_vest_map.cc new file mode 100644 index 00000000..80e84218 --- /dev/null +++ b/vest/mr_vest_map.cc @@ -0,0 +1,98 @@ +#include +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "sparse_vector.h" +#include "scorer.h" +#include "viterbi_envelope.h" +#include "inside_outside.h" +#include "error_surface.h" +#include "hg.h" +#include "hg_io.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("loss_function,l",po::value()->default_value("ibm_bleu"), "Loss function being optimized") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +bool ReadSparseVectorString(const string& s, SparseVector* v) { + vector fields; + Tokenize(s, ';', &fields); + if (fields.empty()) return false; + for (int i = 0; i < fields.size(); ++i) { + vector pair(2); + Tokenize(fields[i], '=', &pair); + if (pair.size() != 2) { + cerr << "Error parsing vector string: " << fields[i] << endl; + return false; + } + v->set_value(FD::Convert(pair[0]), atof(pair[1].c_str())); + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string loss_function = conf["loss_function"].as(); + ScoreType type = ScoreTypeFromString(loss_function); + DocScorer ds(type, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; + Hypergraph hg; + string last_file; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + istringstream is(line); + int sent_id; + string file, s_origin, s_axis; + is >> file >> sent_id >> s_origin >> s_axis; + SparseVector origin; + assert(ReadSparseVectorString(s_origin, &origin)); + SparseVector axis; + assert(ReadSparseVectorString(s_axis, &axis)); + // cerr << "File: " << file << "\nAxis: " << axis << "\n X: " << origin << endl; + if (last_file != file) { + last_file = file; + ReadFile rf(file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + } + ViterbiEnvelopeWeightFunction wf(origin, axis); + ViterbiEnvelope ve = Inside(hg, NULL, wf); + ErrorSurface es; + ds[sent_id]->ComputeErrorSurface(ve, &es); + //cerr << "Viterbi envelope has " << ve.size() << " segments\n"; + cerr << "Error surface has " << es.size() << " segments\n"; + string val; + es.Serialize(&val); + cout << 'M' << ' ' << s_origin << ' ' << s_axis << '\t'; + B64::b64encode(val.c_str(), val.size(), &cout); + cout << endl; + } + return 0; +} -- cgit v1.2.3