summaryrefslogtreecommitdiff
path: root/word-aligner
diff options
context:
space:
mode:
authorpks <pks@users.noreply.github.com>2019-05-12 20:10:37 +0200
committerGitHub <noreply@github.com>2019-05-12 20:10:37 +0200
commit4a13b41700f34c15c30b551f98dbea9cb41f67c3 (patch)
tree0218f41c350a626f5af9909d77406309fa873fdf /word-aligner
parente9268eb3dcd867f3baf67a7bb3d2aad56196ecde (diff)
parentf64746ac87fc7338629b19de9fa2da0f03fa2790 (diff)
Merge branch 'net' into origin/net
Diffstat (limited to 'word-aligner')
-rw-r--r--word-aligner/Makefile.am9
-rw-r--r--word-aligner/net_fa.cc134
2 files changed, 141 insertions, 2 deletions
diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am
index 071e4977..2bffa267 100644
--- a/word-aligner/Makefile.am
+++ b/word-aligner/Makefile.am
@@ -1,4 +1,4 @@
-bin_PROGRAMS = fast_align binderiv
+bin_PROGRAMS = fast_align binderiv net_fa
fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h
fast_align_LDADD = ../utils/libutils.a
@@ -7,6 +7,11 @@ fast_align_LDFLAGS = $(STATIC_FLAGS)
binderiv_SOURCES = binderiv.cc
binderiv_LDADD = ../utils/libutils.a
+net_fa_SOURCES = net_fa.cc ttables.cc da.h ttables.h nn.hpp
+net_fa_LDADD = ../utils/libutils.a
+net_fa_LDFLAGS = $(STATIC_FLAGS) /srv/postedit/lib/nanomsg-0.5-beta/lib/libnanomsg.so
+
EXTRA_DIST = aligner.pl ortho-norm support makefiles stemmers
-AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training
+AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training -I/srv/postedit/lib/nanomsg-0.5-beta/include -I/srv/postedit/lib/cppnanomsg
+
diff --git a/word-aligner/net_fa.cc b/word-aligner/net_fa.cc
new file mode 100644
index 00000000..f71c3f04
--- /dev/null
+++ b/word-aligner/net_fa.cc
@@ -0,0 +1,134 @@
+#include <iostream>
+#include <cmath>
+#include <utility>
+#ifndef HAVE_OLD_CPP
+# include <unordered_map>
+#else
+# include <tr1/unordered_map>
+namespace std { using std::tr1::unordered_map; }
+#endif
+
+#include <boost/functional/hash.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "m.h"
+#include "corpus_tools.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "ttables.h"
+#include "tdict.h"
+#include "da.h"
+
+#include <nanomsg/nn.h>
+#include <nanomsg/pair.h>
+#include "nn.hpp"
+
+namespace po = boost::program_options;
+using namespace std;
+
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)")
+ ("mean_srclen_multiplier,m",po::value<double>()->default_value(1), "When --force_align, use this source length multiplier")
+ ("force_align,f",po::value<string>(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.")
+ ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal")
+ ("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?")
+ ("no_null_word,N","Do not generate from a null token")
+ ("sock_url", po::value<string>()->default_value("tcp://127.0.0.1:60666"), "Socket url.");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ po::notify(*conf);
+
+ if (conf->count("help") || conf->count("force_align")==0) {
+ cerr << "Usage " << argv[0] << " [OPTIONS] -f params\n";
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+
+ return true;
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
+ const double diagonal_tension = conf["diagonal_tension"].as<double>();
+ const double mean_srclen_multiplier = conf["mean_srclen_multiplier"].as<double>();
+ const bool use_null = (conf.count("no_null_word") == 0);
+ const bool favor_diagonal = conf.count("favor_diagonal");
+ const double prob_align_null = conf["prob_align_null"].as<double>();
+ const double prob_align_not_null = 1.0 - prob_align_null;
+ const WordID kNULL = TD::Convert("<eps>");
+ ReadFile s2t_f(conf["force_align"].as<string>());
+ TTable s2t;
+ s2t.DeserializeLogProbsFromText(s2t_f.stream());
+
+ nn::socket sock(AF_SP, NN_PAIR);
+ string url = conf["sock_url"].as<string>();
+ sock.bind(url.c_str());
+ int to = 100;
+ sock.setsockopt(NN_SOL_SOCKET, NN_RCVTIMEO, &to, sizeof (to));
+ string hello = "hello";
+ sock.send(hello.c_str(), hello.size()+1, 0);
+
+ while (true)
+ {
+ char *buf = NULL;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (!buf)
+ continue;
+ string line(buf, buf+sz);
+ if (line == "shutdown") {
+ cerr << "[net_fa] shutting down" << endl;
+ string shutdown = "off";
+ sock.send(shutdown.c_str(), shutdown.size()+1, 0);
+ break;
+ }
+ cerr << "[net_fa] got '" << line << "'" << endl;
+ nn::freemsg(buf);
+ vector<WordID> src, trg;
+ CorpusTools::ReadLine(line, &src, &trg);
+ double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
+
+ // compute likelihood
+ ostringstream ss;
+ for (unsigned j = 0; j < trg.size(); ++j) {
+ const WordID& f_j = trg[j];
+ double sum = 0;
+ int a_j = 0;
+ double max_pat = 0;
+ double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1)
+ if (use_null) {
+ if (favor_diagonal) prob_a_i = prob_align_null;
+ max_pat = s2t.prob(kNULL, f_j) * prob_a_i;
+ sum += max_pat;
+ }
+ double az = 0;
+ if (favor_diagonal)
+ az = DiagonalAlignment::ComputeZ(j+1, trg.size(), src.size(), diagonal_tension) / prob_align_not_null;
+ for (unsigned i = 1; i <= src.size(); ++i) {
+ if (favor_diagonal)
+ prob_a_i = DiagonalAlignment::UnnormalizedProb(j + 1, i, trg.size(), src.size(), diagonal_tension) / az;
+ double pat = s2t.prob(src[i-1], f_j) * prob_a_i;
+ if (pat > max_pat) { max_pat = pat; a_j = i; }
+ sum += pat;
+ }
+ log_prob += log(sum);
+ if (a_j > 0)
+ ss << ' ' << (a_j - 1) << '-' << j;
+ }
+ string a = ss.str();
+ cerr << "[net_fa] sending '" << a << "'" << endl;
+ sock.send(a.c_str(), a.size()+1, 0);
+ } // loop
+
+ return 0;
+}
+