From 64aac199c4a8821772dfaaaa9d162f4a3f5bf121 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 2 Apr 2015 14:06:02 +0200 Subject: enable translation or learning: dtrain_net_interface --- training/dtrain/Makefile.am | 5 +- training/dtrain/dtrain_net.cc | 23 +++--- training/dtrain/dtrain_net_interface.cc | 120 ++++++++++++++++++++++++++++++++ training/dtrain/feed1.rb | 24 +++++++ training/dtrain/sample_net.h | 61 ++++++++++++++++ 5 files changed, 222 insertions(+), 11 deletions(-) create mode 100644 training/dtrain/dtrain_net_interface.cc create mode 100755 training/dtrain/feed1.rb create mode 100644 training/dtrain/sample_net.h (limited to 'training') diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 590218be..d93597c9 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = dtrain dtrain_net +bin_PROGRAMS = dtrain dtrain_net dtrain_net_interface dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a @@ -6,5 +6,8 @@ dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mte dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h dtrain_net_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so +dtrain_net_interface_SOURCES = dtrain_net_interface.cc dtrain_net.h dtrain.h sample_net.h score.h update.h +dtrain_net_interface_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so + AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc index 946b7587..306da957 100644 --- a/training/dtrain/dtrain_net.cc +++ b/training/dtrain/dtrain_net.cc @@ -67,16 +67,19 @@ main(int argc, char** argv) } else { vector parts; boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); - source = parts[0]; - parts.erase(parts.begin()); - for (auto s: parts) { - vector r; - vector toks; - boost::split(toks, s, boost::is_any_of(" ")); - for (auto tok: toks) - r.push_back(TD::Convert(tok)); - refs.emplace_back(MakeNgrams(r, N)); - rsz.push_back(r.size()); + if (parts[0] == "act:translate") { + } else { + source = parts[0]; + parts.erase(parts.begin()); + for (auto s: parts) { + vector r; + vector toks; + boost::split(toks, s, boost::is_any_of(" ")); + for (auto tok: toks) + r.push_back(TD::Convert(tok)); + refs.emplace_back(MakeNgrams(r, N)); + rsz.push_back(r.size()); + } } } } diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc new file mode 100644 index 00000000..f484b56b --- /dev/null +++ b/training/dtrain/dtrain_net_interface.cc @@ -0,0 +1,120 @@ +#include "dtrain_net.h" +#include "sample_net.h" +#include "score.h" +#include "update.h" + +#include +#include +#include "nn.hpp" + +using namespace dtrain; + +int +main(int argc, char** argv) +{ + // get configuration + po::variables_map conf; + if (!dtrain_net_init(argc, argv, &conf)) + exit(1); // something is wrong + const size_t k = conf["k"].as(); + const size_t N = conf["N"].as(); + const weight_t margin = conf["margin"].as(); + const string master_addr = conf["addr"].as(); + + // setup decoder + register_feature_functions(); + SetSilent(true); + ReadFile f(conf["decoder_conf"].as()); + Decoder decoder(f.stream()); + ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); + + // weights + vector& decoder_weights = decoder.CurrentWeightVector(); + SparseVector lambdas, w_average; + if (conf.count("input_weights")) { + Weights::InitFromFile(conf["input_weights"].as(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); + } + + cerr << _p4; + // output configuration + cerr << "dtrain_net" << endl << "Parameters:" << endl; + cerr << setw(25) << "k " << k << endl; + cerr << setw(25) << "N " << N << endl; + cerr << setw(25) << "margin " << margin << endl; + cerr << setw(25) << "decoder conf " << "'" + << conf["decoder_conf"].as() << "'" << endl; + + // socket + nn::socket sock(AF_SP, NN_PAIR); + sock.connect(master_addr.c_str()); + + size_t i = 0; + while(true) + { + char *buf = NULL; + string source; + vector refs; + vector rsz; + bool next = true; + size_t sz = sock.recv(&buf, NN_MSG, 0); + if (buf) { + const string in(buf, buf+sz); + nn::freemsg(buf); + if (in == "shutdown") { + next = false; + } else { + vector parts; + boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); + if (parts[0] == "act:translate") { + cerr << "translating ..." << endl; + lambdas.init_vector(&decoder_weights); + observer->dont_score = true; + decoder.Decode(parts[1], observer); + observer->dont_score = false; + vector* samples = observer->GetSamples(); + ostringstream os; + PrintWordIDVec((*samples)[0].w, os); + sock.send(os.str().c_str(), os.str().size()+1, 0); + cerr << "done" << endl; + continue; + } else { + cerr << "learning ..." << endl; + source = parts[0]; + parts.erase(parts.begin()); + for (auto s: parts) { + vector r; + vector toks; + boost::split(toks, s, boost::is_any_of(" ")); + for (auto tok: toks) + r.push_back(TD::Convert(tok)); + refs.emplace_back(MakeNgrams(r, N)); + rsz.push_back(r.size()); + } + } + } + } + + if (!next) + break; + + // decode + lambdas.init_vector(&decoder_weights); + observer->SetReference(refs, rsz); + decoder.Decode(source, observer); + vector* samples = observer->GetSamples(); + + // get pairs and update + SparseVector updates; + CollectUpdates(samples, updates, margin); + lambdas.plus_eq_v_times_s(updates, 1.0); // fixme + string s = "x"; + sock.send(s.c_str(), s.size()+1, 0); + i++; + + cerr << "done" << endl; + } // input loop + + return 0; +} + diff --git a/training/dtrain/feed1.rb b/training/dtrain/feed1.rb new file mode 100755 index 00000000..a76e4c1e --- /dev/null +++ b/training/dtrain/feed1.rb @@ -0,0 +1,24 @@ +#!/usr/bin/env ruby + +require 'nanomsg' + +#port = ARGV[0] +port = 60667 +sock = NanoMsg::PairSocket.new +addr = "tcp://127.0.0.1:#{port}" +#addr = "ipc:///tmp/xxx.ipc" +sock.bind addr + +#puts sock.recv +while true + line = STDIN.gets + if !line + sock.send 'shutdown' + break + end + sock.send line.strip + sleep 1 + puts "got translation: #{sock.recv}\n\n" + #sock.send "a=1 b=2" +end + diff --git a/training/dtrain/sample_net.h b/training/dtrain/sample_net.h new file mode 100644 index 00000000..497149d9 --- /dev/null +++ b/training/dtrain/sample_net.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_SAMPLE_NET_H_ +#define _DTRAIN_SAMPLE_NET_H_ + +#include "kbest.h" + +#include "score.h" + +namespace dtrain +{ + +struct ScoredKbest : public DecoderObserver +{ + const size_t k_; + size_t feature_count_, effective_sz_; + vector samples_; + PerSentenceBleuScorer* scorer_; + vector* ref_ngs_; + vector* ref_ls_; + bool dont_score; + + ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : + k_(k), scorer_(scorer), dont_score(false) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + { + samples_.clear(); effective_sz_ = feature_count_ = 0; + KBest::KBestDerivations, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); + for (size_t i = 0; i < k_; ++i) { + const KBest::KBestDerivations, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + if (!dont_score) + h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); + samples_.push_back(h); + effective_sz_++; + feature_count_ += h.f.size(); + } + } + + vector* GetSamples() { return &samples_; } + inline void SetReference(vector& ngs, vector& ls) + { + ref_ngs_ = &ngs; + ref_ls_ = &ls; + } + inline size_t GetFeatureCount() { return feature_count_; } + inline size_t GetSize() { return effective_sz_; } +}; + +} // namespace + +#endif + -- cgit v1.2.3