summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-05-08 21:43:45 +0200
committerPatrick Simianer <p@simianer.de>2015-05-08 21:43:45 +0200
commitd0b8fa29b83e6424e6d5848dbc42734b03896304 (patch)
treeb1f2529183e99b92b20f972d3c5e6739ad855adf
parentf678b442e8a0c2e685652d2b7006ccccce989c81 (diff)
parent64aac199c4a8821772dfaaaa9d162f4a3f5bf121 (diff)
Merge branch 'net' of github.com:pks/cdec-dtrain into net
-rw-r--r--.gitignore1
-rw-r--r--training/dtrain/Makefile.am5
-rw-r--r--training/dtrain/dtrain_net.cc23
-rw-r--r--training/dtrain/dtrain_net_interface.cc120
-rwxr-xr-xtraining/dtrain/feed1.rb24
-rw-r--r--training/dtrain/sample_net.h61
6 files changed, 223 insertions, 11 deletions
diff --git a/.gitignore b/.gitignore
index 0bfe57c2..c4b26bf5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -217,6 +217,7 @@ training/dpmert/sentclient
training/dpmert/sentserver
training/dtrain/dtrain
training/dtrain/dtrain_net
+training/dtrain/dtrain_net_interface
training/latent_svm/latent_svm
training/minrisk/minrisk_optimize
training/mira/ada_opt_sm
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index 590218be..d93597c9 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,4 +1,4 @@
-bin_PROGRAMS = dtrain dtrain_net
+bin_PROGRAMS = dtrain dtrain_net dtrain_net_interface
dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h
dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
@@ -6,5 +6,8 @@ dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mte
dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h
dtrain_net_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so
+dtrain_net_interface_SOURCES = dtrain_net_interface.cc dtrain_net.h dtrain.h sample_net.h score.h update.h
+dtrain_net_interface_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so
+
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc
index 946b7587..306da957 100644
--- a/training/dtrain/dtrain_net.cc
+++ b/training/dtrain/dtrain_net.cc
@@ -67,16 +67,19 @@ main(int argc, char** argv)
} else {
vector<string> parts;
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
- source = parts[0];
- parts.erase(parts.begin());
- for (auto s: parts) {
- vector<WordID> r;
- vector<string> toks;
- boost::split(toks, s, boost::is_any_of(" "));
- for (auto tok: toks)
- r.push_back(TD::Convert(tok));
- refs.emplace_back(MakeNgrams(r, N));
- rsz.push_back(r.size());
+ if (parts[0] == "act:translate") {
+ } else {
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector<WordID> r;
+ vector<string> toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
}
}
}
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
new file mode 100644
index 00000000..f484b56b
--- /dev/null
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -0,0 +1,120 @@
+#include "dtrain_net.h"
+#include "sample_net.h"
+#include "score.h"
+#include "update.h"
+
+#include <nanomsg/nn.h>
+#include <nanomsg/pair.h>
+#include "nn.hpp"
+
+using namespace dtrain;
+
+int
+main(int argc, char** argv)
+{
+ // get configuration
+ po::variables_map conf;
+ if (!dtrain_net_init(argc, argv, &conf))
+ exit(1); // something is wrong
+ const size_t k = conf["k"].as<size_t>();
+ const size_t N = conf["N"].as<size_t>();
+ const weight_t margin = conf["margin"].as<weight_t>();
+ const string master_addr = conf["addr"].as<string>();
+
+ // setup decoder
+ register_feature_functions();
+ SetSilent(true);
+ ReadFile f(conf["decoder_conf"].as<string>());
+ Decoder decoder(f.stream());
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+
+ // weights
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
+ SparseVector<weight_t> lambdas, w_average;
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
+ }
+
+ cerr << _p4;
+ // output configuration
+ cerr << "dtrain_net" << endl << "Parameters:" << endl;
+ cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "margin " << margin << endl;
+ cerr << setw(25) << "decoder conf " << "'"
+ << conf["decoder_conf"].as<string>() << "'" << endl;
+
+ // socket
+ nn::socket sock(AF_SP, NN_PAIR);
+ sock.connect(master_addr.c_str());
+
+ size_t i = 0;
+ while(true)
+ {
+ char *buf = NULL;
+ string source;
+ vector<Ngrams> refs;
+ vector<size_t> rsz;
+ bool next = true;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (buf) {
+ const string in(buf, buf+sz);
+ nn::freemsg(buf);
+ if (in == "shutdown") {
+ next = false;
+ } else {
+ vector<string> parts;
+ boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
+ if (parts[0] == "act:translate") {
+ cerr << "translating ..." << endl;
+ lambdas.init_vector(&decoder_weights);
+ observer->dont_score = true;
+ decoder.Decode(parts[1], observer);
+ observer->dont_score = false;
+ vector<ScoredHyp>* samples = observer->GetSamples();
+ ostringstream os;
+ PrintWordIDVec((*samples)[0].w, os);
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ cerr << "done" << endl;
+ continue;
+ } else {
+ cerr << "learning ..." << endl;
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector<WordID> r;
+ vector<string> toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
+ }
+ }
+ }
+
+ if (!next)
+ break;
+
+ // decode
+ lambdas.init_vector(&decoder_weights);
+ observer->SetReference(refs, rsz);
+ decoder.Decode(source, observer);
+ vector<ScoredHyp>* samples = observer->GetSamples();
+
+ // get pairs and update
+ SparseVector<weight_t> updates;
+ CollectUpdates(samples, updates, margin);
+ lambdas.plus_eq_v_times_s(updates, 1.0); // fixme
+ string s = "x";
+ sock.send(s.c_str(), s.size()+1, 0);
+ i++;
+
+ cerr << "done" << endl;
+ } // input loop
+
+ return 0;
+}
+
diff --git a/training/dtrain/feed1.rb b/training/dtrain/feed1.rb
new file mode 100755
index 00000000..a76e4c1e
--- /dev/null
+++ b/training/dtrain/feed1.rb
@@ -0,0 +1,24 @@
+#!/usr/bin/env ruby
+
+require 'nanomsg'
+
+#port = ARGV[0]
+port = 60667
+sock = NanoMsg::PairSocket.new
+addr = "tcp://127.0.0.1:#{port}"
+#addr = "ipc:///tmp/xxx.ipc"
+sock.bind addr
+
+#puts sock.recv
+while true
+ line = STDIN.gets
+ if !line
+ sock.send 'shutdown'
+ break
+ end
+ sock.send line.strip
+ sleep 1
+ puts "got translation: #{sock.recv}\n\n"
+ #sock.send "a=1 b=2"
+end
+
diff --git a/training/dtrain/sample_net.h b/training/dtrain/sample_net.h
new file mode 100644
index 00000000..497149d9
--- /dev/null
+++ b/training/dtrain/sample_net.h
@@ -0,0 +1,61 @@
+#ifndef _DTRAIN_SAMPLE_NET_H_
+#define _DTRAIN_SAMPLE_NET_H_
+
+#include "kbest.h"
+
+#include "score.h"
+
+namespace dtrain
+{
+
+struct ScoredKbest : public DecoderObserver
+{
+ const size_t k_;
+ size_t feature_count_, effective_sz_;
+ vector<ScoredHyp> samples_;
+ PerSentenceBleuScorer* scorer_;
+ vector<Ngrams>* ref_ngs_;
+ vector<size_t>* ref_ls_;
+ bool dont_score;
+
+ ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
+ k_(k), scorer_(scorer), dont_score(false) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+ {
+ samples_.clear(); effective_sz_ = feature_count_ = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
+ for (size_t i = 0; i < k_; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ h.rank = i;
+ if (!dont_score)
+ h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
+ samples_.push_back(h);
+ effective_sz_++;
+ feature_count_ += h.f.size();
+ }
+ }
+
+ vector<ScoredHyp>* GetSamples() { return &samples_; }
+ inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls)
+ {
+ ref_ngs_ = &ngs;
+ ref_ls_ = &ls;
+ }
+ inline size_t GetFeatureCount() { return feature_count_; }
+ inline size_t GetSize() { return effective_sz_; }
+};
+
+} // namespace
+
+#endif
+