#include #include #include #ifndef HAVE_OLD_CPP # include #else # include namespace std { using std::tr1::unordered_map; } #endif #include #include #include #include "m.h" #include "corpus_tools.h" #include "stringlib.h" #include "filelib.h" #include "ttables.h" #include "tdict.h" #include "da.h" #include #include #include "nn.hpp" namespace po = boost::program_options; using namespace std; bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") ("mean_srclen_multiplier,m",po::value()->default_value(1), "When --force_align, use this source length multiplier") ("force_align,f",po::value(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.") ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") ("prob_align_null", po::value()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") ("no_null_word,N","Do not generate from a null token") ("sock_url", po::value()->default_value("tcp://127.0.0.1:60666"), "Socket url."); po::options_description clo("Command line options"); clo.add_options() ("help,h", "Print this help message and exit"); po::options_description dconfig_options, dcmdline_options; dconfig_options.add(opts); dcmdline_options.add(opts).add(clo); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); po::notify(*conf); if (conf->count("help") || conf->count("force_align")==0) { cerr << "Usage " << argv[0] << " [OPTIONS] -f params\n"; cerr << dcmdline_options << endl; return false; } return true; } int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; const double diagonal_tension = conf["diagonal_tension"].as(); const double mean_srclen_multiplier = conf["mean_srclen_multiplier"].as(); const bool use_null = (conf.count("no_null_word") == 0); const bool favor_diagonal = conf.count("favor_diagonal"); const double prob_align_null = conf["prob_align_null"].as(); const double prob_align_not_null = 1.0 - prob_align_null; const WordID kNULL = TD::Convert(""); ReadFile s2t_f(conf["force_align"].as()); TTable s2t; s2t.DeserializeLogProbsFromText(s2t_f.stream()); nn::socket sock(AF_SP, NN_PAIR); string url = conf["sock_url"].as(); sock.bind(url.c_str()); int to = 100; sock.setsockopt(NN_SOL_SOCKET, NN_RCVTIMEO, &to, sizeof (to)); string hello = "hello"; sock.send(hello.c_str(), hello.size()+1, 0); while (true) { char *buf = NULL; size_t sz = sock.recv(&buf, NN_MSG, 0); if (!buf) continue; string line(buf, buf+sz); if (line == "shutdown") { cerr << "[net_fa] shutting down" << endl; string shutdown = "off"; sock.send(shutdown.c_str(), shutdown.size()+1, 0); break; } cerr << "[net_fa] got '" << line << "'" << endl; nn::freemsg(buf); vector src, trg; CorpusTools::ReadLine(line, &src, &trg); double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); // compute likelihood ostringstream ss; for (unsigned j = 0; j < trg.size(); ++j) { const WordID& f_j = trg[j]; double sum = 0; int a_j = 0; double max_pat = 0; double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) if (use_null) { if (favor_diagonal) prob_a_i = prob_align_null; max_pat = s2t.prob(kNULL, f_j) * prob_a_i; sum += max_pat; } double az = 0; if (favor_diagonal) az = DiagonalAlignment::ComputeZ(j+1, trg.size(), src.size(), diagonal_tension) / prob_align_not_null; for (unsigned i = 1; i <= src.size(); ++i) { if (favor_diagonal) prob_a_i = DiagonalAlignment::UnnormalizedProb(j + 1, i, trg.size(), src.size(), diagonal_tension) / az; double pat = s2t.prob(src[i-1], f_j) * prob_a_i; if (pat > max_pat) { max_pat = pat; a_j = i; } sum += pat; } log_prob += log(sum); if (a_j > 0) ss << ' ' << (a_j - 1) << '-' << j; } string a = ss.str(); cerr << "[net_fa] sending '" << a << "'" << endl; sock.send(a.c_str(), a.size()+1, 0); } // loop return 0; }