summaryrefslogtreecommitdiff
path: root/mteval/ns_ext.cc
diff options
context:
space:
mode:
Diffstat (limited to 'mteval/ns_ext.cc')
-rw-r--r--mteval/ns_ext.cc130
1 files changed, 130 insertions, 0 deletions
diff --git a/mteval/ns_ext.cc b/mteval/ns_ext.cc
new file mode 100644
index 00000000..956708af
--- /dev/null
+++ b/mteval/ns_ext.cc
@@ -0,0 +1,130 @@
+#include "ns_ext.h"
+
+#include <cstdio> // popen
+#include <cstdlib>
+#include <cstring>
+#include <unistd.h>
+#include <sstream>
+#include <iostream>
+#include <cassert>
+
+#include "stringlib.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct NScoreServer {
+ NScoreServer(const std::string& cmd);
+ ~NScoreServer();
+
+ float ComputeScore(const std::vector<float>& fields);
+ void Evaluate(const std::vector<std::vector<WordID> >& refs, const std::vector<WordID>& hyp, std::vector<float>* fields);
+
+ private:
+ void RequestResponse(const std::string& request, std::string* response);
+ int p2c[2];
+ int c2p[2];
+};
+
+NScoreServer::NScoreServer(const string& cmd) {
+ cerr << "Invoking " << cmd << " ..." << endl;
+ if (pipe(p2c) < 0) { perror("pipe"); exit(1); }
+ if (pipe(c2p) < 0) { perror("pipe"); exit(1); }
+ pid_t cpid = fork();
+ if (cpid < 0) { perror("fork"); exit(1); }
+ if (cpid == 0) { // child
+ close(p2c[1]);
+ close(c2p[0]);
+ dup2(p2c[0], 0);
+ close(p2c[0]);
+ dup2(c2p[1], 1);
+ close(c2p[1]);
+ cerr << "Exec'ing from child " << cmd << endl;
+ vector<string> vargs;
+ SplitOnWhitespace(cmd, &vargs);
+ const char** cargv = static_cast<const char**>(malloc(sizeof(const char*) * vargs.size()));
+ for (unsigned i = 1; i < vargs.size(); ++i) cargv[i-1] = vargs[i].c_str();
+ cargv[vargs.size() - 1] = NULL;
+ execvp(vargs[0].c_str(), (char* const*)cargv);
+ } else { // parent
+ close(c2p[1]);
+ close(p2c[0]);
+ }
+ string dummy;
+ RequestResponse("SCORE ||| Reference initialization string . ||| Testing initialization string .", &dummy);
+ assert(dummy.size() > 0);
+ cerr << "Connection established.\n";
+}
+
+NScoreServer::~NScoreServer() {
+ // TODO close stuff, join stuff
+}
+
+float NScoreServer::ComputeScore(const vector<float>& fields) {
+ ostringstream os;
+ os << "EVAL |||";
+ for (unsigned i = 0; i < fields.size(); ++i)
+ os << ' ' << fields[i];
+ string sres;
+ RequestResponse(os.str(), &sres);
+ return strtod(sres.c_str(), NULL);
+}
+
+void NScoreServer::Evaluate(const vector<vector<WordID> >& refs, const vector<WordID>& hyp, vector<float>* fields) {
+ ostringstream os;
+ os << "SCORE";
+ for (unsigned i = 0; i < refs.size(); ++i) {
+ os << " |||";
+ for (unsigned j = 0; j < refs[i].size(); ++j) {
+ os << ' ' << TD::Convert(refs[i][j]);
+ }
+ }
+ os << " |||";
+ for (unsigned i = 0; i < hyp.size(); ++i) {
+ os << ' ' << TD::Convert(hyp[i]);
+ }
+ string sres;
+ RequestResponse(os.str(), &sres);
+ istringstream is(sres);
+ float val;
+ fields->clear();
+ while(is >> val)
+ fields->push_back(val);
+}
+
+#define MAX_BUF 16000
+
+void NScoreServer::RequestResponse(const string& request, string* response) {
+// cerr << "@SERVER: " << request << endl;
+ string x = request + "\n";
+ write(p2c[1], x.c_str(), x.size());
+ char buf[MAX_BUF];
+ size_t n = read(c2p[0], buf, MAX_BUF);
+ while (n < MAX_BUF && buf[n-1] != '\n')
+ n += read(c2p[0], &buf[n], MAX_BUF - n);
+
+ buf[n-1] = 0;
+ if (n < 2) {
+ cerr << "Malformed response: " << buf << endl;
+ }
+ *response = Trim(buf, " \t\n");
+// cerr << "@RESPONSE: '" << *response << "'\n";
+}
+
+void ExternalMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
+ const std::vector<std::vector<WordID> >& refs,
+ SufficientStats* out) const {
+ eval_server->Evaluate(refs, hyp, &out->fields);
+}
+
+float ExternalMetric::ComputeScore(const SufficientStats& stats) const {
+ eval_server->ComputeScore(stats.fields);
+}
+
+ExternalMetric::ExternalMetric(const string& metric_name, const std::string& command) :
+ EvaluationMetric(metric_name),
+ eval_server(new NScoreServer(command)) {}
+
+ExternalMetric::~ExternalMetric() {
+ delete eval_server;
+}