diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/dtrain/Makefile.am | 5 | ||||
| -rw-r--r-- | training/dtrain/dtrain_net.cc | 23 | ||||
| -rw-r--r-- | training/dtrain/dtrain_net_interface.cc | 120 | ||||
| -rwxr-xr-x | training/dtrain/feed1.rb | 24 | ||||
| -rw-r--r-- | training/dtrain/sample_net.h | 61 | 
5 files changed, 222 insertions, 11 deletions
| diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 590218be..d93597c9 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = dtrain dtrain_net +bin_PROGRAMS = dtrain dtrain_net dtrain_net_interface  dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h  dtrain_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a @@ -6,5 +6,8 @@ dtrain_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mte  dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h  dtrain_net_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so +dtrain_net_interface_SOURCES = dtrain_net_interface.cc dtrain_net.h dtrain.h sample_net.h score.h update.h +dtrain_net_interface_LDADD   = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so +  AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc index 946b7587..306da957 100644 --- a/training/dtrain/dtrain_net.cc +++ b/training/dtrain/dtrain_net.cc @@ -67,16 +67,19 @@ main(int argc, char** argv)        } else {          vector<string> parts;          boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); -        source = parts[0]; -        parts.erase(parts.begin()); -        for (auto s: parts) { -          vector<WordID> r; -          vector<string> toks; -          boost::split(toks, s, boost::is_any_of(" ")); -          for (auto tok: toks) -            r.push_back(TD::Convert(tok)); -          refs.emplace_back(MakeNgrams(r, N)); -          rsz.push_back(r.size()); +        if (parts[0] == "act:translate") { +        } else { +          source = parts[0]; +          parts.erase(parts.begin()); +          for (auto s: parts) { +            vector<WordID> r; +            vector<string> toks; +            boost::split(toks, s, boost::is_any_of(" ")); +            for (auto tok: toks) +              r.push_back(TD::Convert(tok)); +            refs.emplace_back(MakeNgrams(r, N)); +            rsz.push_back(r.size()); +          }          }        }      } diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc new file mode 100644 index 00000000..f484b56b --- /dev/null +++ b/training/dtrain/dtrain_net_interface.cc @@ -0,0 +1,120 @@ +#include "dtrain_net.h" +#include "sample_net.h" +#include "score.h" +#include "update.h" + +#include <nanomsg/nn.h> +#include <nanomsg/pair.h> +#include "nn.hpp" + +using namespace dtrain; + +int +main(int argc, char** argv) +{ +  // get configuration +  po::variables_map conf; +  if (!dtrain_net_init(argc, argv, &conf)) +    exit(1); // something is wrong +  const size_t k              = conf["k"].as<size_t>(); +  const size_t N              = conf["N"].as<size_t>(); +  const weight_t margin       = conf["margin"].as<weight_t>(); +  const string master_addr    = conf["addr"].as<string>(); + +  // setup decoder +  register_feature_functions(); +  SetSilent(true); +  ReadFile f(conf["decoder_conf"].as<string>()); +  Decoder decoder(f.stream()); +  ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); + +  // weights +  vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); +  SparseVector<weight_t> lambdas, w_average; +  if (conf.count("input_weights")) { +    Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights); +    Weights::InitSparseVector(decoder_weights, &lambdas); +  } + +  cerr << _p4; +  // output configuration +  cerr << "dtrain_net" << endl << "Parameters:" << endl; +  cerr << setw(25) << "k " << k << endl; +  cerr << setw(25) << "N " << N << endl; +  cerr << setw(25) << "margin " << margin << endl; +  cerr << setw(25) << "decoder conf " << "'" +       << conf["decoder_conf"].as<string>() << "'" << endl; + +  // socket +  nn::socket sock(AF_SP, NN_PAIR); +  sock.connect(master_addr.c_str()); + +  size_t i = 0; +  while(true) +  { +    char *buf = NULL; +    string source; +    vector<Ngrams> refs; +    vector<size_t> rsz; +    bool next = true; +    size_t sz = sock.recv(&buf, NN_MSG, 0); +    if (buf) { +      const string in(buf, buf+sz); +      nn::freemsg(buf); +      if (in == "shutdown") { +        next = false; +      } else { +        vector<string> parts; +        boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); +        if (parts[0] == "act:translate") { +          cerr << "translating ..." << endl; +          lambdas.init_vector(&decoder_weights); +          observer->dont_score = true; +          decoder.Decode(parts[1], observer); +          observer->dont_score = false; +          vector<ScoredHyp>* samples = observer->GetSamples(); +          ostringstream os; +          PrintWordIDVec((*samples)[0].w, os); +          sock.send(os.str().c_str(), os.str().size()+1, 0); +          cerr << "done" << endl; +          continue; +        } else { +          cerr << "learning ..." << endl; +          source = parts[0]; +          parts.erase(parts.begin()); +          for (auto s: parts) { +            vector<WordID> r; +            vector<string> toks; +            boost::split(toks, s, boost::is_any_of(" ")); +            for (auto tok: toks) +              r.push_back(TD::Convert(tok)); +            refs.emplace_back(MakeNgrams(r, N)); +            rsz.push_back(r.size()); +          } +        } +      } +    } +     +    if (!next) +      break; + +    // decode +    lambdas.init_vector(&decoder_weights); +    observer->SetReference(refs, rsz); +    decoder.Decode(source, observer); +    vector<ScoredHyp>* samples = observer->GetSamples(); + +    // get pairs and update +    SparseVector<weight_t> updates; +    CollectUpdates(samples, updates, margin); +    lambdas.plus_eq_v_times_s(updates, 1.0); // fixme +    string s = "x"; +    sock.send(s.c_str(), s.size()+1, 0); +    i++; + +    cerr << "done" << endl; +  } // input loop + +  return 0; +} + diff --git a/training/dtrain/feed1.rb b/training/dtrain/feed1.rb new file mode 100755 index 00000000..a76e4c1e --- /dev/null +++ b/training/dtrain/feed1.rb @@ -0,0 +1,24 @@ +#!/usr/bin/env ruby + +require 'nanomsg' + +#port = ARGV[0] +port = 60667 +sock = NanoMsg::PairSocket.new +addr = "tcp://127.0.0.1:#{port}" +#addr = "ipc:///tmp/xxx.ipc" +sock.bind addr + +#puts sock.recv +while true +  line = STDIN.gets +  if !line +    sock.send 'shutdown' +    break +  end +  sock.send line.strip +  sleep 1 +  puts "got translation: #{sock.recv}\n\n" +  #sock.send "a=1 b=2" +end + diff --git a/training/dtrain/sample_net.h b/training/dtrain/sample_net.h new file mode 100644 index 00000000..497149d9 --- /dev/null +++ b/training/dtrain/sample_net.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_SAMPLE_NET_H_ +#define _DTRAIN_SAMPLE_NET_H_ + +#include "kbest.h" + +#include "score.h" + +namespace dtrain +{ + +struct ScoredKbest : public DecoderObserver +{ +  const size_t k_; +  size_t feature_count_, effective_sz_; +  vector<ScoredHyp> samples_; +  PerSentenceBleuScorer* scorer_; +  vector<Ngrams>* ref_ngs_; +  vector<size_t>* ref_ls_; +  bool dont_score; + +  ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : +    k_(k), scorer_(scorer), dont_score(false) {} + +  virtual void +  NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) +  { +    samples_.clear(); effective_sz_ = feature_count_ = 0; +    KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, +      KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); +    for (size_t i = 0; i < k_; ++i) { +      const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, +            KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = +              kbest.LazyKthBest(hg->nodes_.size() - 1, i); +      if (!d) break; +      ScoredHyp h; +      h.w = d->yield; +      h.f = d->feature_values; +      h.model = log(d->score); +      h.rank = i; +      if (!dont_score) +        h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); +      samples_.push_back(h); +      effective_sz_++; +      feature_count_ += h.f.size(); +    } +  } + +  vector<ScoredHyp>* GetSamples() { return &samples_; } +  inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls) +  { +    ref_ngs_ = &ngs; +    ref_ls_ = &ls; +  } +  inline size_t GetFeatureCount() { return feature_count_; } +  inline size_t GetSize() { return effective_sz_; } +}; + +} // namespace + +#endif + | 
