diff options
author | pks <pks@users.noreply.github.com> | 2019-05-12 20:10:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-12 20:10:37 +0200 |
commit | 4a13b41700f34c15c30b551f98dbea9cb41f67c3 (patch) | |
tree | 0218f41c350a626f5af9909d77406309fa873fdf /word-aligner | |
parent | e9268eb3dcd867f3baf67a7bb3d2aad56196ecde (diff) | |
parent | f64746ac87fc7338629b19de9fa2da0f03fa2790 (diff) |
Merge branch 'net' into origin/net
Diffstat (limited to 'word-aligner')
-rw-r--r-- | word-aligner/Makefile.am | 9 | ||||
-rw-r--r-- | word-aligner/net_fa.cc | 134 |
2 files changed, 141 insertions, 2 deletions
diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am index 071e4977..2bffa267 100644 --- a/word-aligner/Makefile.am +++ b/word-aligner/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = fast_align binderiv +bin_PROGRAMS = fast_align binderiv net_fa fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h fast_align_LDADD = ../utils/libutils.a @@ -7,6 +7,11 @@ fast_align_LDFLAGS = $(STATIC_FLAGS) binderiv_SOURCES = binderiv.cc binderiv_LDADD = ../utils/libutils.a +net_fa_SOURCES = net_fa.cc ttables.cc da.h ttables.h nn.hpp +net_fa_LDADD = ../utils/libutils.a +net_fa_LDFLAGS = $(STATIC_FLAGS) /srv/postedit/lib/nanomsg-0.5-beta/lib/libnanomsg.so + EXTRA_DIST = aligner.pl ortho-norm support makefiles stemmers -AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training +AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training -I/srv/postedit/lib/nanomsg-0.5-beta/include -I/srv/postedit/lib/cppnanomsg + diff --git a/word-aligner/net_fa.cc b/word-aligner/net_fa.cc new file mode 100644 index 00000000..f71c3f04 --- /dev/null +++ b/word-aligner/net_fa.cc @@ -0,0 +1,134 @@ +#include <iostream> +#include <cmath> +#include <utility> +#ifndef HAVE_OLD_CPP +# include <unordered_map> +#else +# include <tr1/unordered_map> +namespace std { using std::tr1::unordered_map; } +#endif + +#include <boost/functional/hash.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "m.h" +#include "corpus_tools.h" +#include "stringlib.h" +#include "filelib.h" +#include "ttables.h" +#include "tdict.h" +#include "da.h" + +#include <nanomsg/nn.h> +#include <nanomsg/pair.h> +#include "nn.hpp" + +namespace po = boost::program_options; +using namespace std; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") + ("mean_srclen_multiplier,m",po::value<double>()->default_value(1), "When --force_align, use this source length multiplier") + ("force_align,f",po::value<string>(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.") + ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") + ("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") + ("no_null_word,N","Do not generate from a null token") + ("sock_url", po::value<string>()->default_value("tcp://127.0.0.1:60666"), "Socket url."); + po::options_description clo("Command line options"); + clo.add_options() + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("force_align")==0) { + cerr << "Usage " << argv[0] << " [OPTIONS] -f params\n"; + cerr << dcmdline_options << endl; + return false; + } + + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + const double diagonal_tension = conf["diagonal_tension"].as<double>(); + const double mean_srclen_multiplier = conf["mean_srclen_multiplier"].as<double>(); + const bool use_null = (conf.count("no_null_word") == 0); + const bool favor_diagonal = conf.count("favor_diagonal"); + const double prob_align_null = conf["prob_align_null"].as<double>(); + const double prob_align_not_null = 1.0 - prob_align_null; + const WordID kNULL = TD::Convert("<eps>"); + ReadFile s2t_f(conf["force_align"].as<string>()); + TTable s2t; + s2t.DeserializeLogProbsFromText(s2t_f.stream()); + + nn::socket sock(AF_SP, NN_PAIR); + string url = conf["sock_url"].as<string>(); + sock.bind(url.c_str()); + int to = 100; + sock.setsockopt(NN_SOL_SOCKET, NN_RCVTIMEO, &to, sizeof (to)); + string hello = "hello"; + sock.send(hello.c_str(), hello.size()+1, 0); + + while (true) + { + char *buf = NULL; + size_t sz = sock.recv(&buf, NN_MSG, 0); + if (!buf) + continue; + string line(buf, buf+sz); + if (line == "shutdown") { + cerr << "[net_fa] shutting down" << endl; + string shutdown = "off"; + sock.send(shutdown.c_str(), shutdown.size()+1, 0); + break; + } + cerr << "[net_fa] got '" << line << "'" << endl; + nn::freemsg(buf); + vector<WordID> src, trg; + CorpusTools::ReadLine(line, &src, &trg); + double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + + // compute likelihood + ostringstream ss; + for (unsigned j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j]; + double sum = 0; + int a_j = 0; + double max_pat = 0; + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) + if (use_null) { + if (favor_diagonal) prob_a_i = prob_align_null; + max_pat = s2t.prob(kNULL, f_j) * prob_a_i; + sum += max_pat; + } + double az = 0; + if (favor_diagonal) + az = DiagonalAlignment::ComputeZ(j+1, trg.size(), src.size(), diagonal_tension) / prob_align_not_null; + for (unsigned i = 1; i <= src.size(); ++i) { + if (favor_diagonal) + prob_a_i = DiagonalAlignment::UnnormalizedProb(j + 1, i, trg.size(), src.size(), diagonal_tension) / az; + double pat = s2t.prob(src[i-1], f_j) * prob_a_i; + if (pat > max_pat) { max_pat = pat; a_j = i; } + sum += pat; + } + log_prob += log(sum); + if (a_j > 0) + ss << ' ' << (a_j - 1) << '-' << j; + } + string a = ss.str(); + cerr << "[net_fa] sending '" << a << "'" << endl; + sock.send(a.c_str(), a.size()+1, 0); + } // loop + + return 0; +} + |