summaryrefslogtreecommitdiff
path: root/training/utils
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils')
-rw-r--r--training/utils/Makefile.am37
-rw-r--r--training/utils/candidate_set.cc169
-rw-r--r--training/utils/candidate_set.h60
-rwxr-xr-xtraining/utils/decode-and-evaluate.pl246
-rw-r--r--training/utils/entropy.cc41
-rw-r--r--training/utils/entropy.h22
-rw-r--r--training/utils/grammar_convert.cc348
-rw-r--r--training/utils/lbfgs.h1459
-rw-r--r--training/utils/lbfgs_test.cc117
-rw-r--r--training/utils/libcall.pl71
-rw-r--r--training/utils/online_optimizer.cc16
-rw-r--r--training/utils/online_optimizer.h129
-rw-r--r--training/utils/optimize.cc102
-rw-r--r--training/utils/optimize.h92
-rw-r--r--training/utils/optimize_test.cc118
-rwxr-xr-xtraining/utils/parallelize.pl423
-rw-r--r--training/utils/risk.cc45
-rw-r--r--training/utils/risk.h26
-rw-r--r--training/utils/sentclient.c76
-rw-r--r--training/utils/sentserver.c515
-rw-r--r--training/utils/sentserver.h6
21 files changed, 4118 insertions, 0 deletions
diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am
new file mode 100644
index 00000000..d708a9f5
--- /dev/null
+++ b/training/utils/Makefile.am
@@ -0,0 +1,37 @@
+noinst_LIBRARIES = libtraining_utils.a
+
+bin_PROGRAMS = \
+ sentserver \
+ sentclient \
+ grammar_convert
+
+noinst_PROGRAMS = \
+ lbfgs_test \
+ optimize_test
+
+sentserver_SOURCES = sentserver.c
+sentserver_LDFLAGS = -pthread
+
+sentclient_SOURCES = sentclient.c
+sentclient_LDFLAGS = -pthread
+
+TESTS = lbfgs_test optimize_test
+
+libtraining_utils_a_SOURCES = \
+ candidate_set.cc \
+ entropy.cc \
+ optimize.cc \
+ online_optimizer.cc \
+ risk.cc
+
+optimize_test_SOURCES = optimize_test.cc
+optimize_test_LDADD = libtraining_utils.a $(top_srcdir)/utils/libutils.a
+
+grammar_convert_SOURCES = grammar_convert.cc
+grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a
+
+lbfgs_test_SOURCES = lbfgs_test.cc
+lbfgs_test_LDADD = $(top_srcdir)/utils/libutils.a
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I$(top_srcdir)/klm
+
diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc
new file mode 100644
index 00000000..087efec3
--- /dev/null
+++ b/training/utils/candidate_set.cc
@@ -0,0 +1,169 @@
+#include "candidate_set.h"
+
+#include <tr1/unordered_set>
+
+#include <boost/functional/hash.hpp>
+
+#include "verbose.h"
+#include "ns.h"
+#include "filelib.h"
+#include "wordid.h"
+#include "tdict.h"
+#include "hg.h"
+#include "kbest.h"
+#include "viterbi.h"
+
+using namespace std;
+
+namespace training {
+
+struct ApproxVectorHasher {
+ static const size_t MASK = 0xFFFFFFFFull;
+ union UType {
+ double f; // leave as double
+ size_t i;
+ };
+ static inline double round(const double x) {
+ UType t;
+ t.f = x;
+ size_t r = t.i & MASK;
+ if ((r << 1) > MASK)
+ t.i += MASK - r + 1;
+ else
+ t.i &= (1ull - MASK);
+ return t.f;
+ }
+ size_t operator()(const SparseVector<double>& x) const {
+ size_t h = 0x573915839;
+ for (SparseVector<double>::const_iterator it = x.begin(); it != x.end(); ++it) {
+ UType t;
+ t.f = it->second;
+ if (t.f) {
+ size_t z = (t.i >> 32);
+ boost::hash_combine(h, it->first);
+ boost::hash_combine(h, z);
+ }
+ }
+ return h;
+ }
+};
+
+struct ApproxVectorEquals {
+ bool operator()(const SparseVector<double>& a, const SparseVector<double>& b) const {
+ SparseVector<double>::const_iterator bit = b.begin();
+ for (SparseVector<double>::const_iterator ait = a.begin(); ait != a.end(); ++ait) {
+ if (bit == b.end() ||
+ ait->first != bit->first ||
+ ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second))
+ return false;
+ ++bit;
+ }
+ if (bit != b.end()) return false;
+ return true;
+ }
+};
+
+struct CandidateCompare {
+ bool operator()(const Candidate& a, const Candidate& b) const {
+ ApproxVectorEquals eq;
+ return (a.ewords == b.ewords && eq(a.fmap,b.fmap));
+ }
+};
+
+struct CandidateHasher {
+ size_t operator()(const Candidate& x) const {
+ boost::hash<vector<WordID> > hhasher;
+ ApproxVectorHasher vhasher;
+ size_t ha = hhasher(x.ewords);
+ boost::hash_combine(ha, vhasher(x.fmap));
+ return ha;
+ }
+};
+
+static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
+ SparseVector<double>& x = *out;
+ size_t last_start = cur;
+ size_t last_comma = string::npos;
+ while(cur <= line.size()) {
+ if (line[cur] == ' ' || cur == line.size()) {
+ if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) {
+ cerr << "[ERROR] " << line << endl << " position = " << cur << endl;
+ exit(1);
+ }
+ const int fid = FD::Convert(line.substr(last_start, last_comma - last_start));
+ if (cur < line.size()) line[cur] = 0;
+ const double val = strtod(&line[last_comma + 1], NULL);
+ x.set_value(fid, val);
+
+ last_comma = string::npos;
+ last_start = cur+1;
+ } else {
+ if (line[cur] == '=')
+ last_comma = cur;
+ }
+ ++cur;
+ }
+}
+
+void CandidateSet::WriteToFile(const string& file) const {
+ WriteFile wf(file);
+ ostream& out = *wf.stream();
+ out.precision(10);
+ string ss;
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ out << TD::GetString(cs[i].ewords) << endl;
+ out << cs[i].fmap << endl;
+ cs[i].eval_feats.Encode(&ss);
+ out << ss << endl;
+ }
+}
+
+void CandidateSet::ReadFromFile(const string& file) {
+ if(!SILENT) cerr << "Reading candidates from " << file << endl;
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string cand;
+ string feats;
+ string ss;
+ while(getline(in, cand)) {
+ getline(in, feats);
+ getline(in, ss);
+ assert(in);
+ cs.push_back(Candidate());
+ TD::ConvertSentence(cand, &cs.back().ewords);
+ ParseSparseVector(feats, 0, &cs.back().fmap);
+ cs.back().eval_feats = SufficientStats(ss);
+ }
+ if(!SILENT) cerr << " read " << cs.size() << " candidates\n";
+}
+
+void CandidateSet::Dedup() {
+ if(!SILENT) cerr << "Dedup in=" << cs.size();
+ tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u;
+ while(cs.size() > 0) {
+ u.insert(cs.back());
+ cs.pop_back();
+ }
+ tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin();
+ while (it != u.end()) {
+ cs.push_back(*it);
+ it = u.erase(it);
+ }
+ if(!SILENT) cerr << " out=" << cs.size() << endl;
+}
+
+void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size);
+
+ for (unsigned i = 0; i < kbest_size; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().eval_feats);
+ }
+ Dedup();
+}
+
+}
diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h
new file mode 100644
index 00000000..9d326ed0
--- /dev/null
+++ b/training/utils/candidate_set.h
@@ -0,0 +1,60 @@
+#ifndef _CANDIDATE_SET_H_
+#define _CANDIDATE_SET_H_
+
+#include <vector>
+#include <algorithm>
+
+#include "ns.h"
+#include "wordid.h"
+#include "sparse_vector.h"
+
+class Hypergraph;
+
+namespace training {
+
+struct Candidate {
+ Candidate() {}
+ Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) :
+ ewords(e),
+ fmap(fm) {}
+ Candidate(const std::vector<WordID>& e,
+ const SparseVector<double>& fm,
+ const SegmentEvaluator& se) :
+ ewords(e),
+ fmap(fm) {
+ se.Evaluate(ewords, &eval_feats);
+ }
+
+ void swap(Candidate& other) {
+ eval_feats.swap(other.eval_feats);
+ ewords.swap(other.ewords);
+ fmap.swap(other.fmap);
+ }
+
+ std::vector<WordID> ewords;
+ SparseVector<double> fmap;
+ SufficientStats eval_feats;
+};
+
+// represents some kind of collection of translation candidates, e.g.
+// aggregated k-best lists, sample lists, etc.
+class CandidateSet {
+ public:
+ CandidateSet() {}
+ inline size_t size() const { return cs.size(); }
+ const Candidate& operator[](size_t i) const { return cs[i]; }
+
+ void ReadFromFile(const std::string& file);
+ void WriteToFile(const std::string& file) const;
+ void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
+ // TODO add code to do unique k-best
+ // TODO add code to draw k samples
+
+ private:
+ void Dedup();
+ std::vector<Candidate> cs;
+};
+
+}
+
+#endif
diff --git a/training/utils/decode-and-evaluate.pl b/training/utils/decode-and-evaluate.pl
new file mode 100755
index 00000000..1a332c08
--- /dev/null
+++ b/training/utils/decode-and-evaluate.pl
@@ -0,0 +1,246 @@
+#!/usr/bin/env perl
+use strict;
+my @ORIG_ARGV=@ARGV;
+use Cwd qw(getcwd);
+my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; }
+
+# Skip local config (used for distributing jobs) if we're running in local-only mode
+use LocalConfig;
+use Getopt::Long;
+use File::Basename qw(basename);
+my $QSUB_CMD = qsub_args(mert_memory());
+
+require "libcall.pl";
+
+# Default settings
+my $default_jobs = env_default_jobs();
+my $bin_dir = $SCRIPT_DIR;
+die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir;
+my $FAST_SCORE="$bin_dir/../../mteval/fast_score";
+die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE;
+my $parallelize = "$bin_dir/parallelize.pl";
+my $libcall = "$bin_dir/libcall.pl";
+my $sentserver = "$bin_dir/sentserver";
+my $sentclient = "$bin_dir/sentclient";
+my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm";
+
+my $SCORER = $FAST_SCORE;
+my $cdec = "$bin_dir/../../decoder/cdec";
+die "Can't find decoder in $cdec" unless -x $cdec;
+die "Can't find $parallelize" unless -x $parallelize;
+die "Can't find $libcall" unless -e $libcall;
+my $decoder = $cdec;
+my $jobs = $default_jobs; # number of decode nodes
+my $pmem = "9g";
+my $help = 0;
+my $config;
+my $test_set;
+my $weights;
+my $use_make = 1;
+my $useqsub;
+my $cpbin=1;
+# Process command-line options
+if (GetOptions(
+ "jobs=i" => \$jobs,
+ "help" => \$help,
+ "qsub" => \$useqsub,
+ "input=s" => \$test_set,
+ "config=s" => \$config,
+ "weights=s" => \$weights,
+) == 0 || @ARGV!=0 || $help) {
+ print_help();
+ exit;
+}
+
+if ($useqsub) {
+ $use_make = 0;
+ die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub();
+}
+
+my @missing_args = ();
+
+if (!defined $test_set) { push @missing_args, "--input"; }
+if (!defined $config) { push @missing_args, "--config"; }
+if (!defined $weights) { push @missing_args, "--weights"; }
+die "Please specify missing arguments: " . join (', ', @missing_args) . "\nUse --help for more information.\n" if (@missing_args);
+
+my @tf = localtime(time);
+my $tname = basename($test_set);
+$tname =~ s/\.(sgm|sgml|xml)$//i;
+my $dir = "eval.$tname." . sprintf('%d%02d%02d-%02d%02d%02d', 1900+$tf[5], $tf[4], $tf[3], $tf[2], $tf[1], $tf[0]);
+
+my $time = unchecked_output("date");
+
+check_call("mkdir -p $dir");
+
+split_devset($test_set, "$dir/test.input.raw", "$dir/test.refs");
+my $refs = "-r $dir/test.refs";
+my $newsrc = "$dir/test.input";
+enseg("$dir/test.input.raw", $newsrc);
+my $src_file = $newsrc;
+open F, "<$src_file" or die "Can't read $src_file: $!"; close F;
+
+my $test_trans="$dir/test.trans";
+my $logdir="$dir/logs";
+my $decoderLog="$logdir/decoder.sentserver.log";
+check_call("mkdir -p $logdir");
+
+#decode
+print STDERR "RUNNING DECODER AT ";
+print STDERR unchecked_output("date");
+my $decoder_cmd = "$decoder -c $config --weights $weights";
+my $pcmd;
+if ($use_make) {
+ $pcmd = "cat $src_file | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --";
+} else {
+ $pcmd = "cat $src_file | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --";
+}
+my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $test_trans";
+check_bash_call($cmd);
+print STDERR "DECODER COMPLETED AT ";
+print STDERR unchecked_output("date");
+print STDERR "\nOUTPUT: $test_trans\n\n";
+my $bleu = check_output("cat $test_trans | $SCORER $refs -m ibm_bleu");
+chomp $bleu;
+print STDERR "BLEU: $bleu\n";
+my $ter = check_output("cat $test_trans | $SCORER $refs -m ter");
+chomp $ter;
+print STDERR " TER: $ter\n";
+open TR, ">$dir/test.scores" or die "Can't write $dir/test.scores: $!";
+print TR <<EOT;
+### SCORE REPORT #############################################################
+ OUTPUT=$test_trans
+ SCRIPT INPUT=$test_set
+ DECODER INPUT=$src_file
+ REFERENCES=$dir/test.refs
+------------------------------------------------------------------------------
+ BLEU=$bleu
+ TER=$ter
+##############################################################################
+EOT
+close TR;
+my $sr = unchecked_output("cat $dir/test.scores");
+print STDERR "\n\n$sr\n(A copy of this report can be found in $dir/test.scores)\n\n";
+exit 0;
+
+sub enseg {
+ my $src = shift;
+ my $newsrc = shift;
+ open(SRC, $src);
+ open(NEWSRC, ">$newsrc");
+ my $i=0;
+ while (my $line=<SRC>){
+ chomp $line;
+ if ($line =~ /^\s*<seg/i) {
+ if($line =~ /id="[0-9]+"/) {
+ print NEWSRC "$line\n";
+ } else {
+ die "When using segments with pre-generated <seg> tags, you must include a zero-based id attribute";
+ }
+ } else {
+ print NEWSRC "<seg id=\"$i\">$line</seg>\n";
+ }
+ $i++;
+ }
+ close SRC;
+ close NEWSRC;
+}
+
+sub print_help {
+ my $executable = basename($0); chomp $executable;
+ print << "Help";
+
+Usage: $executable [options] <ini file>
+
+ $executable --config cdec.ini --weights weights.txt [--jobs N] [--qsub] <testset.in-ref>
+
+Options:
+
+ --help
+ Print this message and exit.
+
+ --config <file>
+ A path to the cdec.ini file.
+
+ --weights <file>
+ A file specifying feature weights.
+
+ --dir <dir>
+ Directory for intermediate and output files.
+
+Job control options:
+
+ --jobs <I>
+ Number of decoder processes to run in parallel. [default=$default_jobs]
+
+ --qsub
+ Use qsub to run jobs in parallel (qsub must be configured in
+ environment/LocalEnvironment.pm)
+
+ --pmem <N>
+ Amount of physical memory requested for parallel decoding jobs
+ (used with qsub requests only)
+
+Help
+}
+
+sub convert {
+ my ($str) = @_;
+ my @ps = split /;/, $str;
+ my %dict = ();
+ for my $p (@ps) {
+ my ($k, $v) = split /=/, $p;
+ $dict{$k} = $v;
+ }
+ return %dict;
+}
+
+
+
+sub cmdline {
+ return join ' ',($0,@ORIG_ARGV);
+}
+
+#buggy: last arg gets quoted sometimes?
+my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]};
+my $shell_escape_in_quote=qr{[\\"\$`!]};
+
+sub escape_shell {
+ my ($arg)=@_;
+ return undef unless defined $arg;
+ if ($arg =~ /$is_shell_special/) {
+ $arg =~ s/($shell_escape_in_quote)/\\$1/g;
+ return "\"$arg\"";
+ }
+ return $arg;
+}
+
+sub escaped_shell_args {
+ return map {local $_=$_;chomp;escape_shell($_)} @_;
+}
+
+sub escaped_shell_args_str {
+ return join ' ',&escaped_shell_args(@_);
+}
+
+sub escaped_cmdline {
+ return "$0 ".&escaped_shell_args_str(@ORIG_ARGV);
+}
+
+sub split_devset {
+ my ($infile, $outsrc, $outref) = @_;
+ open F, "<$infile" or die "Can't read $infile: $!";
+ open S, ">$outsrc" or die "Can't write $outsrc: $!";
+ open R, ">$outref" or die "Can't write $outref: $!";
+ while(<F>) {
+ chomp;
+ my ($src, @refs) = split /\s*\|\|\|\s*/;
+ die "Malformed devset line: $_\n" unless scalar @refs > 0;
+ print S "$src\n";
+ print R join(' ||| ', @refs) . "\n";
+ }
+ close R;
+ close S;
+ close F;
+}
+
diff --git a/training/utils/entropy.cc b/training/utils/entropy.cc
new file mode 100644
index 00000000..4fdbe2be
--- /dev/null
+++ b/training/utils/entropy.cc
@@ -0,0 +1,41 @@
+#include "entropy.h"
+
+#include "prob.h"
+#include "candidate_set.h"
+
+using namespace std;
+
+namespace training {
+
+// see Mann and McCallum "Efficient Computation of Entropy Gradient ..." for
+// a mostly clear derivation of:
+// g = E[ F(x,y) * log p(y|x) ] + H(y | x) * E[ F(x,y) ]
+double CandidateSetEntropy::operator()(const vector<double>& params,
+ SparseVector<double>* g) const {
+ prob_t z;
+ vector<double> dps(cands_.size());
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ dps[i] = cands_[i].fmap.dot(params);
+ const prob_t u(dps[i], init_lnx());
+ z += u;
+ }
+ const double log_z = log(z);
+
+ SparseVector<double> exp_feats;
+ double entropy = 0;
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const double log_prob = cands_[i].fmap.dot(params) - log_z;
+ const double prob = exp(log_prob);
+ const double e_logprob = prob * log_prob;
+ entropy -= e_logprob;
+ if (g) {
+ (*g) += cands_[i].fmap * e_logprob;
+ exp_feats += cands_[i].fmap * prob;
+ }
+ }
+ if (g) (*g) += exp_feats * entropy;
+ return entropy;
+}
+
+}
+
diff --git a/training/utils/entropy.h b/training/utils/entropy.h
new file mode 100644
index 00000000..796589ca
--- /dev/null
+++ b/training/utils/entropy.h
@@ -0,0 +1,22 @@
+#ifndef _CSENTROPY_H_
+#define _CSENTROPY_H_
+
+#include <vector>
+#include "sparse_vector.h"
+
+namespace training {
+ class CandidateSet;
+
+ class CandidateSetEntropy {
+ public:
+ explicit CandidateSetEntropy(const CandidateSet& cs) : cands_(cs) {}
+ // compute the entropy (expected log likelihood) of a CandidateSet
+ // (optional) the gradient of the entropy with respect to params
+ double operator()(const std::vector<double>& params,
+ SparseVector<double>* g = NULL) const;
+ private:
+ const CandidateSet& cands_;
+ };
+};
+
+#endif
diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc
new file mode 100644
index 00000000..607a7cb9
--- /dev/null
+++ b/training/utils/grammar_convert.cc
@@ -0,0 +1,348 @@
+/*
+ this program modifies cfg hypergraphs (forests) and extracts kbests?
+ what are: json, split ?
+ */
+#include <iostream>
+#include <algorithm>
+#include <sstream>
+
+#include <boost/lexical_cast.hpp>
+#include <boost/program_options.hpp>
+
+#include "inside_outside.h"
+#include "tdict.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "kbest.h"
+#include "viterbi.h"
+#include "weights.h"
+
+namespace po = boost::program_options;
+using namespace std;
+
+WordID kSTART;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("input,i", po::value<string>()->default_value("-"), "Input file")
+ ("format,f", po::value<string>()->default_value("cfg"), "Input format. Values: cfg, json, split")
+ ("output,o", po::value<string>()->default_value("json"), "Output command. Values: json, 1best")
+ ("reorder,r", "Add Yamada & Knight (2002) reorderings")
+ ("weights,w", po::value<string>(), "Feature weights for k-best derivations [optional]")
+ ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights")
+ ("k_derivations,k", po::value<int>(), "Show k derivations and their features")
+ ("max_reorder,m", po::value<int>()->default_value(999), "Move a constituent at most this far")
+ ("help,h", "Print this help message and exit");
+ po::options_description clo("Command line options");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ po::notify(*conf);
+
+ if (conf->count("help") || conf->count("input") == 0) {
+ cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n";
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int GetOrCreateNode(const WordID& lhs, map<WordID, int>* lhs2node, Hypergraph* hg) {
+ int& node_id = (*lhs2node)[lhs];
+ if (!node_id)
+ node_id = hg->AddNode(lhs)->id_ + 1;
+ return node_id - 1;
+}
+
+void FilterAndCheckCorrectness(int goal, Hypergraph* hg) {
+ if (goal < 0) {
+ cerr << "Error! [S] not found in grammar!\n";
+ exit(1);
+ }
+ if (hg->nodes_[goal].in_edges_.size() != 1) {
+ cerr << "Error! [S] has more than one rewrite!\n";
+ exit(1);
+ }
+ int old_size = hg->nodes_.size();
+ hg->TopologicallySortNodesAndEdges(goal);
+ if (hg->nodes_.size() != old_size) {
+ cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n";
+ }
+ vector<double> inside; // inside score at each node
+ double p = Inside<double, TransitionCountWeightFunction>(*hg, &inside);
+ if (!p) {
+ cerr << "Warning! Grammar defines the empty language!\n";
+ hg->clear();
+ return;
+ }
+ vector<bool> prune(hg->edges_.size(), false);
+ int bad_edges = 0;
+ for (unsigned i = 0; i < hg->edges_.size(); ++i) {
+ Hypergraph::Edge& edge = hg->edges_[i];
+ bool bad = false;
+ for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) {
+ if (!inside[edge.tail_nodes_[j]]) {
+ bad = true;
+ ++bad_edges;
+ }
+ }
+ prune[i] = bad;
+ }
+ cerr << "Removing " << bad_edges << " bad edges from the grammar.\n";
+ for (unsigned i = 0; i < hg->edges_.size(); ++i) {
+ if (prune[i])
+ cerr << " " << hg->edges_[i].rule_->AsString() << endl;
+ }
+ hg->PruneEdges(prune);
+}
+
+void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) {
+ Hypergraph::Edge* new_edge = hg->AddEdge(r, tail);
+ hg->ConnectEdgeToHeadNode(new_edge, head_node);
+ new_edge->feature_values_ = r->scores_;
+}
+
+// from a category label like "NP_2", return "NP"
+string PureCategory(WordID cat) {
+ assert(cat < 0);
+ string c = TD::Convert(cat*-1);
+ size_t p = c.find("_");
+ if (p == string::npos) return c;
+ return c.substr(0, p);
+};
+
+string ConstituentOrderFeature(const TRule& rule, const vector<int>& pi) {
+ const static string kTERM_VAR = "x";
+ const vector<WordID>& f = rule.f();
+ map<string, int> used;
+ vector<string> terms(f.size());
+ for (int i = 0; i < f.size(); ++i) {
+ const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR);
+ int& count = used[term];
+ if (!count) {
+ terms[i] = term;
+ } else {
+ ostringstream os;
+ os << term << count;
+ terms[i] = os.str();
+ }
+ ++count;
+ }
+ ostringstream os;
+ os << PureCategory(rule.GetLHS()) << ':';
+ for (int i = 0; i < f.size(); ++i) {
+ if (i > 0) os << '_';
+ os << terms[pi[i]];
+ }
+ return os.str();
+}
+
+bool CheckPermutationMask(const vector<int>& mask, const vector<int>& pi) {
+ assert(mask.size() == pi.size());
+
+ int req_min = -1;
+ int cur_max = 0;
+ int cur_mask = -1;
+ for (int i = 0; i < mask.size(); ++i) {
+ if (mask[i] != cur_mask) {
+ cur_mask = mask[i];
+ req_min = cur_max - 1;
+ }
+ if (pi[i] > req_min) {
+ if (pi[i] > cur_max) cur_max = pi[i];
+ } else {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) {
+ // Hypergraph tmp = *hg;
+ Hypergraph::Node* node = &hg->nodes_[nodeid];
+ if (node->in_edges_.size() != 1) {
+ cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n";
+ cerr << " not recursing!\n";
+ return;
+ }
+// for (int eii = 0; eii < node->in_edges_.size(); ++eii) {
+ const int oe_index = node->in_edges_.front();
+ const TRule& rule = *hg->edges_[oe_index].rule_;
+ const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_;
+ const int tail_size = orig_tail.size();
+ for (int i = 0; i < tail_size; ++i) {
+ PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg);
+ }
+ const vector<WordID>& of = rule.f_;
+ if (of.size() == 1) return;
+ // cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n";
+ // cerr << "ORIG: " << rule.AsString() << endl;
+ vector<WordID> pi(of.size(), 0);
+ for (int i = 0; i < pi.size(); ++i) pi[i] = i;
+
+ vector<int> permutation_mask(of.size(), 0);
+ const bool dont_reorder_across_PU = true; // TODO add configuration
+ if (dont_reorder_across_PU) {
+ int cur = 0;
+ for (int i = 0; i < pi.size(); ++i) {
+ if (of[i] >= 0) continue;
+ const string cat = PureCategory(of[i]);
+ if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") {
+ ++cur;
+ permutation_mask[i] = cur;
+ ++cur;
+ } else {
+ permutation_mask[i] = cur;
+ }
+ }
+ }
+ int fid = FD::Convert(ConstituentOrderFeature(rule, pi));
+ hg->edges_[oe_index].feature_values_.set_value(fid, 1.0);
+ while (next_permutation(pi.begin(), pi.end())) {
+ if (!CheckPermutationMask(permutation_mask, pi))
+ continue;
+ vector<WordID> nf(pi.size(), 0);
+ Hypergraph::TailNodeVector tail(pi.size(), 0);
+ bool skip = false;
+ for (int i = 0; i < pi.size(); ++i) {
+ int dist = pi[i] - i; if (dist < 0) dist *= -1;
+ if (dist > max_reorder) { skip = true; break; }
+ nf[i] = of[pi[i]];
+ tail[i] = orig_tail[pi[i]];
+ }
+ if (skip) continue;
+ TRulePtr nr(new TRule(rule));
+ nr->f_ = nf;
+ int fid = FD::Convert(ConstituentOrderFeature(rule, pi));
+ nr->scores_.set_value(fid, 1.0);
+ // cerr << "PERM: " << nr->AsString() << endl;
+ CreateEdge(nr, tail, node, hg);
+ }
+ // }
+}
+
+void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) {
+ assert(hg->nodes_.back().cat_ == kSTART);
+ assert(hg->nodes_.back().in_edges_.size() == 1);
+ PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg);
+}
+
+void CollapseWeights(Hypergraph* hg) {
+ int fid = FD::Convert("Reordering");
+ for (int i = 0; i < hg->edges_.size(); ++i) {
+ Hypergraph::Edge& edge = hg->edges_[i];
+ edge.feature_values_.clear();
+ if (edge.edge_prob_ != prob_t::Zero()) {
+ edge.feature_values_.set_value(fid, log(edge.edge_prob_));
+ }
+ }
+}
+
+void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) {
+ if (conf.count("reorder"))
+ PermuteYamadaAndKnight(hg, conf["max_reorder"].as<int>());
+ if (w.size() > 0) { hg->Reweight(w); }
+ if (conf.count("collapse_weights")) CollapseWeights(hg);
+ if (conf["output"].as<string>() == "json") {
+ HypergraphIO::WriteToJSON(*hg, false, &cout);
+ if (!ref.empty()) { cerr << "REF: " << ref << endl; }
+ } else {
+ vector<WordID> onebest;
+ ViterbiESentence(*hg, &onebest);
+ if (ref.empty()) {
+ cout << TD::GetString(onebest) << endl;
+ } else {
+ cout << TD::GetString(onebest) << " ||| " << ref << endl;
+ }
+ }
+ if (conf.count("k_derivations")) {
+ const int k = conf["k_derivations"].as<int>();
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ }
+}
+
+int main(int argc, char **argv) {
+ kSTART = TD::Convert("S") * -1;
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ string infile = conf["input"].as<string>();
+ const bool is_split_input = (conf["format"].as<string>() == "split");
+ const bool is_json_input = is_split_input || (conf["format"].as<string>() == "json");
+ const bool collapse_weights = conf.count("collapse_weights");
+ vector<double> w;
+ if (conf.count("weights"))
+ Weights::InitFromFile(conf["weights"].as<string>(), &w);
+
+ if (collapse_weights && !w.size()) {
+ cerr << "--collapse_weights requires a weights file to be specified!\n";
+ exit(1);
+ }
+ ReadFile rf(infile);
+ istream* in = rf.stream();
+ assert(*in);
+ int lc = 0;
+ Hypergraph hg;
+ map<WordID, int> lhs2node;
+ while(*in) {
+ string line;
+ ++lc;
+ getline(*in, line);
+ if (is_json_input) {
+ if (line.empty() || line[0] == '#') continue;
+ string ref;
+ if (is_split_input) {
+ size_t pos = line.rfind("}}");
+ assert(pos != string::npos);
+ size_t rstart = line.find("||| ", pos);
+ assert(rstart != string::npos);
+ ref = line.substr(rstart + 4);
+ line = line.substr(0, pos + 2);
+ }
+ istringstream is(line);
+ if (HypergraphIO::ReadFromJSON(&is, &hg)) {
+ ProcessHypergraph(w, conf, ref, &hg);
+ hg.clear();
+ } else {
+ cerr << "Error reading grammar from JSON: line " << lc << endl;
+ exit(1);
+ }
+ } else {
+ if (line.empty()) {
+ int goal = lhs2node[kSTART] - 1;
+ FilterAndCheckCorrectness(goal, &hg);
+ ProcessHypergraph(w, conf, "", &hg);
+ hg.clear();
+ lhs2node.clear();
+ continue;
+ }
+ if (line[0] == '#') continue;
+ if (line[0] != '[') {
+ cerr << "Line " << lc << ": bad format\n";
+ exit(1);
+ }
+ TRulePtr tr(TRule::CreateRuleMonolingual(line));
+ Hypergraph::TailNodeVector tail;
+ for (int i = 0; i < tr->f_.size(); ++i) {
+ WordID var_cat = tr->f_[i];
+ if (var_cat < 0)
+ tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg));
+ }
+ const WordID lhs = tr->GetLHS();
+ int head = GetOrCreateNode(lhs, &lhs2node, &hg);
+ Hypergraph::Edge* edge = hg.AddEdge(tr, tail);
+ edge->feature_values_ = tr->scores_;
+ Hypergraph::Node* node = &hg.nodes_[head];
+ hg.ConnectEdgeToHeadNode(edge, node);
+ }
+ }
+}
+
diff --git a/training/utils/lbfgs.h b/training/utils/lbfgs.h
new file mode 100644
index 00000000..e8baecab
--- /dev/null
+++ b/training/utils/lbfgs.h
@@ -0,0 +1,1459 @@
+#ifndef SCITBX_LBFGS_H
+#define SCITBX_LBFGS_H
+
+#include <cstdio>
+#include <cstddef>
+#include <cmath>
+#include <stdexcept>
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <iostream>
+#include <sstream>
+
+namespace scitbx {
+
+//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer.
+/*! Implementation of the
+ Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS)
+ algorithm for large-scale multidimensional minimization
+ problems.
+
+ This code was manually derived from Java code which was
+ in turn derived from the Fortran program
+ <code>lbfgs.f</code>. The Java translation was
+ effected mostly mechanically, with some manual
+ clean-up; in particular, array indices start at 0
+ instead of 1. Most of the comments from the Fortran
+ code have been pasted in.
+
+ Information on the original LBFGS Fortran source code is
+ available at
+ http://www.netlib.org/opt/lbfgs_um.shar . The following
+ information is taken verbatim from the Netlib documentation
+ for the Fortran source.
+
+ <pre>
+ file opt/lbfgs_um.shar
+ for unconstrained optimization problems
+ alg limited memory BFGS method
+ by J. Nocedal
+ contact nocedal@eecs.nwu.edu
+ ref D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
+ , large scale optimization methods'' Mathematical Programming 45
+ , (1989), pp. 503-528.
+ , (Postscript file of this paper is available via anonymous ftp
+ , to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
+ </pre>
+
+ @author Jorge Nocedal: original Fortran version, including comments
+ (July 1990).<br>
+ Robert Dodier: Java translation, August 1997.<br>
+ Ralf W. Grosse-Kunstleve: C++ port, March 2002.<br>
+ Chris Dyer: serialize/deserialize functionality
+ */
+namespace lbfgs {
+
+ //! Generic exception class for %lbfgs %error messages.
+ /*! All exceptions thrown by the minimizer are derived from this class.
+ */
+ class error : public std::exception {
+ public:
+ //! Constructor.
+ error(std::string const& msg) throw()
+ : msg_("lbfgs error: " + msg)
+ {}
+ //! Access to error message.
+ virtual const char* what() const throw() { return msg_.c_str(); }
+ protected:
+ virtual ~error() throw() {}
+ std::string msg_;
+ public:
+ static std::string itoa(unsigned long i) {
+ std::ostringstream os;
+ os << i;
+ return os.str();
+ }
+ };
+
+ //! Specific exception class.
+ class error_internal_error : public error {
+ public:
+ //! Constructor.
+ error_internal_error(const char* file, unsigned long line) throw()
+ : error(
+ "Internal Error: " + std::string(file) + "(" + itoa(line) + ")")
+ {}
+ };
+
+ //! Specific exception class.
+ class error_improper_input_parameter : public error {
+ public:
+ //! Constructor.
+ error_improper_input_parameter(std::string const& msg) throw()
+ : error("Improper input parameter: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_improper_input_data : public error {
+ public:
+ //! Constructor.
+ error_improper_input_data(std::string const& msg) throw()
+ : error("Improper input data: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_search_direction_not_descent : public error {
+ public:
+ //! Constructor.
+ error_search_direction_not_descent() throw()
+ : error("The search direction is not a descent direction.")
+ {}
+ };
+
+ //! Specific exception class.
+ class error_line_search_failed : public error {
+ public:
+ //! Constructor.
+ error_line_search_failed(std::string const& msg) throw()
+ : error("Line search failed: " + msg)
+ {}
+ };
+
+ //! Specific exception class.
+ class error_line_search_failed_rounding_errors
+ : public error_line_search_failed {
+ public:
+ //! Constructor.
+ error_line_search_failed_rounding_errors(std::string const& msg) throw()
+ : error_line_search_failed(msg)
+ {}
+ };
+
+ namespace detail {
+
+ template <typename NumType>
+ inline
+ NumType
+ pow2(NumType const& x) { return x * x; }
+
+ template <typename NumType>
+ inline
+ NumType
+ abs(NumType const& x) {
+ if (x < NumType(0)) return -x;
+ return x;
+ }
+
+ // This class implements an algorithm for multi-dimensional line search.
+ template <typename FloatType, typename SizeType = std::size_t>
+ class mcsrch
+ {
+ protected:
+ int infoc;
+ FloatType dginit;
+ bool brackt;
+ bool stage1;
+ FloatType finit;
+ FloatType dgtest;
+ FloatType width;
+ FloatType width1;
+ FloatType stx;
+ FloatType fx;
+ FloatType dgx;
+ FloatType sty;
+ FloatType fy;
+ FloatType dgy;
+ FloatType stmin;
+ FloatType stmax;
+
+ static FloatType const& max3(
+ FloatType const& x,
+ FloatType const& y,
+ FloatType const& z)
+ {
+ return x < y ? (y < z ? z : y ) : (x < z ? z : x );
+ }
+
+ public:
+ /* Minimize a function along a search direction. This code is
+ a Java translation of the function <code>MCSRCH</code> from
+ <code>lbfgs.f</code>, which in turn is a slight modification
+ of the subroutine <code>CSRCH</code> of More' and Thuente.
+ The changes are to allow reverse communication, and do not
+ affect the performance of the routine. This function, in turn,
+ calls <code>mcstep</code>.<p>
+
+ The Java translation was effected mostly mechanically, with
+ some manual clean-up; in particular, array indices start at 0
+ instead of 1. Most of the comments from the Fortran code have
+ been pasted in here as well.<p>
+
+ The purpose of <code>mcsrch</code> is to find a step which
+ satisfies a sufficient decrease condition and a curvature
+ condition.<p>
+
+ At each stage this function updates an interval of uncertainty
+ with endpoints <code>stx</code> and <code>sty</code>. The
+ interval of uncertainty is initially chosen so that it
+ contains a minimizer of the modified function
+ <pre>
+ f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
+ </pre>
+ If a step is obtained for which the modified function has a
+ nonpositive function value and nonnegative derivative, then
+ the interval of uncertainty is chosen so that it contains a
+ minimizer of <code>f(x+stp*s)</code>.<p>
+
+ The algorithm is designed to find a step which satisfies
+ the sufficient decrease condition
+ <pre>
+ f(x+stp*s) &lt;= f(X) + ftol*stp*(gradf(x)'s),
+ </pre>
+ and the curvature condition
+ <pre>
+ abs(gradf(x+stp*s)'s)) &lt;= gtol*abs(gradf(x)'s).
+ </pre>
+ If <code>ftol</code> is less than <code>gtol</code> and if,
+ for example, the function is bounded below, then there is
+ always a step which satisfies both conditions. If no step can
+ be found which satisfies both conditions, then the algorithm
+ usually stops when rounding errors prevent further progress.
+ In this case <code>stp</code> only satisfies the sufficient
+ decrease condition.<p>
+
+ @author Original Fortran version by Jorge J. More' and
+ David J. Thuente as part of the Minpack project, June 1983,
+ Argonne National Laboratory. Java translation by Robert
+ Dodier, August 1997.
+
+ @param n The number of variables.
+
+ @param x On entry this contains the base point for the line
+ search. On exit it contains <code>x + stp*s</code>.
+
+ @param f On entry this contains the value of the objective
+ function at <code>x</code>. On exit it contains the value
+ of the objective function at <code>x + stp*s</code>.
+
+ @param g On entry this contains the gradient of the objective
+ function at <code>x</code>. On exit it contains the gradient
+ at <code>x + stp*s</code>.
+
+ @param s The search direction.
+
+ @param stp On entry this contains an initial estimate of a
+ satifactory step length. On exit <code>stp</code> contains
+ the final estimate.
+
+ @param ftol Tolerance for the sufficient decrease condition.
+
+ @param xtol Termination occurs when the relative width of the
+ interval of uncertainty is at most <code>xtol</code>.
+
+ @param maxfev Termination occurs when the number of evaluations
+ of the objective function is at least <code>maxfev</code> by
+ the end of an iteration.
+
+ @param info This is an output variable, which can have these
+ values:
+ <ul>
+ <li><code>info = -1</code> A return is made to compute
+ the function and gradient.
+ <li><code>info = 1</code> The sufficient decrease condition
+ and the directional derivative condition hold.
+ </ul>
+
+ @param nfev On exit, this is set to the number of function
+ evaluations.
+
+ @param wa Temporary storage array, of length <code>n</code>.
+ */
+ void run(
+ FloatType const& gtol,
+ FloatType const& stpmin,
+ FloatType const& stpmax,
+ SizeType n,
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ FloatType* s,
+ SizeType is0,
+ FloatType& stp,
+ FloatType ftol,
+ FloatType xtol,
+ SizeType maxfev,
+ int& info,
+ SizeType& nfev,
+ FloatType* wa);
+
+ /* The purpose of this function is to compute a safeguarded step
+ for a linesearch and to update an interval of uncertainty for
+ a minimizer of the function.<p>
+
+ The parameter <code>stx</code> contains the step with the
+ least function value. The parameter <code>stp</code> contains
+ the current step. It is assumed that the derivative at
+ <code>stx</code> is negative in the direction of the step. If
+ <code>brackt</code> is <code>true</code> when
+ <code>mcstep</code> returns then a minimizer has been
+ bracketed in an interval of uncertainty with endpoints
+ <code>stx</code> and <code>sty</code>.<p>
+
+ Variables that must be modified by <code>mcstep</code> are
+ implemented as 1-element arrays.
+
+ @param stx Step at the best step obtained so far.
+ This variable is modified by <code>mcstep</code>.
+ @param fx Function value at the best step obtained so far.
+ This variable is modified by <code>mcstep</code>.
+ @param dx Derivative at the best step obtained so far.
+ The derivative must be negative in the direction of the
+ step, that is, <code>dx</code> and <code>stp-stx</code> must
+ have opposite signs. This variable is modified by
+ <code>mcstep</code>.
+
+ @param sty Step at the other endpoint of the interval of
+ uncertainty. This variable is modified by <code>mcstep</code>.
+ @param fy Function value at the other endpoint of the interval
+ of uncertainty. This variable is modified by
+ <code>mcstep</code>.
+
+ @param dy Derivative at the other endpoint of the interval of
+ uncertainty. This variable is modified by <code>mcstep</code>.
+
+ @param stp Step at the current step. If <code>brackt</code> is set
+ then on input <code>stp</code> must be between <code>stx</code>
+ and <code>sty</code>. On output <code>stp</code> is set to the
+ new step.
+ @param fp Function value at the current step.
+ @param dp Derivative at the current step.
+
+ @param brackt Tells whether a minimizer has been bracketed.
+ If the minimizer has not been bracketed, then on input this
+ variable must be set <code>false</code>. If the minimizer has
+ been bracketed, then on output this variable is
+ <code>true</code>.
+
+ @param stpmin Lower bound for the step.
+ @param stpmax Upper bound for the step.
+
+ If the return value is 1, 2, 3, or 4, then the step has
+ been computed successfully. A return value of 0 indicates
+ improper input parameters.
+
+ @author Jorge J. More, David J. Thuente: original Fortran version,
+ as part of Minpack project. Argonne Nat'l Laboratory, June 1983.
+ Robert Dodier: Java translation, August 1997.
+ */
+ static int mcstep(
+ FloatType& stx,
+ FloatType& fx,
+ FloatType& dx,
+ FloatType& sty,
+ FloatType& fy,
+ FloatType& dy,
+ FloatType& stp,
+ FloatType fp,
+ FloatType dp,
+ bool& brackt,
+ FloatType stpmin,
+ FloatType stpmax);
+
+ void serialize(std::ostream* out) const {
+ out->write((const char*)&infoc,sizeof(infoc));
+ out->write((const char*)&dginit,sizeof(dginit));
+ out->write((const char*)&brackt,sizeof(brackt));
+ out->write((const char*)&stage1,sizeof(stage1));
+ out->write((const char*)&finit,sizeof(finit));
+ out->write((const char*)&dgtest,sizeof(dgtest));
+ out->write((const char*)&width,sizeof(width));
+ out->write((const char*)&width1,sizeof(width1));
+ out->write((const char*)&stx,sizeof(stx));
+ out->write((const char*)&fx,sizeof(fx));
+ out->write((const char*)&dgx,sizeof(dgx));
+ out->write((const char*)&sty,sizeof(sty));
+ out->write((const char*)&fy,sizeof(fy));
+ out->write((const char*)&dgy,sizeof(dgy));
+ out->write((const char*)&stmin,sizeof(stmin));
+ out->write((const char*)&stmax,sizeof(stmax));
+ }
+
+ void deserialize(std::istream* in) const {
+ in->read((char*)&infoc, sizeof(infoc));
+ in->read((char*)&dginit, sizeof(dginit));
+ in->read((char*)&brackt, sizeof(brackt));
+ in->read((char*)&stage1, sizeof(stage1));
+ in->read((char*)&finit, sizeof(finit));
+ in->read((char*)&dgtest, sizeof(dgtest));
+ in->read((char*)&width, sizeof(width));
+ in->read((char*)&width1, sizeof(width1));
+ in->read((char*)&stx, sizeof(stx));
+ in->read((char*)&fx, sizeof(fx));
+ in->read((char*)&dgx, sizeof(dgx));
+ in->read((char*)&sty, sizeof(sty));
+ in->read((char*)&fy, sizeof(fy));
+ in->read((char*)&dgy, sizeof(dgy));
+ in->read((char*)&stmin, sizeof(stmin));
+ in->read((char*)&stmax, sizeof(stmax));
+ }
+ };
+
+ template <typename FloatType, typename SizeType>
+ void mcsrch<FloatType, SizeType>::run(
+ FloatType const& gtol,
+ FloatType const& stpmin,
+ FloatType const& stpmax,
+ SizeType n,
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ FloatType* s,
+ SizeType is0,
+ FloatType& stp,
+ FloatType ftol,
+ FloatType xtol,
+ SizeType maxfev,
+ int& info,
+ SizeType& nfev,
+ FloatType* wa)
+ {
+ if (info != -1) {
+ infoc = 1;
+ if ( n == 0
+ || maxfev == 0
+ || gtol < FloatType(0)
+ || xtol < FloatType(0)
+ || stpmin < FloatType(0)
+ || stpmax < stpmin) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ if (stp <= FloatType(0) || ftol < FloatType(0)) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ // Compute the initial gradient in the search direction
+ // and check that s is a descent direction.
+ dginit = FloatType(0);
+ for (SizeType j = 0; j < n; j++) {
+ dginit += g[j] * s[is0+j];
+ }
+ if (dginit >= FloatType(0)) {
+ throw error_search_direction_not_descent();
+ }
+ brackt = false;
+ stage1 = true;
+ nfev = 0;
+ finit = f;
+ dgtest = ftol*dginit;
+ width = stpmax - stpmin;
+ width1 = FloatType(2) * width;
+ std::copy(x, x+n, wa);
+ // The variables stx, fx, dgx contain the values of the step,
+ // function, and directional derivative at the best step.
+ // The variables sty, fy, dgy contain the value of the step,
+ // function, and derivative at the other endpoint of
+ // the interval of uncertainty.
+ // The variables stp, f, dg contain the values of the step,
+ // function, and derivative at the current step.
+ stx = FloatType(0);
+ fx = finit;
+ dgx = dginit;
+ sty = FloatType(0);
+ fy = finit;
+ dgy = dginit;
+ }
+ for (;;) {
+ if (info != -1) {
+ // Set the minimum and maximum steps to correspond
+ // to the present interval of uncertainty.
+ if (brackt) {
+ stmin = std::min(stx, sty);
+ stmax = std::max(stx, sty);
+ }
+ else {
+ stmin = stx;
+ stmax = stp + FloatType(4) * (stp - stx);
+ }
+ // Force the step to be within the bounds stpmax and stpmin.
+ stp = std::max(stp, stpmin);
+ stp = std::min(stp, stpmax);
+ // If an unusual termination is to occur then let
+ // stp be the lowest point obtained so far.
+ if ( (brackt && (stp <= stmin || stp >= stmax))
+ || nfev >= maxfev - 1 || infoc == 0
+ || (brackt && stmax - stmin <= xtol * stmax)) {
+ stp = stx;
+ }
+ // Evaluate the function and gradient at stp
+ // and compute the directional derivative.
+ // We return to main program to obtain F and G.
+ for (SizeType j = 0; j < n; j++) {
+ x[j] = wa[j] + stp * s[is0+j];
+ }
+ info=-1;
+ break;
+ }
+ info = 0;
+ nfev++;
+ FloatType dg(0);
+ for (SizeType j = 0; j < n; j++) {
+ dg += g[j] * s[is0+j];
+ }
+ FloatType ftest1 = finit + stp*dgtest;
+ // Test for convergence.
+ if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) {
+ throw error_line_search_failed_rounding_errors(
+ "Rounding errors prevent further progress."
+ " There may not be a step which satisfies the"
+ " sufficient decrease and curvature conditions."
+ " Tolerances may be too small.");
+ }
+ if (stp == stpmax && f <= ftest1 && dg <= dgtest) {
+ throw error_line_search_failed(
+ "The step is at the upper bound stpmax().");
+ }
+ if (stp == stpmin && (f > ftest1 || dg >= dgtest)) {
+ throw error_line_search_failed(
+ "The step is at the lower bound stpmin().");
+ }
+ if (nfev >= maxfev) {
+ throw error_line_search_failed(
+ "Number of function evaluations has reached maxfev().");
+ }
+ if (brackt && stmax - stmin <= xtol * stmax) {
+ throw error_line_search_failed(
+ "Relative width of the interval of uncertainty"
+ " is at most xtol().");
+ }
+ // Check for termination.
+ if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) {
+ info = 1;
+ break;
+ }
+ // In the first stage we seek a step for which the modified
+ // function has a nonpositive value and nonnegative derivative.
+ if ( stage1 && f <= ftest1
+ && dg >= std::min(ftol, gtol) * dginit) {
+ stage1 = false;
+ }
+ // A modified function is used to predict the step only if
+ // we have not obtained a step for which the modified
+ // function has a nonpositive function value and nonnegative
+ // derivative, and if a lower function value has been
+ // obtained but the decrease is not sufficient.
+ if (stage1 && f <= fx && f > ftest1) {
+ // Define the modified function and derivative values.
+ FloatType fm = f - stp*dgtest;
+ FloatType fxm = fx - stx*dgtest;
+ FloatType fym = fy - sty*dgtest;
+ FloatType dgm = dg - dgtest;
+ FloatType dgxm = dgx - dgtest;
+ FloatType dgym = dgy - dgtest;
+ // Call cstep to update the interval of uncertainty
+ // and to compute the new step.
+ infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm,
+ brackt, stmin, stmax);
+ // Reset the function and gradient values for f.
+ fx = fxm + stx*dgtest;
+ fy = fym + sty*dgtest;
+ dgx = dgxm + dgtest;
+ dgy = dgym + dgtest;
+ }
+ else {
+ // Call mcstep to update the interval of uncertainty
+ // and to compute the new step.
+ infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg,
+ brackt, stmin, stmax);
+ }
+ // Force a sufficient decrease in the size of the
+ // interval of uncertainty.
+ if (brackt) {
+ if (abs(sty - stx) >= FloatType(0.66) * width1) {
+ stp = stx + FloatType(0.5) * (sty - stx);
+ }
+ width1 = width;
+ width = abs(sty - stx);
+ }
+ }
+ }
+
+ template <typename FloatType, typename SizeType>
+ int mcsrch<FloatType, SizeType>::mcstep(
+ FloatType& stx,
+ FloatType& fx,
+ FloatType& dx,
+ FloatType& sty,
+ FloatType& fy,
+ FloatType& dy,
+ FloatType& stp,
+ FloatType fp,
+ FloatType dp,
+ bool& brackt,
+ FloatType stpmin,
+ FloatType stpmax)
+ {
+ bool bound;
+ FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta;
+ int info = 0;
+ if ( ( brackt && (stp <= std::min(stx, sty)
+ || stp >= std::max(stx, sty)))
+ || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) {
+ return 0;
+ }
+ // Determine if the derivatives have opposite sign.
+ sgnd = dp * (dx / abs(dx));
+ if (fp > fx) {
+ // First case. A higher function value.
+ // The minimum is bracketed. If the cubic step is closer
+ // to stx than the quadratic step, the cubic step is taken,
+ // else the average of the cubic and quadratic steps is taken.
+ info = 1;
+ bound = true;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
+ if (stp < stx) gamma = - gamma;
+ p = (gamma - dx) + theta;
+ q = ((gamma - dx) + gamma) + dp;
+ r = p/q;
+ stpc = stx + r * (stp - stx);
+ stpq = stx
+ + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2))
+ * (stp - stx);
+ if (abs(stpc - stx) < abs(stpq - stx)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpc + (stpq - stpc) / FloatType(2);
+ }
+ brackt = true;
+ }
+ else if (sgnd < FloatType(0)) {
+ // Second case. A lower function value and derivatives of
+ // opposite sign. The minimum is bracketed. If the cubic
+ // step is closer to stx than the quadratic (secant) step,
+ // the cubic step is taken, else the quadratic step is taken.
+ info = 2;
+ bound = false;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
+ if (stp > stx) gamma = - gamma;
+ p = (gamma - dp) + theta;
+ q = ((gamma - dp) + gamma) + dx;
+ r = p/q;
+ stpc = stp + r * (stx - stp);
+ stpq = stp + (dp / (dp - dx)) * (stx - stp);
+ if (abs(stpc - stp) > abs(stpq - stp)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ brackt = true;
+ }
+ else if (abs(dp) < abs(dx)) {
+ // Third case. A lower function value, derivatives of the
+ // same sign, and the magnitude of the derivative decreases.
+ // The cubic step is only used if the cubic tends to infinity
+ // in the direction of the step or if the minimum of the cubic
+ // is beyond stp. Otherwise the cubic step is defined to be
+ // either stpmin or stpmax. The quadratic (secant) step is also
+ // computed and if the minimum is bracketed then the the step
+ // closest to stx is taken, else the step farthest away is taken.
+ info = 3;
+ bound = true;
+ theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
+ s = max3(abs(theta), abs(dx), abs(dp));
+ gamma = s * std::sqrt(
+ std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s)));
+ if (stp > stx) gamma = -gamma;
+ p = (gamma - dp) + theta;
+ q = (gamma + (dx - dp)) + gamma;
+ r = p/q;
+ if (r < FloatType(0) && gamma != FloatType(0)) {
+ stpc = stp + r * (stx - stp);
+ }
+ else if (stp > stx) {
+ stpc = stpmax;
+ }
+ else {
+ stpc = stpmin;
+ }
+ stpq = stp + (dp / (dp - dx)) * (stx - stp);
+ if (brackt) {
+ if (abs(stp - stpc) < abs(stp - stpq)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ }
+ else {
+ if (abs(stp - stpc) > abs(stp - stpq)) {
+ stpf = stpc;
+ }
+ else {
+ stpf = stpq;
+ }
+ }
+ }
+ else {
+ // Fourth case. A lower function value, derivatives of the
+ // same sign, and the magnitude of the derivative does
+ // not decrease. If the minimum is not bracketed, the step
+ // is either stpmin or stpmax, else the cubic step is taken.
+ info = 4;
+ bound = false;
+ if (brackt) {
+ theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp;
+ s = max3(abs(theta), abs(dy), abs(dp));
+ gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s));
+ if (stp > sty) gamma = -gamma;
+ p = (gamma - dp) + theta;
+ q = ((gamma - dp) + gamma) + dy;
+ r = p/q;
+ stpc = stp + r * (sty - stp);
+ stpf = stpc;
+ }
+ else if (stp > stx) {
+ stpf = stpmax;
+ }
+ else {
+ stpf = stpmin;
+ }
+ }
+ // Update the interval of uncertainty. This update does not
+ // depend on the new step or the case analysis above.
+ if (fp > fx) {
+ sty = stp;
+ fy = fp;
+ dy = dp;
+ }
+ else {
+ if (sgnd < FloatType(0)) {
+ sty = stx;
+ fy = fx;
+ dy = dx;
+ }
+ stx = stp;
+ fx = fp;
+ dx = dp;
+ }
+ // Compute the new step and safeguard it.
+ stpf = std::min(stpmax, stpf);
+ stpf = std::max(stpmin, stpf);
+ stp = stpf;
+ if (brackt && bound) {
+ if (sty > stx) {
+ stp = std::min(stx + FloatType(0.66) * (sty - stx), stp);
+ }
+ else {
+ stp = std::max(stx + FloatType(0.66) * (sty - stx), stp);
+ }
+ }
+ return info;
+ }
+
+ /* Compute the sum of a vector times a scalar plus another vector.
+ Adapted from the subroutine <code>daxpy</code> in
+ <code>lbfgs.f</code>.
+ */
+ template <typename FloatType, typename SizeType>
+ void daxpy(
+ SizeType n,
+ FloatType da,
+ const FloatType* dx,
+ SizeType ix0,
+ SizeType incx,
+ FloatType* dy,
+ SizeType iy0,
+ SizeType incy)
+ {
+ SizeType i, ix, iy, m;
+ if (n == 0) return;
+ if (da == FloatType(0)) return;
+ if (!(incx == 1 && incy == 1)) {
+ ix = 0;
+ iy = 0;
+ for (i = 0; i < n; i++) {
+ dy[iy0+iy] += da * dx[ix0+ix];
+ ix += incx;
+ iy += incy;
+ }
+ return;
+ }
+ m = n % 4;
+ for (i = 0; i < m; i++) {
+ dy[iy0+i] += da * dx[ix0+i];
+ }
+ for (; i < n;) {
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ dy[iy0+i] += da * dx[ix0+i]; i++;
+ }
+ }
+
+ template <typename FloatType, typename SizeType>
+ inline
+ void daxpy(
+ SizeType n,
+ FloatType da,
+ const FloatType* dx,
+ SizeType ix0,
+ FloatType* dy)
+ {
+ daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1));
+ }
+
+ /* Compute the dot product of two vectors.
+ Adapted from the subroutine <code>ddot</code>
+ in <code>lbfgs.f</code>.
+ */
+ template <typename FloatType, typename SizeType>
+ FloatType ddot(
+ SizeType n,
+ const FloatType* dx,
+ SizeType ix0,
+ SizeType incx,
+ const FloatType* dy,
+ SizeType iy0,
+ SizeType incy)
+ {
+ SizeType i, ix, iy, m;
+ FloatType dtemp(0);
+ if (n == 0) return FloatType(0);
+ if (!(incx == 1 && incy == 1)) {
+ ix = 0;
+ iy = 0;
+ for (i = 0; i < n; i++) {
+ dtemp += dx[ix0+ix] * dy[iy0+iy];
+ ix += incx;
+ iy += incy;
+ }
+ return dtemp;
+ }
+ m = n % 5;
+ for (i = 0; i < m; i++) {
+ dtemp += dx[ix0+i] * dy[iy0+i];
+ }
+ for (; i < n;) {
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ dtemp += dx[ix0+i] * dy[iy0+i]; i++;
+ }
+ return dtemp;
+ }
+
+ template <typename FloatType, typename SizeType>
+ inline
+ FloatType ddot(
+ SizeType n,
+ const FloatType* dx,
+ const FloatType* dy)
+ {
+ return ddot(
+ n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1));
+ }
+
+ } // namespace detail
+
+ //! Interface to the LBFGS %minimizer.
+ /*! This class solves the unconstrained minimization problem
+ <pre>
+ min f(x), x = (x1,x2,...,x_n),
+ </pre>
+ using the limited-memory BFGS method. The routine is
+ especially effective on problems involving a large number of
+ variables. In a typical iteration of this method an
+ approximation Hk to the inverse of the Hessian
+ is obtained by applying <code>m</code> BFGS updates to a
+ diagonal matrix Hk0, using information from the
+ previous <code>m</code> steps. The user specifies the number
+ <code>m</code>, which determines the amount of storage
+ required by the routine. The user may also provide the
+ diagonal matrices Hk0 (parameter <code>diag</code> in the run()
+ function) if not satisfied with the default choice. The
+ algorithm is described in "On the limited memory BFGS method for
+ large scale optimization", by D. Liu and J. Nocedal, Mathematical
+ Programming B 45 (1989) 503-528.
+
+ The user is required to calculate the function value
+ <code>f</code> and its gradient <code>g</code>. In order to
+ allow the user complete control over these computations,
+ reverse communication is used. The routine must be called
+ repeatedly under the control of the member functions
+ <code>requests_f_and_g()</code>,
+ <code>requests_diag()</code>.
+ If neither requests_f_and_g() nor requests_diag() is
+ <code>true</code> the user should check for convergence
+ (using class traditional_convergence_test or any
+ other custom test). If the convergence test is negative,
+ the minimizer may be called again for the next iteration.
+
+ The steplength (stp()) is determined at each iteration
+ by means of the line search routine <code>mcsrch</code>, which is
+ a slight modification of the routine <code>CSRCH</code> written
+ by More' and Thuente.
+
+ The only variables that are machine-dependent are
+ <code>xtol</code>,
+ <code>stpmin</code> and
+ <code>stpmax</code>.
+
+ Fatal errors cause <code>error</code> exceptions to be thrown.
+ The generic class <code>error</code> is sub-classed (e.g.
+ class <code>error_line_search_failed</code>) to facilitate
+ granular %error handling.
+
+ A note on performance: Using Compaq Fortran V5.4 and
+ Compaq C++ V6.5, the C++ implementation is about 15% slower
+ than the Fortran implementation.
+ */
+ template <typename FloatType, typename SizeType = std::size_t>
+ class minimizer
+ {
+ public:
+ //! Default constructor. Some members are not initialized!
+ minimizer()
+ : n_(0), m_(0), maxfev_(0),
+ gtol_(0), xtol_(0),
+ stpmin_(0), stpmax_(0),
+ ispt(0), iypt(0)
+ {}
+
+ //! Constructor.
+ /*! @param n The number of variables in the minimization problem.
+ Restriction: <code>n &gt; 0</code>.
+
+ @param m The number of corrections used in the BFGS update.
+ Values of <code>m</code> less than 3 are not recommended;
+ large values of <code>m</code> will result in excessive
+ computing time. <code>3 &lt;= m &lt;= 7</code> is
+ recommended.
+ Restriction: <code>m &gt; 0</code>.
+
+ @param maxfev Maximum number of function evaluations
+ <b>per line search</b>.
+ Termination occurs when the number of evaluations
+ of the objective function is at least <code>maxfev</code> by
+ the end of an iteration.
+
+ @param gtol Controls the accuracy of the line search.
+ If the function and gradient evaluations are inexpensive with
+ respect to the cost of the iteration (which is sometimes the
+ case when solving very large problems) it may be advantageous
+ to set <code>gtol</code> to a small value. A typical small
+ value is 0.1.
+ Restriction: <code>gtol</code> should be greater than 1e-4.
+
+ @param xtol An estimate of the machine precision (e.g. 10e-16
+ on a SUN station 3/60). The line search routine will
+ terminate if the relative width of the interval of
+ uncertainty is less than <code>xtol</code>.
+
+ @param stpmin Specifies the lower bound for the step
+ in the line search.
+ The default value is 1e-20. This value need not be modified
+ unless the exponent is too large for the machine being used,
+ or unless the problem is extremely badly scaled (in which
+ case the exponent should be increased).
+
+ @param stpmax specifies the upper bound for the step
+ in the line search.
+ The default value is 1e20. This value need not be modified
+ unless the exponent is too large for the machine being used,
+ or unless the problem is extremely badly scaled (in which
+ case the exponent should be increased).
+ */
+ explicit
+ minimizer(
+ SizeType n,
+ SizeType m = 5,
+ SizeType maxfev = 20,
+ FloatType gtol = FloatType(0.9),
+ FloatType xtol = FloatType(1.e-16),
+ FloatType stpmin = FloatType(1.e-20),
+ FloatType stpmax = FloatType(1.e20))
+ : n_(n), m_(m), maxfev_(maxfev),
+ gtol_(gtol), xtol_(xtol),
+ stpmin_(stpmin), stpmax_(stpmax),
+ iflag_(0), requests_f_and_g_(false), requests_diag_(false),
+ iter_(0), nfun_(0), stp_(0),
+ stp1(0), ftol(0.0001), ys(0), point(0), npt(0),
+ ispt(n+2*m), iypt((n+2*m)+n*m),
+ info(0), bound(0), nfev(0)
+ {
+ if (n_ == 0) {
+ throw error_improper_input_parameter("n = 0.");
+ }
+ if (m_ == 0) {
+ throw error_improper_input_parameter("m = 0.");
+ }
+ if (maxfev_ == 0) {
+ throw error_improper_input_parameter("maxfev = 0.");
+ }
+ if (gtol_ <= FloatType(1.e-4)) {
+ throw error_improper_input_parameter("gtol <= 1.e-4.");
+ }
+ if (xtol_ < FloatType(0)) {
+ throw error_improper_input_parameter("xtol < 0.");
+ }
+ if (stpmin_ < FloatType(0)) {
+ throw error_improper_input_parameter("stpmin < 0.");
+ }
+ if (stpmax_ < stpmin) {
+ throw error_improper_input_parameter("stpmax < stpmin");
+ }
+ w_.resize(n_*(2*m_+1)+2*m_);
+ scratch_array_.resize(n_);
+ }
+
+ //! Number of free parameters (as passed to the constructor).
+ SizeType n() const { return n_; }
+
+ //! Number of corrections kept (as passed to the constructor).
+ SizeType m() const { return m_; }
+
+ /*! \brief Maximum number of evaluations of the objective function
+ per line search (as passed to the constructor).
+ */
+ SizeType maxfev() const { return maxfev_; }
+
+ /*! \brief Control of the accuracy of the line search.
+ (as passed to the constructor).
+ */
+ FloatType gtol() const { return gtol_; }
+
+ //! Estimate of the machine precision (as passed to the constructor).
+ FloatType xtol() const { return xtol_; }
+
+ /*! \brief Lower bound for the step in the line search.
+ (as passed to the constructor).
+ */
+ FloatType stpmin() const { return stpmin_; }
+
+ /*! \brief Upper bound for the step in the line search.
+ (as passed to the constructor).
+ */
+ FloatType stpmax() const { return stpmax_; }
+
+ //! Status indicator for reverse communication.
+ /*! <code>true</code> if the run() function returns to request
+ evaluation of the objective function (<code>f</code>) and
+ gradients (<code>g</code>) for the current point
+ (<code>x</code>). To continue the minimization the
+ run() function is called again with the updated values for
+ <code>f</code> and <code>g</code>.
+ <p>
+ See also: requests_diag()
+ */
+ bool requests_f_and_g() const { return requests_f_and_g_; }
+
+ //! Status indicator for reverse communication.
+ /*! <code>true</code> if the run() function returns to request
+ evaluation of the diagonal matrix (<code>diag</code>)
+ for the current point (<code>x</code>).
+ To continue the minimization the run() function is called
+ again with the updated values for <code>diag</code>.
+ <p>
+ See also: requests_f_and_g()
+ */
+ bool requests_diag() const { return requests_diag_; }
+
+ //! Number of iterations so far.
+ /*! Note that one iteration may involve multiple evaluations
+ of the objective function.
+ <p>
+ See also: nfun()
+ */
+ SizeType iter() const { return iter_; }
+
+ //! Total number of evaluations of the objective function so far.
+ /*! The total number of function evaluations increases by the
+ number of evaluations required for the line search. The total
+ is only increased after a successful line search.
+ <p>
+ See also: iter()
+ */
+ SizeType nfun() const { return nfun_; }
+
+ //! Norm of gradient given gradient array of length n().
+ FloatType euclidean_norm(const FloatType* a) const {
+ return std::sqrt(detail::ddot(n_, a, a));
+ }
+
+ //! Current stepsize.
+ FloatType stp() const { return stp_; }
+
+ //! Execution of one step of the minimization.
+ /*! @param x On initial entry this must be set by the user to
+ the values of the initial estimate of the solution vector.
+
+ @param f Before initial entry or on re-entry under the
+ control of requests_f_and_g(), <code>f</code> must be set
+ by the user to contain the value of the objective function
+ at the current point <code>x</code>.
+
+ @param g Before initial entry or on re-entry under the
+ control of requests_f_and_g(), <code>g</code> must be set
+ by the user to contain the components of the gradient at
+ the current point <code>x</code>.
+
+ The return value is <code>true</code> if either
+ requests_f_and_g() or requests_diag() is <code>true</code>.
+ Otherwise the user should check for convergence
+ (e.g. using class traditional_convergence_test) and
+ call the run() function again to continue the minimization.
+ If the return value is <code>false</code> the user
+ should <b>not</b> update <code>f</code>, <code>g</code> or
+ <code>diag</code> (other overload) before calling
+ the run() function again.
+
+ Note that <code>x</code> is always modified by the run()
+ function. Depending on the situation it can therefore be
+ necessary to evaluate the objective function one more time
+ after the minimization is terminated.
+ */
+ bool run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g)
+ {
+ return generic_run(x, f, g, false, 0);
+ }
+
+ //! Execution of one step of the minimization.
+ /*! @param x See other overload.
+
+ @param f See other overload.
+
+ @param g See other overload.
+
+ @param diag On initial entry or on re-entry under the
+ control of requests_diag(), <code>diag</code> must be set by
+ the user to contain the values of the diagonal matrix Hk0.
+ The routine will return at each iteration of the algorithm
+ with requests_diag() set to <code>true</code>.
+ <p>
+ Restriction: all elements of <code>diag</code> must be
+ positive.
+ */
+ bool run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ const FloatType* diag)
+ {
+ return generic_run(x, f, g, true, diag);
+ }
+
+ void serialize(std::ostream* out) const {
+ out->write((const char*)&n_, sizeof(n_)); // sanity check
+ out->write((const char*)&m_, sizeof(m_)); // sanity check
+ SizeType fs = sizeof(FloatType);
+ out->write((const char*)&fs, sizeof(fs)); // sanity check
+
+ mcsrch_instance.serialize(out);
+ out->write((const char*)&iflag_, sizeof(iflag_));
+ out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_));
+ out->write((const char*)&requests_diag_, sizeof(requests_diag_));
+ out->write((const char*)&iter_, sizeof(iter_));
+ out->write((const char*)&nfun_, sizeof(nfun_));
+ out->write((const char*)&stp_, sizeof(stp_));
+ out->write((const char*)&stp1, sizeof(stp1));
+ out->write((const char*)&ftol, sizeof(ftol));
+ out->write((const char*)&ys, sizeof(ys));
+ out->write((const char*)&point, sizeof(point));
+ out->write((const char*)&npt, sizeof(npt));
+ out->write((const char*)&info, sizeof(info));
+ out->write((const char*)&bound, sizeof(bound));
+ out->write((const char*)&nfev, sizeof(nfev));
+ out->write((const char*)&w_[0], sizeof(FloatType) * w_.size());
+ out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size());
+ }
+
+ void deserialize(std::istream* in) {
+ SizeType n, m, fs;
+ in->read((char*)&n, sizeof(n));
+ in->read((char*)&m, sizeof(m));
+ in->read((char*)&fs, sizeof(fs));
+ assert(n == n_);
+ assert(m == m_);
+ assert(fs == sizeof(FloatType));
+
+ mcsrch_instance.deserialize(in);
+ in->read((char*)&iflag_, sizeof(iflag_));
+ in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_));
+ in->read((char*)&requests_diag_, sizeof(requests_diag_));
+ in->read((char*)&iter_, sizeof(iter_));
+ in->read((char*)&nfun_, sizeof(nfun_));
+ in->read((char*)&stp_, sizeof(stp_));
+ in->read((char*)&stp1, sizeof(stp1));
+ in->read((char*)&ftol, sizeof(ftol));
+ in->read((char*)&ys, sizeof(ys));
+ in->read((char*)&point, sizeof(point));
+ in->read((char*)&npt, sizeof(npt));
+ in->read((char*)&info, sizeof(info));
+ in->read((char*)&bound, sizeof(bound));
+ in->read((char*)&nfev, sizeof(nfev));
+ in->read((char*)&w_[0], sizeof(FloatType) * w_.size());
+ in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size());
+ }
+
+ protected:
+ static void throw_diagonal_element_not_positive(SizeType i) {
+ throw error_improper_input_data(
+ "The " + error::itoa(i) + ". diagonal element of the"
+ " inverse Hessian approximation is not positive.");
+ }
+
+ bool generic_run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ bool diagco,
+ const FloatType* diag);
+
+ detail::mcsrch<FloatType, SizeType> mcsrch_instance;
+ const SizeType n_;
+ const SizeType m_;
+ const SizeType maxfev_;
+ const FloatType gtol_;
+ const FloatType xtol_;
+ const FloatType stpmin_;
+ const FloatType stpmax_;
+ int iflag_;
+ bool requests_f_and_g_;
+ bool requests_diag_;
+ SizeType iter_;
+ SizeType nfun_;
+ FloatType stp_;
+ FloatType stp1;
+ FloatType ftol;
+ FloatType ys;
+ SizeType point;
+ SizeType npt;
+ const SizeType ispt;
+ const SizeType iypt;
+ int info;
+ SizeType bound;
+ SizeType nfev;
+ std::vector<FloatType> w_;
+ std::vector<FloatType> scratch_array_;
+ };
+
+ template <typename FloatType, typename SizeType>
+ bool minimizer<FloatType, SizeType>::generic_run(
+ FloatType* x,
+ FloatType f,
+ const FloatType* g,
+ bool diagco,
+ const FloatType* diag)
+ {
+ bool execute_entire_while_loop = false;
+ if (!(requests_f_and_g_ || requests_diag_)) {
+ execute_entire_while_loop = true;
+ }
+ requests_f_and_g_ = false;
+ requests_diag_ = false;
+ FloatType* w = &(*(w_.begin()));
+ if (iflag_ == 0) { // Initialize.
+ nfun_ = 1;
+ if (diagco) {
+ for (SizeType i = 0; i < n_; i++) {
+ if (diag[i] <= FloatType(0)) {
+ throw_diagonal_element_not_positive(i);
+ }
+ }
+ }
+ else {
+ std::fill_n(scratch_array_.begin(), n_, FloatType(1));
+ diag = &(*(scratch_array_.begin()));
+ }
+ for (SizeType i = 0; i < n_; i++) {
+ w[ispt + i] = -g[i] * diag[i];
+ }
+ FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
+ if (gnorm == FloatType(0)) return false;
+ stp1 = FloatType(1) / gnorm;
+ execute_entire_while_loop = true;
+ }
+ if (execute_entire_while_loop) {
+ bound = iter_;
+ iter_++;
+ info = 0;
+ if (iter_ != 1) {
+ if (iter_ > m_) bound = m_;
+ ys = detail::ddot(
+ n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1));
+ if (!diagco) {
+ FloatType yy = detail::ddot(
+ n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1));
+ std::fill_n(scratch_array_.begin(), n_, ys / yy);
+ diag = &(*(scratch_array_.begin()));
+ }
+ else {
+ iflag_ = 2;
+ requests_diag_ = true;
+ return true;
+ }
+ }
+ }
+ if (execute_entire_while_loop || iflag_ == 2) {
+ if (iter_ != 1) {
+ if (diag == 0) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ if (diagco) {
+ for (SizeType i = 0; i < n_; i++) {
+ if (diag[i] <= FloatType(0)) {
+ throw_diagonal_element_not_positive(i);
+ }
+ }
+ }
+ SizeType cp = point;
+ if (point == 0) cp = m_;
+ w[n_ + cp -1] = 1 / ys;
+ SizeType i;
+ for (i = 0; i < n_; i++) {
+ w[i] = -g[i];
+ }
+ cp = point;
+ for (i = 0; i < bound; i++) {
+ if (cp == 0) cp = m_;
+ cp--;
+ FloatType sq = detail::ddot(
+ n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
+ SizeType inmc=n_+m_+cp;
+ SizeType iycn=iypt+cp*n_;
+ w[inmc] = w[n_ + cp] * sq;
+ detail::daxpy(n_, -w[inmc], w, iycn, w);
+ }
+ for (i = 0; i < n_; i++) {
+ w[i] *= diag[i];
+ }
+ for (i = 0; i < bound; i++) {
+ FloatType yr = detail::ddot(
+ n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
+ FloatType beta = w[n_ + cp] * yr;
+ SizeType inmc=n_+m_+cp;
+ beta = w[inmc] - beta;
+ SizeType iscn=ispt+cp*n_;
+ detail::daxpy(n_, beta, w, iscn, w);
+ cp++;
+ if (cp == m_) cp = 0;
+ }
+ std::copy(w, w+n_, w+(ispt + point * n_));
+ }
+ stp_ = FloatType(1);
+ if (iter_ == 1) stp_ = stp1;
+ std::copy(g, g+n_, w);
+ }
+ mcsrch_instance.run(
+ gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_,
+ stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin())));
+ if (info == -1) {
+ iflag_ = 1;
+ requests_f_and_g_ = true;
+ return true;
+ }
+ if (info != 1) {
+ throw error_internal_error(__FILE__, __LINE__);
+ }
+ nfun_ += nfev;
+ npt = point*n_;
+ for (SizeType i = 0; i < n_; i++) {
+ w[ispt + npt + i] = stp_ * w[ispt + npt + i];
+ w[iypt + npt + i] = g[i] - w[i];
+ }
+ point++;
+ if (point == m_) point = 0;
+ return false;
+ }
+
+ //! Traditional LBFGS convergence test.
+ /*! This convergence test is equivalent to the test embedded
+ in the <code>lbfgs.f</code> Fortran code. The test assumes that
+ there is a meaningful relation between the Euclidean norm of the
+ parameter vector <code>x</code> and the norm of the gradient
+ vector <code>g</code>. Therefore this test should not be used if
+ this assumption is not correct for a given problem.
+ */
+ template <typename FloatType, typename SizeType = std::size_t>
+ class traditional_convergence_test
+ {
+ public:
+ //! Default constructor.
+ traditional_convergence_test()
+ : n_(0), eps_(0)
+ {}
+
+ //! Constructor.
+ /*! @param n The number of variables in the minimization problem.
+ Restriction: <code>n &gt; 0</code>.
+
+ @param eps Determines the accuracy with which the solution
+ is to be found.
+ */
+ explicit
+ traditional_convergence_test(
+ SizeType n,
+ FloatType eps = FloatType(1.e-5))
+ : n_(n), eps_(eps)
+ {
+ if (n_ == 0) {
+ throw error_improper_input_parameter("n = 0.");
+ }
+ if (eps_ < FloatType(0)) {
+ throw error_improper_input_parameter("eps < 0.");
+ }
+ }
+
+ //! Number of free parameters (as passed to the constructor).
+ SizeType n() const { return n_; }
+
+ /*! \brief Accuracy with which the solution is to be found
+ (as passed to the constructor).
+ */
+ FloatType eps() const { return eps_; }
+
+ //! Execution of the convergence test for the given parameters.
+ /*! Returns <code>true</code> if
+ <pre>
+ ||g|| &lt; eps * max(1,||x||),
+ </pre>
+ where <code>||.||</code> denotes the Euclidean norm.
+
+ @param x Current solution vector.
+
+ @param g Components of the gradient at the current
+ point <code>x</code>.
+ */
+ bool
+ operator()(const FloatType* x, const FloatType* g) const
+ {
+ FloatType xnorm = std::sqrt(detail::ddot(n_, x, x));
+ FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
+ if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true;
+ return false;
+ }
+ protected:
+ const SizeType n_;
+ const FloatType eps_;
+ };
+
+}} // namespace scitbx::lbfgs
+
+template <typename T>
+std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer<T>& min) {
+ return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g();
+}
+
+
+#endif // SCITBX_LBFGS_H
diff --git a/training/utils/lbfgs_test.cc b/training/utils/lbfgs_test.cc
new file mode 100644
index 00000000..9678e788
--- /dev/null
+++ b/training/utils/lbfgs_test.cc
@@ -0,0 +1,117 @@
+#include <cassert>
+#include <iostream>
+#include <sstream>
+#include <cmath>
+#include "lbfgs.h"
+#include "sparse_vector.h"
+#include "fdict.h"
+
+using namespace std;
+
+double TestOptimizer() {
+ cerr << "TESTING NON-PERSISTENT OPTIMIZER\n";
+
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ double x[3];
+ double g[3];
+ scitbx::lbfgs::minimizer<double> opt(3);
+ scitbx::lbfgs::traditional_convergence_test<double> converged(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+ opt.run(x, obj, g);
+ if (!opt.requests_f_and_g()) {
+ if (converged(x,g)) break;
+ opt.run(x, obj, g);
+ }
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ cerr << opt << endl;
+ } while (true);
+ return obj;
+}
+
+double TestPersistentOptimizer() {
+ cerr << "\nTESTING PERSISTENT OPTIMIZER\n";
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ double x[3];
+ double g[3];
+ scitbx::lbfgs::traditional_convergence_test<double> converged(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ string state;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+
+ {
+ scitbx::lbfgs::minimizer<double> opt(3);
+ if (state.size() > 0) {
+ istringstream is(state, ios::binary);
+ opt.deserialize(&is);
+ }
+ opt.run(x, obj, g);
+ ostringstream os(ios::binary); opt.serialize(&os); state = os.str();
+ }
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ } while (!converged(x, g));
+ return obj;
+}
+
+void TestSparseVector() {
+ cerr << "Testing SparseVector<double> serialization.\n";
+ int f1 = FD::Convert("Feature_1");
+ int f2 = FD::Convert("Feature_2");
+ FD::Convert("LanguageModel");
+ int f4 = FD::Convert("SomeFeature");
+ int f5 = FD::Convert("SomeOtherFeature");
+ SparseVector<double> g;
+ g.set_value(f2, log(0.5));
+ g.set_value(f4, log(0.125));
+ g.set_value(f1, 0);
+ g.set_value(f5, 23.777);
+ ostringstream os;
+ double iobj = 1.5;
+ B64::Encode(iobj, g, &os);
+ cerr << iobj << "\t" << g << endl;
+ string data = os.str();
+ cout << data << endl;
+ SparseVector<double> v;
+ double obj;
+ bool decode_b64 = B64::Decode(&obj, &v, &data[0], data.size());
+ cerr << obj << "\t" << v << endl;
+ assert(decode_b64);
+ assert(obj == iobj);
+ assert(g.size() == v.size());
+}
+
+int main() {
+ double o1 = TestOptimizer();
+ double o2 = TestPersistentOptimizer();
+ if (fabs(o1 - o2) > 1e-5) {
+ cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl;
+ return 1;
+ }
+ TestSparseVector();
+ cerr << "SUCCESS\n";
+ return 0;
+}
+
diff --git a/training/utils/libcall.pl b/training/utils/libcall.pl
new file mode 100644
index 00000000..c7d0f128
--- /dev/null
+++ b/training/utils/libcall.pl
@@ -0,0 +1,71 @@
+use IPC::Open3;
+use Symbol qw(gensym);
+
+$DUMMY_STDERR = gensym();
+$DUMMY_STDIN = gensym();
+
+# Run the command and ignore failures
+sub unchecked_call {
+ system("@_")
+}
+
+# Run the command and return its output, if any ignoring failures
+sub unchecked_output {
+ return `@_`
+}
+
+# WARNING: Do not use this for commands that will return large amounts
+# of stdout or stderr -- they might block indefinitely
+sub check_output {
+ print STDERR "Executing and gathering output: @_\n";
+
+ my $pid = open3($DUMMY_STDIN, \*PH, $DUMMY_STDERR, @_);
+ my $proc_output = "";
+ while( <PH> ) {
+ $proc_output .= $_;
+ }
+ waitpid($pid, 0);
+ # TODO: Grab signal that the process died from
+ my $child_exit_status = $? >> 8;
+ if($child_exit_status == 0) {
+ return $proc_output;
+ } else {
+ print STDERR "ERROR: Execution of @_ failed.\n";
+ exit(1);
+ }
+}
+
+# Based on Moses' safesystem sub
+sub check_call {
+ print STDERR "Executing: @_\n";
+ system(@_);
+ my $exitcode = $? >> 8;
+ if($exitcode == 0) {
+ return 0;
+ } elsif ($? == -1) {
+ print STDERR "ERROR: Failed to execute: @_\n $!\n";
+ exit(1);
+
+ } elsif ($? & 127) {
+ printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n",
+ ($? & 127), ($? & 128) ? 'with' : 'without';
+ exit(1);
+
+ } else {
+ print STDERR "Failed with exit code: $exitcode\n" if $exitcode;
+ exit($exitcode);
+ }
+}
+
+sub check_bash_call {
+ my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_");
+ check_call(@args);
+}
+
+sub check_bash_output {
+ my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_");
+ return check_output(@args);
+}
+
+# perl module weirdness...
+return 1;
diff --git a/training/utils/online_optimizer.cc b/training/utils/online_optimizer.cc
new file mode 100644
index 00000000..3ed95452
--- /dev/null
+++ b/training/utils/online_optimizer.cc
@@ -0,0 +1,16 @@
+#include "online_optimizer.h"
+
+LearningRateSchedule::~LearningRateSchedule() {}
+
+double StandardLearningRate::eta(int k) const {
+ return eta_0_ / (1.0 + k / N_);
+}
+
+double ExponentialDecayLearningRate::eta(int k) const {
+ return eta_0_ * pow(alpha_, k / N_);
+}
+
+OnlineOptimizer::~OnlineOptimizer() {}
+
+void OnlineOptimizer::ResetEpochImpl() {}
+
diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h
new file mode 100644
index 00000000..28d89344
--- /dev/null
+++ b/training/utils/online_optimizer.h
@@ -0,0 +1,129 @@
+#ifndef _ONL_OPTIMIZE_H_
+#define _ONL_OPTIMIZE_H_
+
+#include <tr1/memory>
+#include <set>
+#include <string>
+#include <cmath>
+#include "sparse_vector.h"
+
+struct LearningRateSchedule {
+ virtual ~LearningRateSchedule();
+ // returns the learning rate for the kth iteration
+ virtual double eta(int k) const = 0;
+};
+
+// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N
+// to mean the batch size in most places, but it doesn't completely
+// make sense to me in the learning rate schedules-- this needs
+// to be worked out to make sure they didn't mean corpus size
+// in some places and batch size in others (since in the paper they
+// only ever work with batch sizes of 1)
+struct StandardLearningRate : public LearningRateSchedule {
+ StandardLearningRate(
+ size_t batch_size, // batch size, not corpus size!
+ double eta_0 = 0.2) :
+ eta_0_(eta_0),
+ N_(static_cast<double>(batch_size)) {}
+
+ virtual double eta(int k) const;
+
+ private:
+ const double eta_0_;
+ const double N_;
+};
+
+struct ExponentialDecayLearningRate : public LearningRateSchedule {
+ ExponentialDecayLearningRate(
+ size_t batch_size, // batch size, not corpus size!
+ double eta_0 = 0.2,
+ double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009)
+ ) : eta_0_(eta_0),
+ N_(static_cast<double>(batch_size)),
+ alpha_(alpha) {
+ assert(alpha > 0);
+ assert(alpha < 1.0);
+ }
+
+ virtual double eta(int k) const;
+
+ private:
+ const double eta_0_;
+ const double N_;
+ const double alpha_;
+};
+
+class OnlineOptimizer {
+ public:
+ virtual ~OnlineOptimizer();
+ OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ size_t batch_size,
+ const std::vector<int>& frozen_feats = std::vector<int>())
+ : N_(batch_size),schedule_(s),k_() {
+ for (int i = 0; i < frozen_feats.size(); ++i)
+ frozen_.insert(frozen_feats[i]);
+ }
+ void ResetEpoch() { k_ = 0; ResetEpochImpl(); }
+ void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
+ ++k_;
+ const double eta = schedule_->eta(k_);
+ UpdateWeightsImpl(eta, approx_g, max_feat, weights);
+ }
+
+ protected:
+ virtual void ResetEpochImpl();
+ virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
+ const size_t N_; // number of training instances per batch
+ std::set<int> frozen_; // frozen (non-optimizing) features
+
+ private:
+ std::tr1::shared_ptr<LearningRateSchedule> schedule_;
+ int k_; // iteration count
+};
+
+class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
+ public:
+ CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ size_t training_instances, double C,
+ const std::vector<int>& frozen) :
+ OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {}
+
+ protected:
+ void ResetEpochImpl() { u_ = 0; }
+ void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
+ u_ += eta * C_ / N_;
+ for (SparseVector<double>::const_iterator it = approx_g.begin();
+ it != approx_g.end(); ++it) {
+ if (frozen_.count(it->first) == 0)
+ weights->add_value(it->first, eta * it->second);
+ }
+ for (int i = 1; i < max_feat; ++i)
+ if (frozen_.count(i) == 0) ApplyPenalty(i, weights);
+ }
+
+ private:
+ void ApplyPenalty(int i, SparseVector<double>* w) {
+ const double z = w->value(i);
+ double w_i = z;
+ double q_i = q_.value(i);
+ if (w_i > 0.0)
+ w_i = std::max(0.0, w_i - (u_ + q_i));
+ else if (w_i < 0.0)
+ w_i = std::min(0.0, w_i + (u_ - q_i));
+ q_i += w_i - z;
+ if (q_i == 0.0)
+ q_.erase(i);
+ else
+ q_.set_value(i, q_i);
+ if (w_i == 0.0)
+ w->erase(i);
+ else
+ w->set_value(i, w_i);
+ }
+
+ const double C_; // reguarlization strength
+ double u_;
+ SparseVector<double> q_;
+};
+
+#endif
diff --git a/training/utils/optimize.cc b/training/utils/optimize.cc
new file mode 100644
index 00000000..41ac90d8
--- /dev/null
+++ b/training/utils/optimize.cc
@@ -0,0 +1,102 @@
+#include "optimize.h"
+
+#include <iostream>
+#include <cassert>
+
+#include "lbfgs.h"
+
+using namespace std;
+
+BatchOptimizer::~BatchOptimizer() {}
+
+void BatchOptimizer::Save(ostream* out) const {
+ out->write((const char*)&eval_, sizeof(eval_));
+ out->write((const char*)&has_converged_, sizeof(has_converged_));
+ SaveImpl(out);
+ unsigned int magic = 0xABCDDCBA; // should be uint32_t
+ out->write((const char*)&magic, sizeof(magic));
+}
+
+void BatchOptimizer::Load(istream* in) {
+ in->read((char*)&eval_, sizeof(eval_));
+ in->read((char*)&has_converged_, sizeof(has_converged_));
+ LoadImpl(in);
+ unsigned int magic = 0; // should be uint32_t
+ in->read((char*)&magic, sizeof(magic));
+ assert(magic == 0xABCDDCBA);
+ cerr << Name() << " EVALUATION #" << eval_ << endl;
+}
+
+void BatchOptimizer::SaveImpl(ostream* out) const {
+ (void)out;
+}
+
+void BatchOptimizer::LoadImpl(istream* in) {
+ (void)in;
+}
+
+string RPropOptimizer::Name() const {
+ return "RPropOptimizer";
+}
+
+void RPropOptimizer::OptimizeImpl(const double& obj,
+ const vector<double>& g,
+ vector<double>* x) {
+ for (int i = 0; i < g.size(); ++i) {
+ const double g_i = g[i];
+ const double sign_i = (signbit(g_i) ? -1.0 : 1.0);
+ const double prod = g_i * prev_g_[i];
+ if (prod > 0.0) {
+ const double dij = min(delta_ij_[i] * eta_plus_, delta_max_);
+ (*x)[i] -= dij * sign_i;
+ delta_ij_[i] = dij;
+ prev_g_[i] = g_i;
+ } else if (prod < 0.0) {
+ delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_);
+ prev_g_[i] = 0.0;
+ } else {
+ (*x)[i] -= delta_ij_[i] * sign_i;
+ prev_g_[i] = g_i;
+ }
+ }
+}
+
+void RPropOptimizer::SaveImpl(ostream* out) const {
+ const size_t n = prev_g_.size();
+ out->write((const char*)&n, sizeof(n));
+ out->write((const char*)&prev_g_[0], sizeof(double) * n);
+ out->write((const char*)&delta_ij_[0], sizeof(double) * n);
+}
+
+void RPropOptimizer::LoadImpl(istream* in) {
+ size_t n;
+ in->read((char*)&n, sizeof(n));
+ assert(n == prev_g_.size());
+ assert(n == delta_ij_.size());
+ in->read((char*)&prev_g_[0], sizeof(double) * n);
+ in->read((char*)&delta_ij_[0], sizeof(double) * n);
+}
+
+string LBFGSOptimizer::Name() const {
+ return "LBFGSOptimizer";
+}
+
+LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) :
+ opt_(num_feats, memory_buffers) {}
+
+void LBFGSOptimizer::SaveImpl(ostream* out) const {
+ opt_.serialize(out);
+}
+
+void LBFGSOptimizer::LoadImpl(istream* in) {
+ opt_.deserialize(in);
+}
+
+void LBFGSOptimizer::OptimizeImpl(const double& obj,
+ const vector<double>& g,
+ vector<double>* x) {
+ opt_.run(&(*x)[0], obj, &g[0]);
+ if (!opt_.requests_f_and_g()) opt_.run(&(*x)[0], obj, &g[0]);
+ // cerr << opt_ << endl;
+}
+
diff --git a/training/utils/optimize.h b/training/utils/optimize.h
new file mode 100644
index 00000000..07943b44
--- /dev/null
+++ b/training/utils/optimize.h
@@ -0,0 +1,92 @@
+#ifndef _OPTIMIZE_H_
+#define _OPTIMIZE_H_
+
+#include <iostream>
+#include <vector>
+#include <string>
+#include <cassert>
+
+#include "lbfgs.h"
+
+// abstract base class for first order optimizers
+// order of invocation: new, Load(), Optimize(), Save(), delete
+class BatchOptimizer {
+ public:
+ BatchOptimizer() : eval_(1), has_converged_(false) {}
+ virtual ~BatchOptimizer();
+ virtual std::string Name() const = 0;
+ int EvaluationCount() const { return eval_; }
+ bool HasConverged() const { return has_converged_; }
+
+ void Optimize(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) {
+ assert(g.size() == x->size());
+ ++eval_;
+ OptimizeImpl(obj, g, x);
+ scitbx::lbfgs::traditional_convergence_test<double> converged(g.size());
+ has_converged_ = converged(&(*x)[0], &g[0]);
+ }
+
+ void Save(std::ostream* out) const;
+ void Load(std::istream* in);
+ protected:
+ virtual void SaveImpl(std::ostream* out) const;
+ virtual void LoadImpl(std::istream* in);
+ virtual void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x) = 0;
+
+ int eval_;
+ private:
+ bool has_converged_;
+};
+
+class RPropOptimizer : public BatchOptimizer {
+ public:
+ explicit RPropOptimizer(int num_vars,
+ double eta_plus = 1.2,
+ double eta_minus = 0.5,
+ double delta_0 = 0.1,
+ double delta_max = 50.0,
+ double delta_min = 1e-6) :
+ prev_g_(num_vars, 0.0),
+ delta_ij_(num_vars, delta_0),
+ eta_plus_(eta_plus),
+ eta_minus_(eta_minus),
+ delta_max_(delta_max),
+ delta_min_(delta_min) {
+ assert(eta_plus > 1.0);
+ assert(eta_minus > 0.0 && eta_minus < 1.0);
+ assert(delta_max > 0.0);
+ assert(delta_min > 0.0);
+ }
+ std::string Name() const;
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ private:
+ std::vector<double> prev_g_;
+ std::vector<double> delta_ij_;
+ const double eta_plus_;
+ const double eta_minus_;
+ const double delta_max_;
+ const double delta_min_;
+};
+
+class LBFGSOptimizer : public BatchOptimizer {
+ public:
+ explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10);
+ std::string Name() const;
+ void SaveImpl(std::ostream* out) const;
+ void LoadImpl(std::istream* in);
+ void OptimizeImpl(const double& obj,
+ const std::vector<double>& g,
+ std::vector<double>* x);
+ private:
+ scitbx::lbfgs::minimizer<double> opt_;
+};
+
+#endif
diff --git a/training/utils/optimize_test.cc b/training/utils/optimize_test.cc
new file mode 100644
index 00000000..bff2ca03
--- /dev/null
+++ b/training/utils/optimize_test.cc
@@ -0,0 +1,118 @@
+#include <cassert>
+#include <iostream>
+#include <sstream>
+#include <boost/program_options/variables_map.hpp>
+#include "optimize.h"
+#include "online_optimizer.h"
+#include "sparse_vector.h"
+#include "fdict.h"
+
+using namespace std;
+
+double TestOptimizer(BatchOptimizer* opt) {
+ cerr << "TESTING NON-PERSISTENT OPTIMIZER\n";
+
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ vector<double> x(3);
+ vector<double> g(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ do {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+ opt->Optimize(obj, g, &x);
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ } while (!opt->HasConverged());
+ return obj;
+}
+
+double TestPersistentOptimizer(BatchOptimizer* opt) {
+ cerr << "\nTESTING PERSISTENT OPTIMIZER\n";
+ // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5
+ // df/dx1 = 8*x1 + x2
+ // df/dx2 = 2*x2 + x1
+ // df/dx3 = 2*x3 + 6
+ vector<double> x(3);
+ vector<double> g(3);
+ x[0] = 8;
+ x[1] = 8;
+ x[2] = 8;
+ double obj = 0;
+ string state;
+ bool converged = false;
+ while (!converged) {
+ g[0] = 8 * x[0] + x[1];
+ g[1] = 2 * x[1] + x[0];
+ g[2] = 2 * x[2] + 6;
+ obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5;
+
+ {
+ if (state.size() > 0) {
+ istringstream is(state, ios::binary);
+ opt->Load(&is);
+ }
+ opt->Optimize(obj, g, &x);
+ ostringstream os(ios::binary); opt->Save(&os); state = os.str();
+
+ }
+
+ cerr << x[0] << " " << x[1] << " " << x[2] << endl;
+ cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl;
+ converged = opt->HasConverged();
+ if (!converged) {
+ // now screw up the state (should be undone by Load)
+ obj += 2.0;
+ g[1] = -g[2];
+ vector<double> x2 = x;
+ try {
+ opt->Optimize(obj, g, &x2);
+ } catch (...) { }
+ }
+ }
+ return obj;
+}
+
+template <class O>
+void TestOptimizerVariants(int num_vars) {
+ O oa(num_vars);
+ cerr << "-------------------------------------------------------------------------\n";
+ cerr << "TESTING: " << oa.Name() << endl;
+ double o1 = TestOptimizer(&oa);
+ O ob(num_vars);
+ double o2 = TestPersistentOptimizer(&ob);
+ if (o1 != o2) {
+ cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl;
+ exit(1);
+ }
+ cerr << oa.Name() << " SUCCESS\n";
+}
+
+using namespace std::tr1;
+
+void TestOnline() {
+ size_t N = 20;
+ double C = 1.0;
+ double eta0 = 0.2;
+ std::tr1::shared_ptr<LearningRateSchedule> r(new ExponentialDecayLearningRate(N, eta0, 0.85));
+ //shared_ptr<LearningRateSchedule> r(new StandardLearningRate(N, eta0));
+ CumulativeL1OnlineOptimizer opt(r, N, C, std::vector<int>());
+ assert(r->eta(10) < r->eta(1));
+}
+
+int main() {
+ int n = 3;
+ TestOptimizerVariants<LBFGSOptimizer>(n);
+ TestOptimizerVariants<RPropOptimizer>(n);
+ TestOnline();
+ return 0;
+}
+
diff --git a/training/utils/parallelize.pl b/training/utils/parallelize.pl
new file mode 100755
index 00000000..4197e0e5
--- /dev/null
+++ b/training/utils/parallelize.pl
@@ -0,0 +1,423 @@
+#!/usr/bin/env perl
+
+# Author: Adam Lopez
+#
+# This script takes a command that processes input
+# from stdin one-line-at-time, and parallelizes it
+# on the cluster using David Chiang's sentserver/
+# sentclient architecture.
+#
+# Prerequisites: the command *must* read each line
+# without waiting for subsequent lines of input
+# (for instance, a command which must read all lines
+# of input before processing will not work) and
+# return it to the output *without* buffering
+# multiple lines.
+
+#TODO: if -j 1, run immediately, not via sentserver? possible differences in environment might make debugging harder
+
+#ANNOYANCE: if input is shorter than -j n lines, or at the very last few lines, repeatedly sleeps. time cut down to 15s from 60s
+
+my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; }
+use LocalConfig;
+
+use Cwd qw/ abs_path cwd getcwd /;
+use File::Temp qw/ tempfile /;
+use Getopt::Long;
+use IPC::Open2;
+use strict;
+use POSIX ":sys_wait_h";
+
+use File::Basename;
+my $myDir = dirname(__FILE__);
+print STDERR __FILE__." -> $myDir\n";
+push(@INC, $myDir);
+require "libcall.pl";
+
+my $tailn=5; # +0 = concatenate all the client logs. 5 = last 5 lines
+my $recycle_clients; # spawn new clients when previous ones terminate
+my $stay_alive; # dont let server die when having zero clients
+my $joblist = "";
+my $errordir="";
+my $multiline;
+my $workdir = '.';
+my $numnodes = 8;
+my $user = $ENV{"USER"};
+my $pmem = "9g";
+my $basep=50300;
+my $randp=300;
+my $tryp=50;
+my $no_which;
+my $no_cd;
+
+my $DEBUG=$ENV{DEBUG};
+print STDERR "DEBUG=$DEBUG output enabled.\n" if $DEBUG;
+my $verbose = 1;
+sub verbose {
+ if ($verbose) {
+ print STDERR @_,"\n";
+ }
+}
+sub debug {
+ if ($DEBUG) {
+ my ($package, $filename, $line) = caller;
+ print STDERR "DEBUG: $filename($line): ",join(' ',@_),"\n";
+ }
+}
+my $is_shell_special=qr.[ \t\n\\><|&;"'`~*?{}$!()].;
+my $shell_escape_in_quote=qr.[\\"\$`!].;
+sub escape_shell {
+ my ($arg)=@_;
+ return undef unless defined $arg;
+ return '""' unless $arg;
+ if ($arg =~ /$is_shell_special/) {
+ $arg =~ s/($shell_escape_in_quote)/\\$1/g;
+ return "\"$arg\"";
+ }
+ return $arg;
+}
+sub preview_files {
+ my ($l,$skipempty,$footer,$n)=@_;
+ $n=$tailn unless defined $n;
+ my @f=grep { ! ($skipempty && -z $_) } @$l;
+ my $fn=join(' ',map {escape_shell($_)} @f);
+ my $cmd="tail -n $n $fn";
+ unchecked_output("$cmd").($footer?"\nNONEMPTY FILES:\n$fn\n":"");
+}
+sub prefix_dirname($) {
+ #like `dirname but if ends in / then return the whole thing
+ local ($_)=@_;
+ if (/\/$/) {
+ $_;
+ } else {
+ s#/[^/]$##;
+ $_ ? $_ : '';
+ }
+}
+sub ensure_final_slash($) {
+ local ($_)=@_;
+ m#/$# ? $_ : ($_."/");
+}
+sub extend_path($$;$$) {
+ my ($base,$ext,$mkdir,$baseisdir)=@_;
+ if (-d $base) {
+ $base.="/";
+ } else {
+ my $dir;
+ if ($baseisdir) {
+ $dir=$base;
+ $base.='/' unless $base =~ /\/$/;
+ } else {
+ $dir=prefix_dirname($base);
+ }
+ my @cmd=("/bin/mkdir","-p",$dir);
+ check_call(@cmd) if $mkdir;
+ }
+ return $base.$ext;
+}
+
+my $abscwd=abs_path(&getcwd);
+sub print_help;
+
+my $use_fork;
+my @pids;
+
+# Process command-line options
+unless (GetOptions(
+ "stay-alive" => \$stay_alive,
+ "recycle-clients" => \$recycle_clients,
+ "error-dir=s" => \$errordir,
+ "multi-line" => \$multiline,
+ "workdir=s" => \$workdir,
+ "use-fork" => \$use_fork,
+ "verbose" => \$verbose,
+ "jobs=i" => \$numnodes,
+ "pmem=s" => \$pmem,
+ "baseport=i" => \$basep,
+# "iport=i" => \$randp, #for short name -i
+ "no-which!" => \$no_which,
+ "no-cd!" => \$no_cd,
+ "tailn=s" => \$tailn,
+) && scalar @ARGV){
+ print_help();
+ die "bad options.";
+}
+
+my $cmd = "";
+my $prog=shift;
+if ($no_which) {
+ $cmd=$prog;
+} else {
+ $cmd=check_output("which $prog");
+ chomp $cmd;
+ die "$prog not found - $cmd" unless $cmd;
+}
+#$cmd=abs_path($cmd);
+for my $arg (@ARGV) {
+ $cmd .= " ".escape_shell($arg);
+}
+die "Please specify a command to parallelize\n" if $cmd eq '';
+
+my $cdcmd=$no_cd ? '' : ("cd ".escape_shell($abscwd)."\n");
+
+my $executable = $cmd;
+$executable =~ s/^\s*(\S+)($|\s.*)/$1/;
+$executable=check_output("basename $executable");
+chomp $executable;
+
+
+print STDERR "Parallelizing ($numnodes ways): $cmd\n\n";
+
+# create -e dir and save .sh
+use File::Temp qw/tempdir/;
+unless ($errordir) {
+ $errordir=tempdir("$executable.XXXXXX",CLEANUP=>1);
+}
+if ($errordir) {
+ my $scriptfile=extend_path("$errordir/","$executable.sh",1,1);
+ -d $errordir || die "should have created -e dir $errordir";
+ open SF,">",$scriptfile || die;
+ print SF "$cdcmd$cmd\n";
+ close SF;
+ chmod 0755,$scriptfile;
+ $errordir=abs_path($errordir);
+ &verbose("-e dir: $errordir");
+}
+
+# set cleanup handler
+my @cleanup_cmds;
+sub cleanup;
+sub cleanup_and_die;
+$SIG{INT} = "cleanup_and_die";
+$SIG{TERM} = "cleanup_and_die";
+$SIG{HUP} = "cleanup_and_die";
+
+# other subs:
+sub numof_live_jobs;
+sub launch_job_on_node;
+
+
+# vars
+my $mydir = check_output("dirname $0"); chomp $mydir;
+my $sentserver = "$mydir/sentserver";
+my $sentclient = "$mydir/sentclient";
+my $host = check_output("hostname");
+chomp $host;
+
+
+# find open port
+srand;
+my $port = 50300+int(rand($randp));
+my $endp=$port+$tryp;
+sub listening_port_lines {
+ my $quiet=$verbose?'':'2>/dev/null';
+ return unchecked_output("netstat -a -n $quiet | grep LISTENING | grep -i tcp");
+}
+my $netstat=&listening_port_lines;
+
+if ($verbose){ print STDERR "Testing port $port...";}
+
+while ($netstat=~/$port/ || &listening_port_lines=~/$port/){
+ if ($verbose){ print STDERR "port is busy\n";}
+ $port++;
+ if ($port > $endp){
+ die "Unable to find open port\n";
+ }
+ if ($verbose){ print STDERR "Testing port $port... "; }
+}
+if ($verbose){
+ print STDERR "port $port is available\n";
+}
+
+my $key = int(rand()*1000000);
+
+my $multiflag = "";
+if ($multiline){ $multiflag = "-m"; print STDERR "expecting multiline output.\n"; }
+my $stay_alive_flag = "";
+if ($stay_alive){ $stay_alive_flag = "--stay-alive"; print STDERR "staying alive while no clients are connected.\n"; }
+
+my $node_count = 0;
+my $script = "";
+# fork == one thread runs the sentserver, while the
+# other spawns the sentclient commands.
+my $pid = fork;
+if ($pid == 0) { # child
+ sleep 8; # give other thread time to start sentserver
+ $script = "$cdcmd$sentclient $host:$port:$key $cmd";
+
+ if ($verbose){
+ print STDERR "Client script:\n====\n";
+ print STDERR $script;
+ print STDERR "====\n";
+ }
+ for (my $jobn=0; $jobn<$numnodes; $jobn++){
+ launch_job();
+ }
+ if ($recycle_clients) {
+ my $ret;
+ my $livejobs;
+ while (1) {
+ $ret = waitpid($pid, WNOHANG);
+ #print STDERR "waitpid $pid ret = $ret \n";
+ last if ($ret != 0);
+ $livejobs = numof_live_jobs();
+ if ($numnodes >= $livejobs ) { # a client terminated, OR # lines of input was less than -j
+ print STDERR "num of requested nodes = $numnodes; num of currently live jobs = $livejobs; Client terminated - launching another.\n";
+ launch_job();
+ } else {
+ sleep 15;
+ }
+ }
+ }
+ print STDERR "CHILD PROCESSES SPAWNED ... WAITING\n";
+ for my $p (@pids) {
+ waitpid($p, 0);
+ }
+} else {
+# my $todo = "$sentserver -k $key $multiflag $port ";
+ my $todo = "$sentserver -k $key $multiflag $port $stay_alive_flag ";
+ if ($verbose){ print STDERR "Running: $todo\n"; }
+ check_call($todo);
+ print STDERR "Call to $sentserver returned.\n";
+ cleanup();
+ exit(0);
+}
+
+sub numof_live_jobs {
+ if ($use_fork) {
+ die "not implemented";
+ } else {
+ # We can probably continue decoding if the qstat error is only temporary
+ my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat")));
+ return ($#livejobs + 1);
+ }
+}
+my (@errors,@outs,@cmds);
+
+sub launch_job {
+ if ($use_fork) { return launch_job_fork(); }
+ my $errorfile = "/dev/null";
+ my $outfile = "/dev/null";
+ $node_count++;
+ my $clientname = $executable;
+ $clientname =~ s/^(.{4}).*$/$1/;
+ $clientname = "$clientname.$node_count";
+ if ($errordir){
+ $errorfile = "$errordir/$clientname.ER";
+ $outfile = "$errordir/$clientname.OU";
+ push @errors,$errorfile;
+ push @outs,$outfile;
+ }
+ my $todo = qsub_args($pmem) . " -N $clientname -o $outfile -e $errorfile";
+ push @cmds,$todo;
+
+ print STDERR "Running: $todo\n";
+ local(*QOUT, *QIN);
+ open2(\*QOUT, \*QIN, $todo) or die "Failed to open2: $!";
+ print QIN $script;
+ close QIN;
+ while (my $jobid=<QOUT>){
+ chomp $jobid;
+ if ($verbose){ print STDERR "Launched client job: $jobid"; }
+ $jobid =~ s/^(\d+)(.*?)$/\1/g;
+ $jobid =~ s/^Your job (\d+) .*$/\1/;
+ print STDERR " short job id $jobid\n";
+ if ($verbose){
+ print STDERR "cd: $abscwd\n";
+ print STDERR "cmd: $cmd\n";
+ }
+ if ($joblist == "") { $joblist = $jobid; }
+ else {$joblist = $joblist . "\|" . $jobid; }
+ my $cleanfn="qdel $jobid 2> /dev/null";
+ push(@cleanup_cmds, $cleanfn);
+ }
+ close QOUT;
+}
+
+sub launch_job_fork {
+ my $errorfile = "/dev/null";
+ my $outfile = "/dev/null";
+ $node_count++;
+ my $clientname = $executable;
+ $clientname =~ s/^(.{4}).*$/$1/;
+ $clientname = "$clientname.$node_count";
+ if ($errordir){
+ $errorfile = "$errordir/$clientname.ER";
+ $outfile = "$errordir/$clientname.OU";
+ push @errors,$errorfile;
+ push @outs,$outfile;
+ }
+ my $pid = fork;
+ if ($pid == 0) {
+ my ($fh, $scr_name) = get_temp_script();
+ print $fh $script;
+ close $fh;
+ my $todo = "/bin/bash -xeo pipefail $scr_name 1> $outfile 2> $errorfile";
+ print STDERR "EXEC: $todo\n";
+ my $out = check_output("$todo");
+ unlink $scr_name or warn "Failed to remove $scr_name";
+ exit 0;
+ } else {
+ push @pids, $pid;
+ }
+}
+
+sub get_temp_script {
+ my ($fh, $filename) = tempfile( "$workdir/workXXXX", SUFFIX => '.sh');
+ return ($fh, $filename);
+}
+
+sub cleanup_and_die {
+ cleanup();
+ die "\n";
+}
+
+sub cleanup {
+ print STDERR "Cleaning up...\n";
+ for $cmd (@cleanup_cmds){
+ print STDERR " Cleanup command: $cmd\n";
+ eval $cmd;
+ }
+ print STDERR "outputs:\n",preview_files(\@outs,1),"\n";
+ print STDERR "errors:\n",preview_files(\@errors,1),"\n";
+ print STDERR "cmd:\n",$cmd,"\n";
+ print STDERR " cat $errordir/*.ER\nfor logs.\n";
+ print STDERR "Cleanup finished.\n";
+}
+
+sub print_help
+{
+ my $name = check_output("basename $0"); chomp $name;
+ print << "Help";
+
+usage: $name [options]
+
+ Automatic black-box parallelization of commands.
+
+options:
+
+ --use-fork
+ Instead of using qsub, use fork.
+
+ -e, --error-dir <dir>
+ Retain output files from jobs in <dir>, rather
+ than silently deleting them.
+
+ -m, --multi-line
+ Expect that command may produce multiple output
+ lines for a single input line. $name makes a
+ reasonable attempt to obtain all output before
+ processing additional inputs. However, use of this
+ option is inherently unsafe.
+
+ -v, --verbose
+ Print diagnostic informatoin on stderr.
+
+ -j, --jobs
+ Number of jobs to use.
+
+ -p, --pmem
+ pmem setting for each job.
+
+Help
+}
diff --git a/training/utils/risk.cc b/training/utils/risk.cc
new file mode 100644
index 00000000..d5a12cfd
--- /dev/null
+++ b/training/utils/risk.cc
@@ -0,0 +1,45 @@
+#include "risk.h"
+
+#include "prob.h"
+#include "candidate_set.h"
+#include "ns.h"
+
+using namespace std;
+
+namespace training {
+
+// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)])
+double CandidateSetRisk::operator()(const vector<double>& params,
+ SparseVector<double>* g) const {
+ prob_t z;
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const prob_t u(cands_[i].fmap.dot(params), init_lnx());
+ z += u;
+ }
+ const double log_z = log(z);
+
+ SparseVector<double> exp_feats;
+ if (g) {
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const double log_prob = cands_[i].fmap.dot(params) - log_z;
+ const double prob = exp(log_prob);
+ exp_feats += cands_[i].fmap * prob;
+ }
+ }
+
+ double risk = 0;
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const double log_prob = cands_[i].fmap.dot(params) - log_z;
+ const double prob = exp(log_prob);
+ const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats)
+ : 1.0 - metric_.ComputeScore(cands_[i].eval_feats);
+ const double r = prob * cost;
+ risk += r;
+ if (g) (*g) += (cands_[i].fmap - exp_feats) * r;
+ }
+ return risk;
+}
+
+}
+
+
diff --git a/training/utils/risk.h b/training/utils/risk.h
new file mode 100644
index 00000000..2e8db0fb
--- /dev/null
+++ b/training/utils/risk.h
@@ -0,0 +1,26 @@
+#ifndef _RISK_H_
+#define _RISK_H_
+
+#include <vector>
+#include "sparse_vector.h"
+class EvaluationMetric;
+
+namespace training {
+ class CandidateSet;
+
+ class CandidateSetRisk {
+ public:
+ explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) :
+ cands_(cs),
+ metric_(metric) {}
+ // compute the risk (expected loss) of a CandidateSet
+ // (optional) the gradient of the risk with respect to params
+ double operator()(const std::vector<double>& params,
+ SparseVector<double>* g = NULL) const;
+ private:
+ const CandidateSet& cands_;
+ const EvaluationMetric& metric_;
+ };
+};
+
+#endif
diff --git a/training/utils/sentclient.c b/training/utils/sentclient.c
new file mode 100644
index 00000000..91d994ab
--- /dev/null
+++ b/training/utils/sentclient.c
@@ -0,0 +1,76 @@
+/* Copyright (c) 2001 by David Chiang. All rights reserved.*/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <string.h>
+
+#include "sentserver.h"
+
+int main (int argc, char *argv[]) {
+ int sock, port;
+ char *s, *key;
+ struct hostent *hp;
+ struct sockaddr_in server;
+ int errors = 0;
+
+ if (argc < 3) {
+ fprintf(stderr, "Usage: sentclient host[:port[:key]] command [args ...]\n");
+ exit(1);
+ }
+
+ s = strchr(argv[1], ':');
+ key = NULL;
+
+ if (s == NULL) {
+ port = DEFAULT_PORT;
+ } else {
+ *s = '\0';
+ s+=1;
+ /* dumb hack */
+ key = strchr(s, ':');
+ if (key != NULL){
+ *key = '\0';
+ key += 1;
+ }
+ port = atoi(s);
+ }
+
+ sock = socket(AF_INET, SOCK_STREAM, 0);
+
+ hp = gethostbyname(argv[1]);
+ if (hp == NULL) {
+ fprintf(stderr, "unknown host %s\n", argv[1]);
+ exit(1);
+ }
+
+ bzero((char *)&server, sizeof(server));
+ bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length);
+ server.sin_family = hp->h_addrtype;
+ server.sin_port = htons(port);
+
+ while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) {
+ perror("connect()");
+ sleep(1);
+ errors++;
+ if (errors > 5)
+ exit(1);
+ }
+
+ close(0);
+ close(1);
+ dup2(sock, 0);
+ dup2(sock, 1);
+
+ if (key != NULL){
+ write(1, key, strlen(key));
+ write(1, "\n", 1);
+ }
+
+ execvp(argv[2], argv+2);
+ return 0;
+}
diff --git a/training/utils/sentserver.c b/training/utils/sentserver.c
new file mode 100644
index 00000000..c20b4fa6
--- /dev/null
+++ b/training/utils/sentserver.c
@@ -0,0 +1,515 @@
+/* Copyright (c) 2001 by David Chiang. All rights reserved.*/
+
+#include <string.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/time.h>
+#include <netinet/in.h>
+#include <sched.h>
+#include <pthread.h>
+#include <errno.h>
+
+#include "sentserver.h"
+
+#define MAX_CLIENTS 64
+
+struct clientinfo {
+ int s;
+ struct sockaddr_in sin;
+};
+
+struct line {
+ int id;
+ char *s;
+ int status;
+ struct line *next;
+} *head, **ptail;
+
+int n_sent = 0, n_received=0, n_flushed=0;
+
+#define STATUS_RUNNING 0
+#define STATUS_ABORTED 1
+#define STATUS_FINISHED 2
+
+pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER;
+pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER;
+pthread_mutex_t input_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+int n_clients = 0;
+int s;
+int expect_multiline_output = 0;
+int log_mutex = 0;
+int stay_alive = 0; /* dont panic and die with zero clients */
+
+void queue_finish(struct line *node, char *s, int fid);
+char * read_line(int fd, int multiline);
+void done (int code);
+
+struct line * queue_get(int fid) {
+ struct line *cur;
+ char *s, *synch;
+
+ if (log_mutex) fprintf(stderr, "Getting for data for fid %d\n", fid);
+ if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid);
+ pthread_mutex_lock(&queue_mutex);
+
+ /* First, check for aborted sentences. */
+
+ if (log_mutex) fprintf(stderr, " Checking queue for aborted jobs (fid %d)\n", fid);
+ for (cur = head; cur != NULL; cur = cur->next) {
+ if (cur->status == STATUS_ABORTED) {
+ cur->status = STATUS_RUNNING;
+
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+
+ return cur;
+ }
+ }
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+
+ /* Otherwise, read a new one. */
+ if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid);
+ if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid);
+ pthread_mutex_lock(&input_mutex);
+ s = read_line(0,0);
+
+ while (s) {
+ if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid);
+ pthread_mutex_lock(&queue_mutex);
+ if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid);
+ pthread_mutex_unlock(&input_mutex);
+
+ cur = malloc(sizeof (struct line));
+ cur->id = n_sent;
+ cur->s = s;
+ cur->next = NULL;
+
+ *ptail = cur;
+ ptail = &cur->next;
+
+ n_sent++;
+
+ if (strcmp(s,"===SYNCH===\n")==0){
+ fprintf(stderr, "Received ===SYNCH=== signal (fid %d)\n", fid);
+ // Note: queue_finish calls free(cur->s).
+ // Therefore we need to create a new string here.
+ synch = malloc((strlen("===SYNCH===\n")+2) * sizeof (char));
+ synch = strcpy(synch, s);
+
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+ queue_finish(cur, synch, fid); /* handles its own lock */
+
+ if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid);
+ if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid);
+ pthread_mutex_lock(&input_mutex);
+
+ s = read_line(0,0);
+ } else {
+ if (log_mutex) fprintf(stderr, " Received new data %d (fid %d)\n", cur->id, fid);
+ cur->status = STATUS_RUNNING;
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+ return cur;
+ }
+ }
+
+ if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid);
+ pthread_mutex_unlock(&input_mutex);
+ /* Only way to reach this point: no more output */
+
+ if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid);
+ pthread_mutex_lock(&queue_mutex);
+ if (head == NULL) {
+ fprintf(stderr, "Reached end of file. Exiting.\n");
+ done(0);
+ } else
+ ptail = NULL; /* This serves as a signal that there is no more input */
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+
+ return NULL;
+}
+
+void queue_panic() {
+ struct line *next;
+ while (head && head->status == STATUS_FINISHED) {
+ /* Write out finished sentences */
+ if (head->status == STATUS_FINISHED) {
+ fputs(head->s, stdout);
+ fflush(stdout);
+ }
+ /* Write out blank line for unfinished sentences */
+ if (head->status == STATUS_ABORTED) {
+ fputs("\n", stdout);
+ fflush(stdout);
+ }
+ /* By defition, there cannot be any RUNNING sentences, since
+ function is only called when n_clients == 0 */
+ free(head->s);
+ next = head->next;
+ free(head);
+ head = next;
+ n_flushed++;
+ }
+ fclose(stdout);
+ fprintf(stderr, "All clients died. Panicking, flushing completed sentences and exiting.\n");
+ done(1);
+}
+
+void queue_abort(struct line *node, int fid) {
+ if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid);
+ pthread_mutex_lock(&queue_mutex);
+ node->status = STATUS_ABORTED;
+ if (n_clients == 0) {
+ if (stay_alive) {
+ fprintf(stderr, "Warning! No live clients detected! Staying alive, will retry soon.\n");
+ } else {
+ queue_panic();
+ }
+ }
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+}
+
+
+void queue_print() {
+ struct line *cur;
+
+ fprintf(stderr, " Queue\n");
+
+ for (cur = head; cur != NULL; cur = cur->next) {
+ switch(cur->status) {
+ case STATUS_RUNNING:
+ fprintf(stderr, " %d running ", cur->id); break;
+ case STATUS_ABORTED:
+ fprintf(stderr, " %d aborted ", cur->id); break;
+ case STATUS_FINISHED:
+ fprintf(stderr, " %d finished ", cur->id); break;
+
+ }
+ fprintf(stderr, "\n");
+ //fprintf(stderr, cur->s);
+ }
+}
+
+void queue_finish(struct line *node, char *s, int fid) {
+ struct line *next;
+ if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid);
+ pthread_mutex_lock(&queue_mutex);
+
+ free(node->s);
+ node->s = s;
+ node->status = STATUS_FINISHED;
+ n_received++;
+
+ /* Flush out finished nodes */
+ while (head && head->status == STATUS_FINISHED) {
+
+ if (log_mutex) fprintf(stderr, " Flushing finished node %d\n", head->id);
+
+ fputs(head->s, stdout);
+ fflush(stdout);
+ if (log_mutex) fprintf(stderr, " Flushed node %d\n", head->id);
+ free(head->s);
+
+ next = head->next;
+ free(head);
+
+ head = next;
+
+ n_flushed++;
+
+ if (head == NULL) { /* empty queue */
+ if (ptail == NULL) { /* This can only happen if set in queue_get as signal that there is no more input. */
+ fprintf(stderr, "All sentences finished. Exiting.\n");
+ done(0);
+ } else /* ptail pointed at something which was just popped off the stack -- reset to head*/
+ ptail = &head;
+ }
+ }
+
+ if (log_mutex) fprintf(stderr, " Flushing output %d\n", head->id);
+ fflush(stdout);
+ fprintf(stderr, "%d sentences sent, %d sentences finished, %d sentences flushed\n", n_sent, n_received, n_flushed);
+
+ if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid);
+ pthread_mutex_unlock(&queue_mutex);
+
+}
+
+char * read_line(int fd, int multiline) {
+ int size = 80;
+ char errorbuf[100];
+ char *s = malloc(size+2);
+ int result, errors=0;
+ int i = 0;
+
+ result = read(fd, s+i, 1);
+
+ while (1) {
+ if (result < 0) {
+ perror("read()");
+ sprintf(errorbuf, "Error code: %d\n", errno);
+ fprintf(stderr, errorbuf);
+ errors++;
+ if (errors > 5) {
+ free(s);
+ return NULL;
+ } else {
+ sleep(1); /* retry after delay */
+ }
+ } else if (result == 0) {
+ break;
+ } else if (multiline==0 && s[i] == '\n') {
+ break;
+ } else {
+ if (s[i] == '\n'){
+ /* if we've reached this point,
+ then multiline must be 1, and we're
+ going to poll the fd for an additional
+ line of data. The basic design is to
+ run a select on the filedescriptor fd.
+ Select will return under two conditions:
+ if there is data on the fd, or if a
+ timeout is reached. We'll select on this
+ fd. If select returns because there's data
+ ready, keep going; else assume there's no
+ more and return the data we already have.
+ */
+
+ fd_set set;
+ FD_ZERO(&set);
+ FD_SET(fd, &set);
+
+ struct timeval timeout;
+ timeout.tv_sec = 3; // number of seconds for timeout
+ timeout.tv_usec = 0;
+
+ int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout);
+ if (ready<1){
+ break; // no more data, stop looping
+ }
+ }
+ i++;
+
+ if (i == size) {
+ size = size*2;
+ s = realloc(s, size+2);
+ }
+ }
+
+ result = read(fd, s+i, 1);
+ }
+
+ if (result == 0 && i == 0) { /* end of file */
+ free(s);
+ return NULL;
+ }
+
+ s[i] = '\n';
+ s[i+1] = '\0';
+
+ return s;
+}
+
+void * new_client(void *arg) {
+ struct clientinfo *client = (struct clientinfo *)arg;
+ struct line *cur;
+ int result;
+ char *s;
+ char errorbuf[100];
+
+ pthread_mutex_lock(&clients_mutex);
+ n_clients++;
+ pthread_mutex_unlock(&clients_mutex);
+
+ fprintf(stderr, "Client connected (%d connected)\n", n_clients);
+
+ for (;;) {
+
+ cur = queue_get(client->s);
+
+ if (cur) {
+ /* fprintf(stderr, "Sending to client: %s", cur->s); */
+ fprintf(stderr, "Sending data %d to client (fid %d)\n", cur->id, client->s);
+ result = write(client->s, cur->s, strlen(cur->s));
+ if (result < strlen(cur->s)){
+ perror("write()");
+ sprintf(errorbuf, "Error code: %d\n", errno);
+ fprintf(stderr, errorbuf);
+
+ pthread_mutex_lock(&clients_mutex);
+ n_clients--;
+ pthread_mutex_unlock(&clients_mutex);
+
+ fprintf(stderr, "Client died (%d connected)\n", n_clients);
+ queue_abort(cur, client->s);
+
+ close(client->s);
+ free(client);
+
+ pthread_exit(NULL);
+ }
+ } else {
+ close(client->s);
+ pthread_mutex_lock(&clients_mutex);
+ n_clients--;
+ pthread_mutex_unlock(&clients_mutex);
+ fprintf(stderr, "Client dismissed (%d connected)\n", n_clients);
+ pthread_exit(NULL);
+ }
+
+ s = read_line(client->s,expect_multiline_output);
+ if (s) {
+ /* fprintf(stderr, "Client (fid %d) returned: %s", client->s, s); */
+ fprintf(stderr, "Client (fid %d) returned data %d\n", client->s, cur->id);
+// queue_print();
+ queue_finish(cur, s, client->s);
+ } else {
+ pthread_mutex_lock(&clients_mutex);
+ n_clients--;
+ pthread_mutex_unlock(&clients_mutex);
+
+ fprintf(stderr, "Client died (%d connected)\n", n_clients);
+ queue_abort(cur, client->s);
+
+ close(client->s);
+ free(client);
+
+ pthread_exit(NULL);
+ }
+
+ }
+ return 0;
+}
+
+void done (int code) {
+ close(s);
+ exit(code);
+}
+
+
+
+int main (int argc, char *argv[]) {
+ struct sockaddr_in sin, from;
+ int g;
+ socklen_t len;
+ struct clientinfo *client;
+ int port;
+ int opt;
+ int errors = 0;
+ int argi;
+ char *key = NULL, *client_key;
+ int use_key = 0;
+ /* the key stuff here doesn't provide any
+ real measure of security, it's mainly to keep
+ jobs from bumping into each other. */
+
+ pthread_t tid;
+ port = DEFAULT_PORT;
+
+ for (argi=1; argi < argc; argi++){
+ if (strcmp(argv[argi], "-m")==0){
+ expect_multiline_output = 1;
+ } else if (strcmp(argv[argi], "-k")==0){
+ argi++;
+ if (argi == argc){
+ fprintf(stderr, "Key must be specified after -k\n");
+ exit(1);
+ }
+ key = argv[argi];
+ use_key = 1;
+ } else if (strcmp(argv[argi], "--stay-alive")==0){
+ stay_alive = 1; /* dont panic and die with zero clients */
+ } else {
+ port = atoi(argv[argi]);
+ }
+ }
+
+ /* Initialize data structures */
+ head = NULL;
+ ptail = &head;
+
+ /* Set up listener */
+ s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ opt = 1;
+ setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
+
+ sin.sin_family = AF_INET;
+ sin.sin_addr.s_addr = htonl(INADDR_ANY);
+ sin.sin_port = htons(port);
+ while (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) {
+ perror("bind()");
+ sleep(1);
+ errors++;
+ if (errors > 100)
+ exit(1);
+ }
+
+ len = sizeof(sin);
+ getsockname(s, (struct sockaddr *) &sin, &len);
+
+ fprintf(stderr, "Listening on port %hu\n", ntohs(sin.sin_port));
+
+ while (listen(s, MAX_CLIENTS) < 0) {
+ perror("listen()");
+ sleep(1);
+ errors++;
+ if (errors > 100)
+ exit(1);
+ }
+
+ for (;;) {
+ len = sizeof(from);
+ g = accept(s, (struct sockaddr *)&from, &len);
+ if (g < 0) {
+ perror("accept()");
+ sleep(1);
+ continue;
+ }
+ client = malloc(sizeof(struct clientinfo));
+ client->s = g;
+ bcopy(&from, &client->sin, len);
+
+ if (use_key){
+ fd_set set;
+ FD_ZERO(&set);
+ FD_SET(client->s, &set);
+
+ struct timeval timeout;
+ timeout.tv_sec = 3; // number of seconds for timeout
+ timeout.tv_usec = 0;
+
+ int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout);
+ if (ready<1){
+ fprintf(stderr, "Prospective client failed to respond with correct key.\n");
+ close(client->s);
+ free(client);
+ } else {
+ client_key = read_line(client->s,0);
+ client_key[strlen(client_key)-1]='\0'; /* chop trailing newline */
+ if (strcmp(key, client_key)==0){
+ pthread_create(&tid, NULL, new_client, client);
+ } else {
+ fprintf(stderr, "Prospective client failed to respond with correct key.\n");
+ close(client->s);
+ free(client);
+ }
+ free(client_key);
+ }
+ } else {
+ pthread_create(&tid, NULL, new_client, client);
+ }
+ }
+
+}
+
+
+
diff --git a/training/utils/sentserver.h b/training/utils/sentserver.h
new file mode 100644
index 00000000..cd17a546
--- /dev/null
+++ b/training/utils/sentserver.h
@@ -0,0 +1,6 @@
+#ifndef SENTSERVER_H
+#define SENTSERVER_H
+
+#define DEFAULT_PORT 50000
+
+#endif