From 8aa29810bb77611cc20b7a384897ff6703783ea1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- training/utils/candidate_set.cc | 169 ++++ training/utils/candidate_set.h | 60 ++ training/utils/decode-and-evaluate.pl | 246 ++++++ training/utils/entropy.cc | 41 + training/utils/entropy.h | 22 + training/utils/grammar_convert.cc | 348 ++++++++ training/utils/lbfgs.h | 1459 +++++++++++++++++++++++++++++++++ training/utils/lbfgs_test.cc | 117 +++ training/utils/libcall.pl | 71 ++ training/utils/online_optimizer.cc | 16 + training/utils/online_optimizer.h | 129 +++ training/utils/optimize.cc | 102 +++ training/utils/optimize.h | 92 +++ training/utils/optimize_test.cc | 118 +++ training/utils/parallelize.pl | 423 ++++++++++ training/utils/risk.cc | 45 + training/utils/risk.h | 26 + training/utils/sentclient.c | 76 ++ training/utils/sentserver.c | 515 ++++++++++++ training/utils/sentserver.h | 6 + 20 files changed, 4081 insertions(+) create mode 100644 training/utils/candidate_set.cc create mode 100644 training/utils/candidate_set.h create mode 100755 training/utils/decode-and-evaluate.pl create mode 100644 training/utils/entropy.cc create mode 100644 training/utils/entropy.h create mode 100644 training/utils/grammar_convert.cc create mode 100644 training/utils/lbfgs.h create mode 100644 training/utils/lbfgs_test.cc create mode 100644 training/utils/libcall.pl create mode 100644 training/utils/online_optimizer.cc create mode 100644 training/utils/online_optimizer.h create mode 100644 training/utils/optimize.cc create mode 100644 training/utils/optimize.h create mode 100644 training/utils/optimize_test.cc create mode 100755 training/utils/parallelize.pl create mode 100644 training/utils/risk.cc create mode 100644 training/utils/risk.h create mode 100644 training/utils/sentclient.c create mode 100644 training/utils/sentserver.c create mode 100644 training/utils/sentserver.h (limited to 'training/utils') diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc new file mode 100644 index 00000000..087efec3 --- /dev/null +++ b/training/utils/candidate_set.cc @@ -0,0 +1,169 @@ +#include "candidate_set.h" + +#include + +#include + +#include "verbose.h" +#include "ns.h" +#include "filelib.h" +#include "wordid.h" +#include "tdict.h" +#include "hg.h" +#include "kbest.h" +#include "viterbi.h" + +using namespace std; + +namespace training { + +struct ApproxVectorHasher { + static const size_t MASK = 0xFFFFFFFFull; + union UType { + double f; // leave as double + size_t i; + }; + static inline double round(const double x) { + UType t; + t.f = x; + size_t r = t.i & MASK; + if ((r << 1) > MASK) + t.i += MASK - r + 1; + else + t.i &= (1ull - MASK); + return t.f; + } + size_t operator()(const SparseVector& x) const { + size_t h = 0x573915839; + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { + UType t; + t.f = it->second; + if (t.f) { + size_t z = (t.i >> 32); + boost::hash_combine(h, it->first); + boost::hash_combine(h, z); + } + } + return h; + } +}; + +struct ApproxVectorEquals { + bool operator()(const SparseVector& a, const SparseVector& b) const { + SparseVector::const_iterator bit = b.begin(); + for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { + if (bit == b.end() || + ait->first != bit->first || + ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) + return false; + ++bit; + } + if (bit != b.end()) return false; + return true; + } +}; + +struct CandidateCompare { + bool operator()(const Candidate& a, const Candidate& b) const { + ApproxVectorEquals eq; + return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); + } +}; + +struct CandidateHasher { + size_t operator()(const Candidate& x) const { + boost::hash > hhasher; + ApproxVectorHasher vhasher; + size_t ha = hhasher(x.ewords); + boost::hash_combine(ha, vhasher(x.fmap)); + return ha; + } +}; + +static void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + +void CandidateSet::WriteToFile(const string& file) const { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + string ss; + for (unsigned i = 0; i < cs.size(); ++i) { + out << TD::GetString(cs[i].ewords) << endl; + out << cs[i].fmap << endl; + cs[i].eval_feats.Encode(&ss); + out << ss << endl; + } +} + +void CandidateSet::ReadFromFile(const string& file) { + if(!SILENT) cerr << "Reading candidates from " << file << endl; + ReadFile rf(file); + istream& in = *rf.stream(); + string cand; + string feats; + string ss; + while(getline(in, cand)) { + getline(in, feats); + getline(in, ss); + assert(in); + cs.push_back(Candidate()); + TD::ConvertSentence(cand, &cs.back().ewords); + ParseSparseVector(feats, 0, &cs.back().fmap); + cs.back().eval_feats = SufficientStats(ss); + } + if(!SILENT) cerr << " read " << cs.size() << " candidates\n"; +} + +void CandidateSet::Dedup() { + if(!SILENT) cerr << "Dedup in=" << cs.size(); + tr1::unordered_set u; + while(cs.size() > 0) { + u.insert(cs.back()); + cs.pop_back(); + } + tr1::unordered_set::iterator it = u.begin(); + while (it != u.end()) { + cs.push_back(*it); + it = u.erase(it); + } + if(!SILENT) cerr << " out=" << cs.size() << endl; +} + +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + for (unsigned i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().eval_feats); + } + Dedup(); +} + +} diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h new file mode 100644 index 00000000..9d326ed0 --- /dev/null +++ b/training/utils/candidate_set.h @@ -0,0 +1,60 @@ +#ifndef _CANDIDATE_SET_H_ +#define _CANDIDATE_SET_H_ + +#include +#include + +#include "ns.h" +#include "wordid.h" +#include "sparse_vector.h" + +class Hypergraph; + +namespace training { + +struct Candidate { + Candidate() {} + Candidate(const std::vector& e, const SparseVector& fm) : + ewords(e), + fmap(fm) {} + Candidate(const std::vector& e, + const SparseVector& fm, + const SegmentEvaluator& se) : + ewords(e), + fmap(fm) { + se.Evaluate(ewords, &eval_feats); + } + + void swap(Candidate& other) { + eval_feats.swap(other.eval_feats); + ewords.swap(other.ewords); + fmap.swap(other.fmap); + } + + std::vector ewords; + SparseVector fmap; + SufficientStats eval_feats; +}; + +// represents some kind of collection of translation candidates, e.g. +// aggregated k-best lists, sample lists, etc. +class CandidateSet { + public: + CandidateSet() {} + inline size_t size() const { return cs.size(); } + const Candidate& operator[](size_t i) const { return cs[i]; } + + void ReadFromFile(const std::string& file); + void WriteToFile(const std::string& file) const; + void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); + // TODO add code to do unique k-best + // TODO add code to draw k samples + + private: + void Dedup(); + std::vector cs; +}; + +} + +#endif diff --git a/training/utils/decode-and-evaluate.pl b/training/utils/decode-and-evaluate.pl new file mode 100755 index 00000000..1a332c08 --- /dev/null +++ b/training/utils/decode-and-evaluate.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use File::Basename qw(basename); +my $QSUB_CMD = qsub_args(mert_memory()); + +require "libcall.pl"; + +# Default settings +my $default_jobs = env_default_jobs(); +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $parallelize = "$bin_dir/parallelize.pl"; +my $libcall = "$bin_dir/libcall.pl"; +my $sentserver = "$bin_dir/sentserver"; +my $sentclient = "$bin_dir/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "9g"; +my $help = 0; +my $config; +my $test_set; +my $weights; +my $use_make = 1; +my $useqsub; +my $cpbin=1; +# Process command-line options +if (GetOptions( + "jobs=i" => \$jobs, + "help" => \$help, + "qsub" => \$useqsub, + "input=s" => \$test_set, + "config=s" => \$config, + "weights=s" => \$weights, +) == 0 || @ARGV!=0 || $help) { + print_help(); + exit; +} + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); + +if (!defined $test_set) { push @missing_args, "--input"; } +if (!defined $config) { push @missing_args, "--config"; } +if (!defined $weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\nUse --help for more information.\n" if (@missing_args); + +my @tf = localtime(time); +my $tname = basename($test_set); +$tname =~ s/\.(sgm|sgml|xml)$//i; +my $dir = "eval.$tname." . sprintf('%d%02d%02d-%02d%02d%02d', 1900+$tf[5], $tf[4], $tf[3], $tf[2], $tf[1], $tf[0]); + +my $time = unchecked_output("date"); + +check_call("mkdir -p $dir"); + +split_devset($test_set, "$dir/test.input.raw", "$dir/test.refs"); +my $refs = "-r $dir/test.refs"; +my $newsrc = "$dir/test.input"; +enseg("$dir/test.input.raw", $newsrc); +my $src_file = $newsrc; +open F, "<$src_file" or die "Can't read $src_file: $!"; close F; + +my $test_trans="$dir/test.trans"; +my $logdir="$dir/logs"; +my $decoderLog="$logdir/decoder.sentserver.log"; +check_call("mkdir -p $logdir"); + +#decode +print STDERR "RUNNING DECODER AT "; +print STDERR unchecked_output("date"); +my $decoder_cmd = "$decoder -c $config --weights $weights"; +my $pcmd; +if ($use_make) { + $pcmd = "cat $src_file | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; +} else { + $pcmd = "cat $src_file | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; +} +my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $test_trans"; +check_bash_call($cmd); +print STDERR "DECODER COMPLETED AT "; +print STDERR unchecked_output("date"); +print STDERR "\nOUTPUT: $test_trans\n\n"; +my $bleu = check_output("cat $test_trans | $SCORER $refs -m ibm_bleu"); +chomp $bleu; +print STDERR "BLEU: $bleu\n"; +my $ter = check_output("cat $test_trans | $SCORER $refs -m ter"); +chomp $ter; +print STDERR " TER: $ter\n"; +open TR, ">$dir/test.scores" or die "Can't write $dir/test.scores: $!"; +print TR <$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; +} + +sub print_help { + my $executable = basename($0); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable --config cdec.ini --weights weights.txt [--jobs N] [--qsub] + +Options: + + --help + Print this message and exit. + + --config + A path to the cdec.ini file. + + --weights + A file specifying feature weights. + + --dir + Directory for intermediate and output files. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} + +sub split_devset { + my ($infile, $outsrc, $outref) = @_; + open F, "<$infile" or die "Can't read $infile: $!"; + open S, ">$outsrc" or die "Can't write $outsrc: $!"; + open R, ">$outref" or die "Can't write $outref: $!"; + while() { + chomp; + my ($src, @refs) = split /\s*\|\|\|\s*/; + die "Malformed devset line: $_\n" unless scalar @refs > 0; + print S "$src\n"; + print R join(' ||| ', @refs) . "\n"; + } + close R; + close S; + close F; +} + diff --git a/training/utils/entropy.cc b/training/utils/entropy.cc new file mode 100644 index 00000000..4fdbe2be --- /dev/null +++ b/training/utils/entropy.cc @@ -0,0 +1,41 @@ +#include "entropy.h" + +#include "prob.h" +#include "candidate_set.h" + +using namespace std; + +namespace training { + +// see Mann and McCallum "Efficient Computation of Entropy Gradient ..." for +// a mostly clear derivation of: +// g = E[ F(x,y) * log p(y|x) ] + H(y | x) * E[ F(x,y) ] +double CandidateSetEntropy::operator()(const vector& params, + SparseVector* g) const { + prob_t z; + vector dps(cands_.size()); + for (unsigned i = 0; i < cands_.size(); ++i) { + dps[i] = cands_[i].fmap.dot(params); + const prob_t u(dps[i], init_lnx()); + z += u; + } + const double log_z = log(z); + + SparseVector exp_feats; + double entropy = 0; + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + const double e_logprob = prob * log_prob; + entropy -= e_logprob; + if (g) { + (*g) += cands_[i].fmap * e_logprob; + exp_feats += cands_[i].fmap * prob; + } + } + if (g) (*g) += exp_feats * entropy; + return entropy; +} + +} + diff --git a/training/utils/entropy.h b/training/utils/entropy.h new file mode 100644 index 00000000..796589ca --- /dev/null +++ b/training/utils/entropy.h @@ -0,0 +1,22 @@ +#ifndef _CSENTROPY_H_ +#define _CSENTROPY_H_ + +#include +#include "sparse_vector.h" + +namespace training { + class CandidateSet; + + class CandidateSetEntropy { + public: + explicit CandidateSetEntropy(const CandidateSet& cs) : cands_(cs) {} + // compute the entropy (expected log likelihood) of a CandidateSet + // (optional) the gradient of the entropy with respect to params + double operator()(const std::vector& params, + SparseVector* g = NULL) const; + private: + const CandidateSet& cands_; + }; +}; + +#endif diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc new file mode 100644 index 00000000..607a7cb9 --- /dev/null +++ b/training/utils/grammar_convert.cc @@ -0,0 +1,348 @@ +/* + this program modifies cfg hypergraphs (forests) and extracts kbests? + what are: json, split ? + */ +#include +#include +#include + +#include +#include + +#include "inside_outside.h" +#include "tdict.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +WordID kSTART; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value()->default_value("-"), "Input file") + ("format,f", po::value()->default_value("cfg"), "Input format. Values: cfg, json, split") + ("output,o", po::value()->default_value("json"), "Output command. Values: json, 1best") + ("reorder,r", "Add Yamada & Knight (2002) reorderings") + ("weights,w", po::value(), "Feature weights for k-best derivations [optional]") + ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") + ("k_derivations,k", po::value(), "Show k derivations and their features") + ("max_reorder,m", po::value()->default_value(999), "Move a constituent at most this far") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +int GetOrCreateNode(const WordID& lhs, map* lhs2node, Hypergraph* hg) { + int& node_id = (*lhs2node)[lhs]; + if (!node_id) + node_id = hg->AddNode(lhs)->id_ + 1; + return node_id - 1; +} + +void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { + if (goal < 0) { + cerr << "Error! [S] not found in grammar!\n"; + exit(1); + } + if (hg->nodes_[goal].in_edges_.size() != 1) { + cerr << "Error! [S] has more than one rewrite!\n"; + exit(1); + } + int old_size = hg->nodes_.size(); + hg->TopologicallySortNodesAndEdges(goal); + if (hg->nodes_.size() != old_size) { + cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; + } + vector inside; // inside score at each node + double p = Inside(*hg, &inside); + if (!p) { + cerr << "Warning! Grammar defines the empty language!\n"; + hg->clear(); + return; + } + vector prune(hg->edges_.size(), false); + int bad_edges = 0; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + bool bad = false; + for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { + if (!inside[edge.tail_nodes_[j]]) { + bad = true; + ++bad_edges; + } + } + prune[i] = bad; + } + cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + if (prune[i]) + cerr << " " << hg->edges_[i].rule_->AsString() << endl; + } + hg->PruneEdges(prune); +} + +void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { + Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); + hg->ConnectEdgeToHeadNode(new_edge, head_node); + new_edge->feature_values_ = r->scores_; +} + +// from a category label like "NP_2", return "NP" +string PureCategory(WordID cat) { + assert(cat < 0); + string c = TD::Convert(cat*-1); + size_t p = c.find("_"); + if (p == string::npos) return c; + return c.substr(0, p); +}; + +string ConstituentOrderFeature(const TRule& rule, const vector& pi) { + const static string kTERM_VAR = "x"; + const vector& f = rule.f(); + map used; + vector terms(f.size()); + for (int i = 0; i < f.size(); ++i) { + const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); + int& count = used[term]; + if (!count) { + terms[i] = term; + } else { + ostringstream os; + os << term << count; + terms[i] = os.str(); + } + ++count; + } + ostringstream os; + os << PureCategory(rule.GetLHS()) << ':'; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) os << '_'; + os << terms[pi[i]]; + } + return os.str(); +} + +bool CheckPermutationMask(const vector& mask, const vector& pi) { + assert(mask.size() == pi.size()); + + int req_min = -1; + int cur_max = 0; + int cur_mask = -1; + for (int i = 0; i < mask.size(); ++i) { + if (mask[i] != cur_mask) { + cur_mask = mask[i]; + req_min = cur_max - 1; + } + if (pi[i] > req_min) { + if (pi[i] > cur_max) cur_max = pi[i]; + } else { + return false; + } + } + + return true; +} + +void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { + // Hypergraph tmp = *hg; + Hypergraph::Node* node = &hg->nodes_[nodeid]; + if (node->in_edges_.size() != 1) { + cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; + cerr << " not recursing!\n"; + return; + } +// for (int eii = 0; eii < node->in_edges_.size(); ++eii) { + const int oe_index = node->in_edges_.front(); + const TRule& rule = *hg->edges_[oe_index].rule_; + const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; + const int tail_size = orig_tail.size(); + for (int i = 0; i < tail_size; ++i) { + PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); + } + const vector& of = rule.f_; + if (of.size() == 1) return; + // cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; + // cerr << "ORIG: " << rule.AsString() << endl; + vector pi(of.size(), 0); + for (int i = 0; i < pi.size(); ++i) pi[i] = i; + + vector permutation_mask(of.size(), 0); + const bool dont_reorder_across_PU = true; // TODO add configuration + if (dont_reorder_across_PU) { + int cur = 0; + for (int i = 0; i < pi.size(); ++i) { + if (of[i] >= 0) continue; + const string cat = PureCategory(of[i]); + if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { + ++cur; + permutation_mask[i] = cur; + ++cur; + } else { + permutation_mask[i] = cur; + } + } + } + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); + while (next_permutation(pi.begin(), pi.end())) { + if (!CheckPermutationMask(permutation_mask, pi)) + continue; + vector nf(pi.size(), 0); + Hypergraph::TailNodeVector tail(pi.size(), 0); + bool skip = false; + for (int i = 0; i < pi.size(); ++i) { + int dist = pi[i] - i; if (dist < 0) dist *= -1; + if (dist > max_reorder) { skip = true; break; } + nf[i] = of[pi[i]]; + tail[i] = orig_tail[pi[i]]; + } + if (skip) continue; + TRulePtr nr(new TRule(rule)); + nr->f_ = nf; + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + nr->scores_.set_value(fid, 1.0); + // cerr << "PERM: " << nr->AsString() << endl; + CreateEdge(nr, tail, node, hg); + } + // } +} + +void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { + assert(hg->nodes_.back().cat_ == kSTART); + assert(hg->nodes_.back().in_edges_.size() == 1); + PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); +} + +void CollapseWeights(Hypergraph* hg) { + int fid = FD::Convert("Reordering"); + for (int i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + edge.feature_values_.clear(); + if (edge.edge_prob_ != prob_t::Zero()) { + edge.feature_values_.set_value(fid, log(edge.edge_prob_)); + } + } +} + +void ProcessHypergraph(const vector& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { + if (conf.count("reorder")) + PermuteYamadaAndKnight(hg, conf["max_reorder"].as()); + if (w.size() > 0) { hg->Reweight(w); } + if (conf.count("collapse_weights")) CollapseWeights(hg); + if (conf["output"].as() == "json") { + HypergraphIO::WriteToJSON(*hg, false, &cout); + if (!ref.empty()) { cerr << "REF: " << ref << endl; } + } else { + vector onebest; + ViterbiESentence(*hg, &onebest); + if (ref.empty()) { + cout << TD::GetString(onebest) << endl; + } else { + cout << TD::GetString(onebest) << " ||| " << ref << endl; + } + } + if (conf.count("k_derivations")) { + const int k = conf["k_derivations"].as(); + KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +int main(int argc, char **argv) { + kSTART = TD::Convert("S") * -1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as(); + const bool is_split_input = (conf["format"].as() == "split"); + const bool is_json_input = is_split_input || (conf["format"].as() == "json"); + const bool collapse_weights = conf.count("collapse_weights"); + vector w; + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &w); + + if (collapse_weights && !w.size()) { + cerr << "--collapse_weights requires a weights file to be specified!\n"; + exit(1); + } + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + int lc = 0; + Hypergraph hg; + map lhs2node; + while(*in) { + string line; + ++lc; + getline(*in, line); + if (is_json_input) { + if (line.empty() || line[0] == '#') continue; + string ref; + if (is_split_input) { + size_t pos = line.rfind("}}"); + assert(pos != string::npos); + size_t rstart = line.find("||| ", pos); + assert(rstart != string::npos); + ref = line.substr(rstart + 4); + line = line.substr(0, pos + 2); + } + istringstream is(line); + if (HypergraphIO::ReadFromJSON(&is, &hg)) { + ProcessHypergraph(w, conf, ref, &hg); + hg.clear(); + } else { + cerr << "Error reading grammar from JSON: line " << lc << endl; + exit(1); + } + } else { + if (line.empty()) { + int goal = lhs2node[kSTART] - 1; + FilterAndCheckCorrectness(goal, &hg); + ProcessHypergraph(w, conf, "", &hg); + hg.clear(); + lhs2node.clear(); + continue; + } + if (line[0] == '#') continue; + if (line[0] != '[') { + cerr << "Line " << lc << ": bad format\n"; + exit(1); + } + TRulePtr tr(TRule::CreateRuleMonolingual(line)); + Hypergraph::TailNodeVector tail; + for (int i = 0; i < tr->f_.size(); ++i) { + WordID var_cat = tr->f_[i]; + if (var_cat < 0) + tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); + } + const WordID lhs = tr->GetLHS(); + int head = GetOrCreateNode(lhs, &lhs2node, &hg); + Hypergraph::Edge* edge = hg.AddEdge(tr, tail); + edge->feature_values_ = tr->scores_; + Hypergraph::Node* node = &hg.nodes_[head]; + hg.ConnectEdgeToHeadNode(edge, node); + } + } +} + diff --git a/training/utils/lbfgs.h b/training/utils/lbfgs.h new file mode 100644 index 00000000..e8baecab --- /dev/null +++ b/training/utils/lbfgs.h @@ -0,0 +1,1459 @@ +#ifndef SCITBX_LBFGS_H +#define SCITBX_LBFGS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace scitbx { + +//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. +/*! Implementation of the + Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) + algorithm for large-scale multidimensional minimization + problems. + + This code was manually derived from Java code which was + in turn derived from the Fortran program + lbfgs.f. The Java translation was + effected mostly mechanically, with some manual + clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran + code have been pasted in. + + Information on the original LBFGS Fortran source code is + available at + http://www.netlib.org/opt/lbfgs_um.shar . The following + information is taken verbatim from the Netlib documentation + for the Fortran source. + +
+    file    opt/lbfgs_um.shar
+    for     unconstrained optimization problems
+    alg     limited memory BFGS method
+    by      J. Nocedal
+    contact nocedal@eecs.nwu.edu
+    ref     D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
+    ,       large scale optimization methods'' Mathematical Programming 45
+    ,       (1989), pp. 503-528.
+    ,       (Postscript file of this paper is available via anonymous ftp
+    ,       to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
+    
+ + @author Jorge Nocedal: original Fortran version, including comments + (July 1990).
+ Robert Dodier: Java translation, August 1997.
+ Ralf W. Grosse-Kunstleve: C++ port, March 2002.
+ Chris Dyer: serialize/deserialize functionality + */ +namespace lbfgs { + + //! Generic exception class for %lbfgs %error messages. + /*! All exceptions thrown by the minimizer are derived from this class. + */ + class error : public std::exception { + public: + //! Constructor. + error(std::string const& msg) throw() + : msg_("lbfgs error: " + msg) + {} + //! Access to error message. + virtual const char* what() const throw() { return msg_.c_str(); } + protected: + virtual ~error() throw() {} + std::string msg_; + public: + static std::string itoa(unsigned long i) { + std::ostringstream os; + os << i; + return os.str(); + } + }; + + //! Specific exception class. + class error_internal_error : public error { + public: + //! Constructor. + error_internal_error(const char* file, unsigned long line) throw() + : error( + "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") + {} + }; + + //! Specific exception class. + class error_improper_input_parameter : public error { + public: + //! Constructor. + error_improper_input_parameter(std::string const& msg) throw() + : error("Improper input parameter: " + msg) + {} + }; + + //! Specific exception class. + class error_improper_input_data : public error { + public: + //! Constructor. + error_improper_input_data(std::string const& msg) throw() + : error("Improper input data: " + msg) + {} + }; + + //! Specific exception class. + class error_search_direction_not_descent : public error { + public: + //! Constructor. + error_search_direction_not_descent() throw() + : error("The search direction is not a descent direction.") + {} + }; + + //! Specific exception class. + class error_line_search_failed : public error { + public: + //! Constructor. + error_line_search_failed(std::string const& msg) throw() + : error("Line search failed: " + msg) + {} + }; + + //! Specific exception class. + class error_line_search_failed_rounding_errors + : public error_line_search_failed { + public: + //! Constructor. + error_line_search_failed_rounding_errors(std::string const& msg) throw() + : error_line_search_failed(msg) + {} + }; + + namespace detail { + + template + inline + NumType + pow2(NumType const& x) { return x * x; } + + template + inline + NumType + abs(NumType const& x) { + if (x < NumType(0)) return -x; + return x; + } + + // This class implements an algorithm for multi-dimensional line search. + template + class mcsrch + { + protected: + int infoc; + FloatType dginit; + bool brackt; + bool stage1; + FloatType finit; + FloatType dgtest; + FloatType width; + FloatType width1; + FloatType stx; + FloatType fx; + FloatType dgx; + FloatType sty; + FloatType fy; + FloatType dgy; + FloatType stmin; + FloatType stmax; + + static FloatType const& max3( + FloatType const& x, + FloatType const& y, + FloatType const& z) + { + return x < y ? (y < z ? z : y ) : (x < z ? z : x ); + } + + public: + /* Minimize a function along a search direction. This code is + a Java translation of the function MCSRCH from + lbfgs.f, which in turn is a slight modification + of the subroutine CSRCH of More' and Thuente. + The changes are to allow reverse communication, and do not + affect the performance of the routine. This function, in turn, + calls mcstep.

+ + The Java translation was effected mostly mechanically, with + some manual clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran code have + been pasted in here as well.

+ + The purpose of mcsrch is to find a step which + satisfies a sufficient decrease condition and a curvature + condition.

+ + At each stage this function updates an interval of uncertainty + with endpoints stx and sty. The + interval of uncertainty is initially chosen so that it + contains a minimizer of the modified function +

+                f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
+           
+ If a step is obtained for which the modified function has a + nonpositive function value and nonnegative derivative, then + the interval of uncertainty is chosen so that it contains a + minimizer of f(x+stp*s).

+ + The algorithm is designed to find a step which satisfies + the sufficient decrease condition +

+                 f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s),
+           
+ and the curvature condition +
+                 abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s).
+           
+ If ftol is less than gtol and if, + for example, the function is bounded below, then there is + always a step which satisfies both conditions. If no step can + be found which satisfies both conditions, then the algorithm + usually stops when rounding errors prevent further progress. + In this case stp only satisfies the sufficient + decrease condition.

+ + @author Original Fortran version by Jorge J. More' and + David J. Thuente as part of the Minpack project, June 1983, + Argonne National Laboratory. Java translation by Robert + Dodier, August 1997. + + @param n The number of variables. + + @param x On entry this contains the base point for the line + search. On exit it contains x + stp*s. + + @param f On entry this contains the value of the objective + function at x. On exit it contains the value + of the objective function at x + stp*s. + + @param g On entry this contains the gradient of the objective + function at x. On exit it contains the gradient + at x + stp*s. + + @param s The search direction. + + @param stp On entry this contains an initial estimate of a + satifactory step length. On exit stp contains + the final estimate. + + @param ftol Tolerance for the sufficient decrease condition. + + @param xtol Termination occurs when the relative width of the + interval of uncertainty is at most xtol. + + @param maxfev Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param info This is an output variable, which can have these + values: +

    +
  • info = -1 A return is made to compute + the function and gradient. +
  • info = 1 The sufficient decrease condition + and the directional derivative condition hold. +
+ + @param nfev On exit, this is set to the number of function + evaluations. + + @param wa Temporary storage array, of length n. + */ + void run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa); + + /* The purpose of this function is to compute a safeguarded step + for a linesearch and to update an interval of uncertainty for + a minimizer of the function.

+ + The parameter stx contains the step with the + least function value. The parameter stp contains + the current step. It is assumed that the derivative at + stx is negative in the direction of the step. If + brackt is true when + mcstep returns then a minimizer has been + bracketed in an interval of uncertainty with endpoints + stx and sty.

+ + Variables that must be modified by mcstep are + implemented as 1-element arrays. + + @param stx Step at the best step obtained so far. + This variable is modified by mcstep. + @param fx Function value at the best step obtained so far. + This variable is modified by mcstep. + @param dx Derivative at the best step obtained so far. + The derivative must be negative in the direction of the + step, that is, dx and stp-stx must + have opposite signs. This variable is modified by + mcstep. + + @param sty Step at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + @param fy Function value at the other endpoint of the interval + of uncertainty. This variable is modified by + mcstep. + + @param dy Derivative at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + + @param stp Step at the current step. If brackt is set + then on input stp must be between stx + and sty. On output stp is set to the + new step. + @param fp Function value at the current step. + @param dp Derivative at the current step. + + @param brackt Tells whether a minimizer has been bracketed. + If the minimizer has not been bracketed, then on input this + variable must be set false. If the minimizer has + been bracketed, then on output this variable is + true. + + @param stpmin Lower bound for the step. + @param stpmax Upper bound for the step. + + If the return value is 1, 2, 3, or 4, then the step has + been computed successfully. A return value of 0 indicates + improper input parameters. + + @author Jorge J. More, David J. Thuente: original Fortran version, + as part of Minpack project. Argonne Nat'l Laboratory, June 1983. + Robert Dodier: Java translation, August 1997. + */ + static int mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax); + + void serialize(std::ostream* out) const { + out->write((const char*)&infoc,sizeof(infoc)); + out->write((const char*)&dginit,sizeof(dginit)); + out->write((const char*)&brackt,sizeof(brackt)); + out->write((const char*)&stage1,sizeof(stage1)); + out->write((const char*)&finit,sizeof(finit)); + out->write((const char*)&dgtest,sizeof(dgtest)); + out->write((const char*)&width,sizeof(width)); + out->write((const char*)&width1,sizeof(width1)); + out->write((const char*)&stx,sizeof(stx)); + out->write((const char*)&fx,sizeof(fx)); + out->write((const char*)&dgx,sizeof(dgx)); + out->write((const char*)&sty,sizeof(sty)); + out->write((const char*)&fy,sizeof(fy)); + out->write((const char*)&dgy,sizeof(dgy)); + out->write((const char*)&stmin,sizeof(stmin)); + out->write((const char*)&stmax,sizeof(stmax)); + } + + void deserialize(std::istream* in) const { + in->read((char*)&infoc, sizeof(infoc)); + in->read((char*)&dginit, sizeof(dginit)); + in->read((char*)&brackt, sizeof(brackt)); + in->read((char*)&stage1, sizeof(stage1)); + in->read((char*)&finit, sizeof(finit)); + in->read((char*)&dgtest, sizeof(dgtest)); + in->read((char*)&width, sizeof(width)); + in->read((char*)&width1, sizeof(width1)); + in->read((char*)&stx, sizeof(stx)); + in->read((char*)&fx, sizeof(fx)); + in->read((char*)&dgx, sizeof(dgx)); + in->read((char*)&sty, sizeof(sty)); + in->read((char*)&fy, sizeof(fy)); + in->read((char*)&dgy, sizeof(dgy)); + in->read((char*)&stmin, sizeof(stmin)); + in->read((char*)&stmax, sizeof(stmax)); + } + }; + + template + void mcsrch::run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa) + { + if (info != -1) { + infoc = 1; + if ( n == 0 + || maxfev == 0 + || gtol < FloatType(0) + || xtol < FloatType(0) + || stpmin < FloatType(0) + || stpmax < stpmin) { + throw error_internal_error(__FILE__, __LINE__); + } + if (stp <= FloatType(0) || ftol < FloatType(0)) { + throw error_internal_error(__FILE__, __LINE__); + } + // Compute the initial gradient in the search direction + // and check that s is a descent direction. + dginit = FloatType(0); + for (SizeType j = 0; j < n; j++) { + dginit += g[j] * s[is0+j]; + } + if (dginit >= FloatType(0)) { + throw error_search_direction_not_descent(); + } + brackt = false; + stage1 = true; + nfev = 0; + finit = f; + dgtest = ftol*dginit; + width = stpmax - stpmin; + width1 = FloatType(2) * width; + std::copy(x, x+n, wa); + // The variables stx, fx, dgx contain the values of the step, + // function, and directional derivative at the best step. + // The variables sty, fy, dgy contain the value of the step, + // function, and derivative at the other endpoint of + // the interval of uncertainty. + // The variables stp, f, dg contain the values of the step, + // function, and derivative at the current step. + stx = FloatType(0); + fx = finit; + dgx = dginit; + sty = FloatType(0); + fy = finit; + dgy = dginit; + } + for (;;) { + if (info != -1) { + // Set the minimum and maximum steps to correspond + // to the present interval of uncertainty. + if (brackt) { + stmin = std::min(stx, sty); + stmax = std::max(stx, sty); + } + else { + stmin = stx; + stmax = stp + FloatType(4) * (stp - stx); + } + // Force the step to be within the bounds stpmax and stpmin. + stp = std::max(stp, stpmin); + stp = std::min(stp, stpmax); + // If an unusual termination is to occur then let + // stp be the lowest point obtained so far. + if ( (brackt && (stp <= stmin || stp >= stmax)) + || nfev >= maxfev - 1 || infoc == 0 + || (brackt && stmax - stmin <= xtol * stmax)) { + stp = stx; + } + // Evaluate the function and gradient at stp + // and compute the directional derivative. + // We return to main program to obtain F and G. + for (SizeType j = 0; j < n; j++) { + x[j] = wa[j] + stp * s[is0+j]; + } + info=-1; + break; + } + info = 0; + nfev++; + FloatType dg(0); + for (SizeType j = 0; j < n; j++) { + dg += g[j] * s[is0+j]; + } + FloatType ftest1 = finit + stp*dgtest; + // Test for convergence. + if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { + throw error_line_search_failed_rounding_errors( + "Rounding errors prevent further progress." + " There may not be a step which satisfies the" + " sufficient decrease and curvature conditions." + " Tolerances may be too small."); + } + if (stp == stpmax && f <= ftest1 && dg <= dgtest) { + throw error_line_search_failed( + "The step is at the upper bound stpmax()."); + } + if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { + throw error_line_search_failed( + "The step is at the lower bound stpmin()."); + } + if (nfev >= maxfev) { + throw error_line_search_failed( + "Number of function evaluations has reached maxfev()."); + } + if (brackt && stmax - stmin <= xtol * stmax) { + throw error_line_search_failed( + "Relative width of the interval of uncertainty" + " is at most xtol()."); + } + // Check for termination. + if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { + info = 1; + break; + } + // In the first stage we seek a step for which the modified + // function has a nonpositive value and nonnegative derivative. + if ( stage1 && f <= ftest1 + && dg >= std::min(ftol, gtol) * dginit) { + stage1 = false; + } + // A modified function is used to predict the step only if + // we have not obtained a step for which the modified + // function has a nonpositive function value and nonnegative + // derivative, and if a lower function value has been + // obtained but the decrease is not sufficient. + if (stage1 && f <= fx && f > ftest1) { + // Define the modified function and derivative values. + FloatType fm = f - stp*dgtest; + FloatType fxm = fx - stx*dgtest; + FloatType fym = fy - sty*dgtest; + FloatType dgm = dg - dgtest; + FloatType dgxm = dgx - dgtest; + FloatType dgym = dgy - dgtest; + // Call cstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, + brackt, stmin, stmax); + // Reset the function and gradient values for f. + fx = fxm + stx*dgtest; + fy = fym + sty*dgtest; + dgx = dgxm + dgtest; + dgy = dgym + dgtest; + } + else { + // Call mcstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, + brackt, stmin, stmax); + } + // Force a sufficient decrease in the size of the + // interval of uncertainty. + if (brackt) { + if (abs(sty - stx) >= FloatType(0.66) * width1) { + stp = stx + FloatType(0.5) * (sty - stx); + } + width1 = width; + width = abs(sty - stx); + } + } + } + + template + int mcsrch::mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax) + { + bool bound; + FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; + int info = 0; + if ( ( brackt && (stp <= std::min(stx, sty) + || stp >= std::max(stx, sty))) + || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { + return 0; + } + // Determine if the derivatives have opposite sign. + sgnd = dp * (dx / abs(dx)); + if (fp > fx) { + // First case. A higher function value. + // The minimum is bracketed. If the cubic step is closer + // to stx than the quadratic step, the cubic step is taken, + // else the average of the cubic and quadratic steps is taken. + info = 1; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp < stx) gamma = - gamma; + p = (gamma - dx) + theta; + q = ((gamma - dx) + gamma) + dp; + r = p/q; + stpc = stx + r * (stp - stx); + stpq = stx + + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) + * (stp - stx); + if (abs(stpc - stx) < abs(stpq - stx)) { + stpf = stpc; + } + else { + stpf = stpc + (stpq - stpc) / FloatType(2); + } + brackt = true; + } + else if (sgnd < FloatType(0)) { + // Second case. A lower function value and derivatives of + // opposite sign. The minimum is bracketed. If the cubic + // step is closer to stx than the quadratic (secant) step, + // the cubic step is taken, else the quadratic step is taken. + info = 2; + bound = false; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp > stx) gamma = - gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dx; + r = p/q; + stpc = stp + r * (stx - stp); + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (abs(stpc - stp) > abs(stpq - stp)) { + stpf = stpc; + } + else { + stpf = stpq; + } + brackt = true; + } + else if (abs(dp) < abs(dx)) { + // Third case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative decreases. + // The cubic step is only used if the cubic tends to infinity + // in the direction of the step or if the minimum of the cubic + // is beyond stp. Otherwise the cubic step is defined to be + // either stpmin or stpmax. The quadratic (secant) step is also + // computed and if the minimum is bracketed then the the step + // closest to stx is taken, else the step farthest away is taken. + info = 3; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt( + std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); + if (stp > stx) gamma = -gamma; + p = (gamma - dp) + theta; + q = (gamma + (dx - dp)) + gamma; + r = p/q; + if (r < FloatType(0) && gamma != FloatType(0)) { + stpc = stp + r * (stx - stp); + } + else if (stp > stx) { + stpc = stpmax; + } + else { + stpc = stpmin; + } + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (brackt) { + if (abs(stp - stpc) < abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + else { + if (abs(stp - stpc) > abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + } + else { + // Fourth case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative does + // not decrease. If the minimum is not bracketed, the step + // is either stpmin or stpmax, else the cubic step is taken. + info = 4; + bound = false; + if (brackt) { + theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; + s = max3(abs(theta), abs(dy), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); + if (stp > sty) gamma = -gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dy; + r = p/q; + stpc = stp + r * (sty - stp); + stpf = stpc; + } + else if (stp > stx) { + stpf = stpmax; + } + else { + stpf = stpmin; + } + } + // Update the interval of uncertainty. This update does not + // depend on the new step or the case analysis above. + if (fp > fx) { + sty = stp; + fy = fp; + dy = dp; + } + else { + if (sgnd < FloatType(0)) { + sty = stx; + fy = fx; + dy = dx; + } + stx = stp; + fx = fp; + dx = dp; + } + // Compute the new step and safeguard it. + stpf = std::min(stpmax, stpf); + stpf = std::max(stpmin, stpf); + stp = stpf; + if (brackt && bound) { + if (sty > stx) { + stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); + } + else { + stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); + } + } + return info; + } + + /* Compute the sum of a vector times a scalar plus another vector. + Adapted from the subroutine daxpy in + lbfgs.f. + */ + template + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + SizeType incx, + FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + if (n == 0) return; + if (da == FloatType(0)) return; + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dy[iy0+iy] += da * dx[ix0+ix]; + ix += incx; + iy += incy; + } + return; + } + m = n % 4; + for (i = 0; i < m; i++) { + dy[iy0+i] += da * dx[ix0+i]; + } + for (; i < n;) { + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + } + } + + template + inline + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + FloatType* dy) + { + daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); + } + + /* Compute the dot product of two vectors. + Adapted from the subroutine ddot + in lbfgs.f. + */ + template + FloatType ddot( + SizeType n, + const FloatType* dx, + SizeType ix0, + SizeType incx, + const FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + FloatType dtemp(0); + if (n == 0) return FloatType(0); + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dtemp += dx[ix0+ix] * dy[iy0+iy]; + ix += incx; + iy += incy; + } + return dtemp; + } + m = n % 5; + for (i = 0; i < m; i++) { + dtemp += dx[ix0+i] * dy[iy0+i]; + } + for (; i < n;) { + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + } + return dtemp; + } + + template + inline + FloatType ddot( + SizeType n, + const FloatType* dx, + const FloatType* dy) + { + return ddot( + n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); + } + + } // namespace detail + + //! Interface to the LBFGS %minimizer. + /*! This class solves the unconstrained minimization problem +

+          min f(x),  x = (x1,x2,...,x_n),
+      
+ using the limited-memory BFGS method. The routine is + especially effective on problems involving a large number of + variables. In a typical iteration of this method an + approximation Hk to the inverse of the Hessian + is obtained by applying m BFGS updates to a + diagonal matrix Hk0, using information from the + previous m steps. The user specifies the number + m, which determines the amount of storage + required by the routine. The user may also provide the + diagonal matrices Hk0 (parameter diag in the run() + function) if not satisfied with the default choice. The + algorithm is described in "On the limited memory BFGS method for + large scale optimization", by D. Liu and J. Nocedal, Mathematical + Programming B 45 (1989) 503-528. + + The user is required to calculate the function value + f and its gradient g. In order to + allow the user complete control over these computations, + reverse communication is used. The routine must be called + repeatedly under the control of the member functions + requests_f_and_g(), + requests_diag(). + If neither requests_f_and_g() nor requests_diag() is + true the user should check for convergence + (using class traditional_convergence_test or any + other custom test). If the convergence test is negative, + the minimizer may be called again for the next iteration. + + The steplength (stp()) is determined at each iteration + by means of the line search routine mcsrch, which is + a slight modification of the routine CSRCH written + by More' and Thuente. + + The only variables that are machine-dependent are + xtol, + stpmin and + stpmax. + + Fatal errors cause error exceptions to be thrown. + The generic class error is sub-classed (e.g. + class error_line_search_failed) to facilitate + granular %error handling. + + A note on performance: Using Compaq Fortran V5.4 and + Compaq C++ V6.5, the C++ implementation is about 15% slower + than the Fortran implementation. + */ + template + class minimizer + { + public: + //! Default constructor. Some members are not initialized! + minimizer() + : n_(0), m_(0), maxfev_(0), + gtol_(0), xtol_(0), + stpmin_(0), stpmax_(0), + ispt(0), iypt(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param m The number of corrections used in the BFGS update. + Values of m less than 3 are not recommended; + large values of m will result in excessive + computing time. 3 <= m <= 7 is + recommended. + Restriction: m > 0. + + @param maxfev Maximum number of function evaluations + per line search. + Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param gtol Controls the accuracy of the line search. + If the function and gradient evaluations are inexpensive with + respect to the cost of the iteration (which is sometimes the + case when solving very large problems) it may be advantageous + to set gtol to a small value. A typical small + value is 0.1. + Restriction: gtol should be greater than 1e-4. + + @param xtol An estimate of the machine precision (e.g. 10e-16 + on a SUN station 3/60). The line search routine will + terminate if the relative width of the interval of + uncertainty is less than xtol. + + @param stpmin Specifies the lower bound for the step + in the line search. + The default value is 1e-20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + + @param stpmax specifies the upper bound for the step + in the line search. + The default value is 1e20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + */ + explicit + minimizer( + SizeType n, + SizeType m = 5, + SizeType maxfev = 20, + FloatType gtol = FloatType(0.9), + FloatType xtol = FloatType(1.e-16), + FloatType stpmin = FloatType(1.e-20), + FloatType stpmax = FloatType(1.e20)) + : n_(n), m_(m), maxfev_(maxfev), + gtol_(gtol), xtol_(xtol), + stpmin_(stpmin), stpmax_(stpmax), + iflag_(0), requests_f_and_g_(false), requests_diag_(false), + iter_(0), nfun_(0), stp_(0), + stp1(0), ftol(0.0001), ys(0), point(0), npt(0), + ispt(n+2*m), iypt((n+2*m)+n*m), + info(0), bound(0), nfev(0) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (m_ == 0) { + throw error_improper_input_parameter("m = 0."); + } + if (maxfev_ == 0) { + throw error_improper_input_parameter("maxfev = 0."); + } + if (gtol_ <= FloatType(1.e-4)) { + throw error_improper_input_parameter("gtol <= 1.e-4."); + } + if (xtol_ < FloatType(0)) { + throw error_improper_input_parameter("xtol < 0."); + } + if (stpmin_ < FloatType(0)) { + throw error_improper_input_parameter("stpmin < 0."); + } + if (stpmax_ < stpmin) { + throw error_improper_input_parameter("stpmax < stpmin"); + } + w_.resize(n_*(2*m_+1)+2*m_); + scratch_array_.resize(n_); + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + //! Number of corrections kept (as passed to the constructor). + SizeType m() const { return m_; } + + /*! \brief Maximum number of evaluations of the objective function + per line search (as passed to the constructor). + */ + SizeType maxfev() const { return maxfev_; } + + /*! \brief Control of the accuracy of the line search. + (as passed to the constructor). + */ + FloatType gtol() const { return gtol_; } + + //! Estimate of the machine precision (as passed to the constructor). + FloatType xtol() const { return xtol_; } + + /*! \brief Lower bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmin() const { return stpmin_; } + + /*! \brief Upper bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmax() const { return stpmax_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the objective function (f) and + gradients (g) for the current point + (x). To continue the minimization the + run() function is called again with the updated values for + f and g. +

+ See also: requests_diag() + */ + bool requests_f_and_g() const { return requests_f_and_g_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the diagonal matrix (diag) + for the current point (x). + To continue the minimization the run() function is called + again with the updated values for diag. +

+ See also: requests_f_and_g() + */ + bool requests_diag() const { return requests_diag_; } + + //! Number of iterations so far. + /*! Note that one iteration may involve multiple evaluations + of the objective function. +

+ See also: nfun() + */ + SizeType iter() const { return iter_; } + + //! Total number of evaluations of the objective function so far. + /*! The total number of function evaluations increases by the + number of evaluations required for the line search. The total + is only increased after a successful line search. +

+ See also: iter() + */ + SizeType nfun() const { return nfun_; } + + //! Norm of gradient given gradient array of length n(). + FloatType euclidean_norm(const FloatType* a) const { + return std::sqrt(detail::ddot(n_, a, a)); + } + + //! Current stepsize. + FloatType stp() const { return stp_; } + + //! Execution of one step of the minimization. + /*! @param x On initial entry this must be set by the user to + the values of the initial estimate of the solution vector. + + @param f Before initial entry or on re-entry under the + control of requests_f_and_g(), f must be set + by the user to contain the value of the objective function + at the current point x. + + @param g Before initial entry or on re-entry under the + control of requests_f_and_g(), g must be set + by the user to contain the components of the gradient at + the current point x. + + The return value is true if either + requests_f_and_g() or requests_diag() is true. + Otherwise the user should check for convergence + (e.g. using class traditional_convergence_test) and + call the run() function again to continue the minimization. + If the return value is false the user + should not update f, g or + diag (other overload) before calling + the run() function again. + + Note that x is always modified by the run() + function. Depending on the situation it can therefore be + necessary to evaluate the objective function one more time + after the minimization is terminated. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g) + { + return generic_run(x, f, g, false, 0); + } + + //! Execution of one step of the minimization. + /*! @param x See other overload. + + @param f See other overload. + + @param g See other overload. + + @param diag On initial entry or on re-entry under the + control of requests_diag(), diag must be set by + the user to contain the values of the diagonal matrix Hk0. + The routine will return at each iteration of the algorithm + with requests_diag() set to true. +

+ Restriction: all elements of diag must be + positive. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g, + const FloatType* diag) + { + return generic_run(x, f, g, true, diag); + } + + void serialize(std::ostream* out) const { + out->write((const char*)&n_, sizeof(n_)); // sanity check + out->write((const char*)&m_, sizeof(m_)); // sanity check + SizeType fs = sizeof(FloatType); + out->write((const char*)&fs, sizeof(fs)); // sanity check + + mcsrch_instance.serialize(out); + out->write((const char*)&iflag_, sizeof(iflag_)); + out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + out->write((const char*)&requests_diag_, sizeof(requests_diag_)); + out->write((const char*)&iter_, sizeof(iter_)); + out->write((const char*)&nfun_, sizeof(nfun_)); + out->write((const char*)&stp_, sizeof(stp_)); + out->write((const char*)&stp1, sizeof(stp1)); + out->write((const char*)&ftol, sizeof(ftol)); + out->write((const char*)&ys, sizeof(ys)); + out->write((const char*)&point, sizeof(point)); + out->write((const char*)&npt, sizeof(npt)); + out->write((const char*)&info, sizeof(info)); + out->write((const char*)&bound, sizeof(bound)); + out->write((const char*)&nfev, sizeof(nfev)); + out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); + out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + void deserialize(std::istream* in) { + SizeType n, m, fs; + in->read((char*)&n, sizeof(n)); + in->read((char*)&m, sizeof(m)); + in->read((char*)&fs, sizeof(fs)); + assert(n == n_); + assert(m == m_); + assert(fs == sizeof(FloatType)); + + mcsrch_instance.deserialize(in); + in->read((char*)&iflag_, sizeof(iflag_)); + in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + in->read((char*)&requests_diag_, sizeof(requests_diag_)); + in->read((char*)&iter_, sizeof(iter_)); + in->read((char*)&nfun_, sizeof(nfun_)); + in->read((char*)&stp_, sizeof(stp_)); + in->read((char*)&stp1, sizeof(stp1)); + in->read((char*)&ftol, sizeof(ftol)); + in->read((char*)&ys, sizeof(ys)); + in->read((char*)&point, sizeof(point)); + in->read((char*)&npt, sizeof(npt)); + in->read((char*)&info, sizeof(info)); + in->read((char*)&bound, sizeof(bound)); + in->read((char*)&nfev, sizeof(nfev)); + in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); + in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + protected: + static void throw_diagonal_element_not_positive(SizeType i) { + throw error_improper_input_data( + "The " + error::itoa(i) + ". diagonal element of the" + " inverse Hessian approximation is not positive."); + } + + bool generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag); + + detail::mcsrch mcsrch_instance; + const SizeType n_; + const SizeType m_; + const SizeType maxfev_; + const FloatType gtol_; + const FloatType xtol_; + const FloatType stpmin_; + const FloatType stpmax_; + int iflag_; + bool requests_f_and_g_; + bool requests_diag_; + SizeType iter_; + SizeType nfun_; + FloatType stp_; + FloatType stp1; + FloatType ftol; + FloatType ys; + SizeType point; + SizeType npt; + const SizeType ispt; + const SizeType iypt; + int info; + SizeType bound; + SizeType nfev; + std::vector w_; + std::vector scratch_array_; + }; + + template + bool minimizer::generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag) + { + bool execute_entire_while_loop = false; + if (!(requests_f_and_g_ || requests_diag_)) { + execute_entire_while_loop = true; + } + requests_f_and_g_ = false; + requests_diag_ = false; + FloatType* w = &(*(w_.begin())); + if (iflag_ == 0) { // Initialize. + nfun_ = 1; + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + else { + std::fill_n(scratch_array_.begin(), n_, FloatType(1)); + diag = &(*(scratch_array_.begin())); + } + for (SizeType i = 0; i < n_; i++) { + w[ispt + i] = -g[i] * diag[i]; + } + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm == FloatType(0)) return false; + stp1 = FloatType(1) / gnorm; + execute_entire_while_loop = true; + } + if (execute_entire_while_loop) { + bound = iter_; + iter_++; + info = 0; + if (iter_ != 1) { + if (iter_ > m_) bound = m_; + ys = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); + if (!diagco) { + FloatType yy = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); + std::fill_n(scratch_array_.begin(), n_, ys / yy); + diag = &(*(scratch_array_.begin())); + } + else { + iflag_ = 2; + requests_diag_ = true; + return true; + } + } + } + if (execute_entire_while_loop || iflag_ == 2) { + if (iter_ != 1) { + if (diag == 0) { + throw error_internal_error(__FILE__, __LINE__); + } + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + SizeType cp = point; + if (point == 0) cp = m_; + w[n_ + cp -1] = 1 / ys; + SizeType i; + for (i = 0; i < n_; i++) { + w[i] = -g[i]; + } + cp = point; + for (i = 0; i < bound; i++) { + if (cp == 0) cp = m_; + cp--; + FloatType sq = detail::ddot( + n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + SizeType inmc=n_+m_+cp; + SizeType iycn=iypt+cp*n_; + w[inmc] = w[n_ + cp] * sq; + detail::daxpy(n_, -w[inmc], w, iycn, w); + } + for (i = 0; i < n_; i++) { + w[i] *= diag[i]; + } + for (i = 0; i < bound; i++) { + FloatType yr = detail::ddot( + n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + FloatType beta = w[n_ + cp] * yr; + SizeType inmc=n_+m_+cp; + beta = w[inmc] - beta; + SizeType iscn=ispt+cp*n_; + detail::daxpy(n_, beta, w, iscn, w); + cp++; + if (cp == m_) cp = 0; + } + std::copy(w, w+n_, w+(ispt + point * n_)); + } + stp_ = FloatType(1); + if (iter_ == 1) stp_ = stp1; + std::copy(g, g+n_, w); + } + mcsrch_instance.run( + gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, + stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); + if (info == -1) { + iflag_ = 1; + requests_f_and_g_ = true; + return true; + } + if (info != 1) { + throw error_internal_error(__FILE__, __LINE__); + } + nfun_ += nfev; + npt = point*n_; + for (SizeType i = 0; i < n_; i++) { + w[ispt + npt + i] = stp_ * w[ispt + npt + i]; + w[iypt + npt + i] = g[i] - w[i]; + } + point++; + if (point == m_) point = 0; + return false; + } + + //! Traditional LBFGS convergence test. + /*! This convergence test is equivalent to the test embedded + in the lbfgs.f Fortran code. The test assumes that + there is a meaningful relation between the Euclidean norm of the + parameter vector x and the norm of the gradient + vector g. Therefore this test should not be used if + this assumption is not correct for a given problem. + */ + template + class traditional_convergence_test + { + public: + //! Default constructor. + traditional_convergence_test() + : n_(0), eps_(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param eps Determines the accuracy with which the solution + is to be found. + */ + explicit + traditional_convergence_test( + SizeType n, + FloatType eps = FloatType(1.e-5)) + : n_(n), eps_(eps) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (eps_ < FloatType(0)) { + throw error_improper_input_parameter("eps < 0."); + } + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + /*! \brief Accuracy with which the solution is to be found + (as passed to the constructor). + */ + FloatType eps() const { return eps_; } + + //! Execution of the convergence test for the given parameters. + /*! Returns true if +

+            ||g|| < eps * max(1,||x||),
+          
+ where ||.|| denotes the Euclidean norm. + + @param x Current solution vector. + + @param g Components of the gradient at the current + point x. + */ + bool + operator()(const FloatType* x, const FloatType* g) const + { + FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; + return false; + } + protected: + const SizeType n_; + const FloatType eps_; + }; + +}} // namespace scitbx::lbfgs + +template +std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer& min) { + return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); +} + + +#endif // SCITBX_LBFGS_H diff --git a/training/utils/lbfgs_test.cc b/training/utils/lbfgs_test.cc new file mode 100644 index 00000000..9678e788 --- /dev/null +++ b/training/utils/lbfgs_test.cc @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include "lbfgs.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer() { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::minimizer opt(3); + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt.run(x, obj, g); + if (!opt.requests_f_and_g()) { + if (converged(x,g)) break; + opt.run(x, obj, g); + } + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + cerr << opt << endl; + } while (true); + return obj; +} + +double TestPersistentOptimizer() { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + scitbx::lbfgs::minimizer opt(3); + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt.deserialize(&is); + } + opt.run(x, obj, g); + ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!converged(x, g)); + return obj; +} + +void TestSparseVector() { + cerr << "Testing SparseVector serialization.\n"; + int f1 = FD::Convert("Feature_1"); + int f2 = FD::Convert("Feature_2"); + FD::Convert("LanguageModel"); + int f4 = FD::Convert("SomeFeature"); + int f5 = FD::Convert("SomeOtherFeature"); + SparseVector g; + g.set_value(f2, log(0.5)); + g.set_value(f4, log(0.125)); + g.set_value(f1, 0); + g.set_value(f5, 23.777); + ostringstream os; + double iobj = 1.5; + B64::Encode(iobj, g, &os); + cerr << iobj << "\t" << g << endl; + string data = os.str(); + cout << data << endl; + SparseVector v; + double obj; + bool decode_b64 = B64::Decode(&obj, &v, &data[0], data.size()); + cerr << obj << "\t" << v << endl; + assert(decode_b64); + assert(obj == iobj); + assert(g.size() == v.size()); +} + +int main() { + double o1 = TestOptimizer(); + double o2 = TestPersistentOptimizer(); + if (fabs(o1 - o2) > 1e-5) { + cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + return 1; + } + TestSparseVector(); + cerr << "SUCCESS\n"; + return 0; +} + diff --git a/training/utils/libcall.pl b/training/utils/libcall.pl new file mode 100644 index 00000000..c7d0f128 --- /dev/null +++ b/training/utils/libcall.pl @@ -0,0 +1,71 @@ +use IPC::Open3; +use Symbol qw(gensym); + +$DUMMY_STDERR = gensym(); +$DUMMY_STDIN = gensym(); + +# Run the command and ignore failures +sub unchecked_call { + system("@_") +} + +# Run the command and return its output, if any ignoring failures +sub unchecked_output { + return `@_` +} + +# WARNING: Do not use this for commands that will return large amounts +# of stdout or stderr -- they might block indefinitely +sub check_output { + print STDERR "Executing and gathering output: @_\n"; + + my $pid = open3($DUMMY_STDIN, \*PH, $DUMMY_STDERR, @_); + my $proc_output = ""; + while( ) { + $proc_output .= $_; + } + waitpid($pid, 0); + # TODO: Grab signal that the process died from + my $child_exit_status = $? >> 8; + if($child_exit_status == 0) { + return $proc_output; + } else { + print STDERR "ERROR: Execution of @_ failed.\n"; + exit(1); + } +} + +# Based on Moses' safesystem sub +sub check_call { + print STDERR "Executing: @_\n"; + system(@_); + my $exitcode = $? >> 8; + if($exitcode == 0) { + return 0; + } elsif ($? == -1) { + print STDERR "ERROR: Failed to execute: @_\n $!\n"; + exit(1); + + } elsif ($? & 127) { + printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n", + ($? & 127), ($? & 128) ? 'with' : 'without'; + exit(1); + + } else { + print STDERR "Failed with exit code: $exitcode\n" if $exitcode; + exit($exitcode); + } +} + +sub check_bash_call { + my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); + check_call(@args); +} + +sub check_bash_output { + my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); + return check_output(@args); +} + +# perl module weirdness... +return 1; diff --git a/training/utils/online_optimizer.cc b/training/utils/online_optimizer.cc new file mode 100644 index 00000000..3ed95452 --- /dev/null +++ b/training/utils/online_optimizer.cc @@ -0,0 +1,16 @@ +#include "online_optimizer.h" + +LearningRateSchedule::~LearningRateSchedule() {} + +double StandardLearningRate::eta(int k) const { + return eta_0_ / (1.0 + k / N_); +} + +double ExponentialDecayLearningRate::eta(int k) const { + return eta_0_ * pow(alpha_, k / N_); +} + +OnlineOptimizer::~OnlineOptimizer() {} + +void OnlineOptimizer::ResetEpochImpl() {} + diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h new file mode 100644 index 00000000..28d89344 --- /dev/null +++ b/training/utils/online_optimizer.h @@ -0,0 +1,129 @@ +#ifndef _ONL_OPTIMIZE_H_ +#define _ONL_OPTIMIZE_H_ + +#include +#include +#include +#include +#include "sparse_vector.h" + +struct LearningRateSchedule { + virtual ~LearningRateSchedule(); + // returns the learning rate for the kth iteration + virtual double eta(int k) const = 0; +}; + +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) +struct StandardLearningRate : public LearningRateSchedule { + StandardLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2) : + eta_0_(eta_0), + N_(static_cast(batch_size)) {} + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; +}; + +struct ExponentialDecayLearningRate : public LearningRateSchedule { + ExponentialDecayLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2, + double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) + ) : eta_0_(eta_0), + N_(static_cast(batch_size)), + alpha_(alpha) { + assert(alpha > 0); + assert(alpha < 1.0); + } + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; + const double alpha_; +}; + +class OnlineOptimizer { + public: + virtual ~OnlineOptimizer(); + OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t batch_size, + const std::vector& frozen_feats = std::vector()) + : N_(batch_size),schedule_(s),k_() { + for (int i = 0; i < frozen_feats.size(); ++i) + frozen_.insert(frozen_feats[i]); + } + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } + void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { + ++k_; + const double eta = schedule_->eta(k_); + UpdateWeightsImpl(eta, approx_g, max_feat, weights); + } + + protected: + virtual void ResetEpochImpl(); + virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; + const size_t N_; // number of training instances per batch + std::set frozen_; // frozen (non-optimizing) features + + private: + std::tr1::shared_ptr schedule_; + int k_; // iteration count +}; + +class CumulativeL1OnlineOptimizer : public OnlineOptimizer { + public: + CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t training_instances, double C, + const std::vector& frozen) : + OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} + + protected: + void ResetEpochImpl() { u_ = 0; } + void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { + u_ += eta * C_ / N_; + for (SparseVector::const_iterator it = approx_g.begin(); + it != approx_g.end(); ++it) { + if (frozen_.count(it->first) == 0) + weights->add_value(it->first, eta * it->second); + } + for (int i = 1; i < max_feat; ++i) + if (frozen_.count(i) == 0) ApplyPenalty(i, weights); + } + + private: + void ApplyPenalty(int i, SparseVector* w) { + const double z = w->value(i); + double w_i = z; + double q_i = q_.value(i); + if (w_i > 0.0) + w_i = std::max(0.0, w_i - (u_ + q_i)); + else if (w_i < 0.0) + w_i = std::min(0.0, w_i + (u_ - q_i)); + q_i += w_i - z; + if (q_i == 0.0) + q_.erase(i); + else + q_.set_value(i, q_i); + if (w_i == 0.0) + w->erase(i); + else + w->set_value(i, w_i); + } + + const double C_; // reguarlization strength + double u_; + SparseVector q_; +}; + +#endif diff --git a/training/utils/optimize.cc b/training/utils/optimize.cc new file mode 100644 index 00000000..41ac90d8 --- /dev/null +++ b/training/utils/optimize.cc @@ -0,0 +1,102 @@ +#include "optimize.h" + +#include +#include + +#include "lbfgs.h" + +using namespace std; + +BatchOptimizer::~BatchOptimizer() {} + +void BatchOptimizer::Save(ostream* out) const { + out->write((const char*)&eval_, sizeof(eval_)); + out->write((const char*)&has_converged_, sizeof(has_converged_)); + SaveImpl(out); + unsigned int magic = 0xABCDDCBA; // should be uint32_t + out->write((const char*)&magic, sizeof(magic)); +} + +void BatchOptimizer::Load(istream* in) { + in->read((char*)&eval_, sizeof(eval_)); + in->read((char*)&has_converged_, sizeof(has_converged_)); + LoadImpl(in); + unsigned int magic = 0; // should be uint32_t + in->read((char*)&magic, sizeof(magic)); + assert(magic == 0xABCDDCBA); + cerr << Name() << " EVALUATION #" << eval_ << endl; +} + +void BatchOptimizer::SaveImpl(ostream* out) const { + (void)out; +} + +void BatchOptimizer::LoadImpl(istream* in) { + (void)in; +} + +string RPropOptimizer::Name() const { + return "RPropOptimizer"; +} + +void RPropOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + for (int i = 0; i < g.size(); ++i) { + const double g_i = g[i]; + const double sign_i = (signbit(g_i) ? -1.0 : 1.0); + const double prod = g_i * prev_g_[i]; + if (prod > 0.0) { + const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); + (*x)[i] -= dij * sign_i; + delta_ij_[i] = dij; + prev_g_[i] = g_i; + } else if (prod < 0.0) { + delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); + prev_g_[i] = 0.0; + } else { + (*x)[i] -= delta_ij_[i] * sign_i; + prev_g_[i] = g_i; + } + } +} + +void RPropOptimizer::SaveImpl(ostream* out) const { + const size_t n = prev_g_.size(); + out->write((const char*)&n, sizeof(n)); + out->write((const char*)&prev_g_[0], sizeof(double) * n); + out->write((const char*)&delta_ij_[0], sizeof(double) * n); +} + +void RPropOptimizer::LoadImpl(istream* in) { + size_t n; + in->read((char*)&n, sizeof(n)); + assert(n == prev_g_.size()); + assert(n == delta_ij_.size()); + in->read((char*)&prev_g_[0], sizeof(double) * n); + in->read((char*)&delta_ij_[0], sizeof(double) * n); +} + +string LBFGSOptimizer::Name() const { + return "LBFGSOptimizer"; +} + +LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : + opt_(num_feats, memory_buffers) {} + +void LBFGSOptimizer::SaveImpl(ostream* out) const { + opt_.serialize(out); +} + +void LBFGSOptimizer::LoadImpl(istream* in) { + opt_.deserialize(in); +} + +void LBFGSOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + opt_.run(&(*x)[0], obj, &g[0]); + if (!opt_.requests_f_and_g()) opt_.run(&(*x)[0], obj, &g[0]); + // cerr << opt_ << endl; +} + diff --git a/training/utils/optimize.h b/training/utils/optimize.h new file mode 100644 index 00000000..07943b44 --- /dev/null +++ b/training/utils/optimize.h @@ -0,0 +1,92 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include +#include +#include +#include + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class BatchOptimizer { + public: + BatchOptimizer() : eval_(1), has_converged_(false) {} + virtual ~BatchOptimizer(); + virtual std::string Name() const = 0; + int EvaluationCount() const { return eval_; } + bool HasConverged() const { return has_converged_; } + + void Optimize(const double& obj, + const std::vector& g, + std::vector* x) { + assert(g.size() == x->size()); + ++eval_; + OptimizeImpl(obj, g, x); + scitbx::lbfgs::traditional_convergence_test converged(g.size()); + has_converged_ = converged(&(*x)[0], &g[0]); + } + + void Save(std::ostream* out) const; + void Load(std::istream* in); + protected: + virtual void SaveImpl(std::ostream* out) const; + virtual void LoadImpl(std::istream* in); + virtual void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x) = 0; + + int eval_; + private: + bool has_converged_; +}; + +class RPropOptimizer : public BatchOptimizer { + public: + explicit RPropOptimizer(int num_vars, + double eta_plus = 1.2, + double eta_minus = 0.5, + double delta_0 = 0.1, + double delta_max = 50.0, + double delta_min = 1e-6) : + prev_g_(num_vars, 0.0), + delta_ij_(num_vars, delta_0), + eta_plus_(eta_plus), + eta_minus_(eta_minus), + delta_max_(delta_max), + delta_min_(delta_min) { + assert(eta_plus > 1.0); + assert(eta_minus > 0.0 && eta_minus < 1.0); + assert(delta_max > 0.0); + assert(delta_min > 0.0); + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + private: + std::vector prev_g_; + std::vector delta_ij_; + const double eta_plus_; + const double eta_minus_; + const double delta_max_; + const double delta_min_; +}; + +class LBFGSOptimizer : public BatchOptimizer { + public: + explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); + std::string Name() const; + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + private: + scitbx::lbfgs::minimizer opt_; +}; + +#endif diff --git a/training/utils/optimize_test.cc b/training/utils/optimize_test.cc new file mode 100644 index 00000000..bff2ca03 --- /dev/null +++ b/training/utils/optimize_test.cc @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include "optimize.h" +#include "online_optimizer.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer(BatchOptimizer* opt) { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt->Optimize(obj, g, &x); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!opt->HasConverged()); + return obj; +} + +double TestPersistentOptimizer(BatchOptimizer* opt) { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + bool converged = false; + while (!converged) { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt->Load(&is); + } + opt->Optimize(obj, g, &x); + ostringstream os(ios::binary); opt->Save(&os); state = os.str(); + + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + converged = opt->HasConverged(); + if (!converged) { + // now screw up the state (should be undone by Load) + obj += 2.0; + g[1] = -g[2]; + vector x2 = x; + try { + opt->Optimize(obj, g, &x2); + } catch (...) { } + } + } + return obj; +} + +template +void TestOptimizerVariants(int num_vars) { + O oa(num_vars); + cerr << "-------------------------------------------------------------------------\n"; + cerr << "TESTING: " << oa.Name() << endl; + double o1 = TestOptimizer(&oa); + O ob(num_vars); + double o2 = TestPersistentOptimizer(&ob); + if (o1 != o2) { + cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + exit(1); + } + cerr << oa.Name() << " SUCCESS\n"; +} + +using namespace std::tr1; + +void TestOnline() { + size_t N = 20; + double C = 1.0; + double eta0 = 0.2; + std::tr1::shared_ptr r(new ExponentialDecayLearningRate(N, eta0, 0.85)); + //shared_ptr r(new StandardLearningRate(N, eta0)); + CumulativeL1OnlineOptimizer opt(r, N, C, std::vector()); + assert(r->eta(10) < r->eta(1)); +} + +int main() { + int n = 3; + TestOptimizerVariants(n); + TestOptimizerVariants(n); + TestOnline(); + return 0; +} + diff --git a/training/utils/parallelize.pl b/training/utils/parallelize.pl new file mode 100755 index 00000000..4197e0e5 --- /dev/null +++ b/training/utils/parallelize.pl @@ -0,0 +1,423 @@ +#!/usr/bin/env perl + +# Author: Adam Lopez +# +# This script takes a command that processes input +# from stdin one-line-at-time, and parallelizes it +# on the cluster using David Chiang's sentserver/ +# sentclient architecture. +# +# Prerequisites: the command *must* read each line +# without waiting for subsequent lines of input +# (for instance, a command which must read all lines +# of input before processing will not work) and +# return it to the output *without* buffering +# multiple lines. + +#TODO: if -j 1, run immediately, not via sentserver? possible differences in environment might make debugging harder + +#ANNOYANCE: if input is shorter than -j n lines, or at the very last few lines, repeatedly sleeps. time cut down to 15s from 60s + +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } +use LocalConfig; + +use Cwd qw/ abs_path cwd getcwd /; +use File::Temp qw/ tempfile /; +use Getopt::Long; +use IPC::Open2; +use strict; +use POSIX ":sys_wait_h"; + +use File::Basename; +my $myDir = dirname(__FILE__); +print STDERR __FILE__." -> $myDir\n"; +push(@INC, $myDir); +require "libcall.pl"; + +my $tailn=5; # +0 = concatenate all the client logs. 5 = last 5 lines +my $recycle_clients; # spawn new clients when previous ones terminate +my $stay_alive; # dont let server die when having zero clients +my $joblist = ""; +my $errordir=""; +my $multiline; +my $workdir = '.'; +my $numnodes = 8; +my $user = $ENV{"USER"}; +my $pmem = "9g"; +my $basep=50300; +my $randp=300; +my $tryp=50; +my $no_which; +my $no_cd; + +my $DEBUG=$ENV{DEBUG}; +print STDERR "DEBUG=$DEBUG output enabled.\n" if $DEBUG; +my $verbose = 1; +sub verbose { + if ($verbose) { + print STDERR @_,"\n"; + } +} +sub debug { + if ($DEBUG) { + my ($package, $filename, $line) = caller; + print STDERR "DEBUG: $filename($line): ",join(' ',@_),"\n"; + } +} +my $is_shell_special=qr.[ \t\n\\><|&;"'`~*?{}$!()].; +my $shell_escape_in_quote=qr.[\\"\$`!].; +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + return '""' unless $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} +sub preview_files { + my ($l,$skipempty,$footer,$n)=@_; + $n=$tailn unless defined $n; + my @f=grep { ! ($skipempty && -z $_) } @$l; + my $fn=join(' ',map {escape_shell($_)} @f); + my $cmd="tail -n $n $fn"; + unchecked_output("$cmd").($footer?"\nNONEMPTY FILES:\n$fn\n":""); +} +sub prefix_dirname($) { + #like `dirname but if ends in / then return the whole thing + local ($_)=@_; + if (/\/$/) { + $_; + } else { + s#/[^/]$##; + $_ ? $_ : ''; + } +} +sub ensure_final_slash($) { + local ($_)=@_; + m#/$# ? $_ : ($_."/"); +} +sub extend_path($$;$$) { + my ($base,$ext,$mkdir,$baseisdir)=@_; + if (-d $base) { + $base.="/"; + } else { + my $dir; + if ($baseisdir) { + $dir=$base; + $base.='/' unless $base =~ /\/$/; + } else { + $dir=prefix_dirname($base); + } + my @cmd=("/bin/mkdir","-p",$dir); + check_call(@cmd) if $mkdir; + } + return $base.$ext; +} + +my $abscwd=abs_path(&getcwd); +sub print_help; + +my $use_fork; +my @pids; + +# Process command-line options +unless (GetOptions( + "stay-alive" => \$stay_alive, + "recycle-clients" => \$recycle_clients, + "error-dir=s" => \$errordir, + "multi-line" => \$multiline, + "workdir=s" => \$workdir, + "use-fork" => \$use_fork, + "verbose" => \$verbose, + "jobs=i" => \$numnodes, + "pmem=s" => \$pmem, + "baseport=i" => \$basep, +# "iport=i" => \$randp, #for short name -i + "no-which!" => \$no_which, + "no-cd!" => \$no_cd, + "tailn=s" => \$tailn, +) && scalar @ARGV){ + print_help(); + die "bad options."; +} + +my $cmd = ""; +my $prog=shift; +if ($no_which) { + $cmd=$prog; +} else { + $cmd=check_output("which $prog"); + chomp $cmd; + die "$prog not found - $cmd" unless $cmd; +} +#$cmd=abs_path($cmd); +for my $arg (@ARGV) { + $cmd .= " ".escape_shell($arg); +} +die "Please specify a command to parallelize\n" if $cmd eq ''; + +my $cdcmd=$no_cd ? '' : ("cd ".escape_shell($abscwd)."\n"); + +my $executable = $cmd; +$executable =~ s/^\s*(\S+)($|\s.*)/$1/; +$executable=check_output("basename $executable"); +chomp $executable; + + +print STDERR "Parallelizing ($numnodes ways): $cmd\n\n"; + +# create -e dir and save .sh +use File::Temp qw/tempdir/; +unless ($errordir) { + $errordir=tempdir("$executable.XXXXXX",CLEANUP=>1); +} +if ($errordir) { + my $scriptfile=extend_path("$errordir/","$executable.sh",1,1); + -d $errordir || die "should have created -e dir $errordir"; + open SF,">",$scriptfile || die; + print SF "$cdcmd$cmd\n"; + close SF; + chmod 0755,$scriptfile; + $errordir=abs_path($errordir); + &verbose("-e dir: $errordir"); +} + +# set cleanup handler +my @cleanup_cmds; +sub cleanup; +sub cleanup_and_die; +$SIG{INT} = "cleanup_and_die"; +$SIG{TERM} = "cleanup_and_die"; +$SIG{HUP} = "cleanup_and_die"; + +# other subs: +sub numof_live_jobs; +sub launch_job_on_node; + + +# vars +my $mydir = check_output("dirname $0"); chomp $mydir; +my $sentserver = "$mydir/sentserver"; +my $sentclient = "$mydir/sentclient"; +my $host = check_output("hostname"); +chomp $host; + + +# find open port +srand; +my $port = 50300+int(rand($randp)); +my $endp=$port+$tryp; +sub listening_port_lines { + my $quiet=$verbose?'':'2>/dev/null'; + return unchecked_output("netstat -a -n $quiet | grep LISTENING | grep -i tcp"); +} +my $netstat=&listening_port_lines; + +if ($verbose){ print STDERR "Testing port $port...";} + +while ($netstat=~/$port/ || &listening_port_lines=~/$port/){ + if ($verbose){ print STDERR "port is busy\n";} + $port++; + if ($port > $endp){ + die "Unable to find open port\n"; + } + if ($verbose){ print STDERR "Testing port $port... "; } +} +if ($verbose){ + print STDERR "port $port is available\n"; +} + +my $key = int(rand()*1000000); + +my $multiflag = ""; +if ($multiline){ $multiflag = "-m"; print STDERR "expecting multiline output.\n"; } +my $stay_alive_flag = ""; +if ($stay_alive){ $stay_alive_flag = "--stay-alive"; print STDERR "staying alive while no clients are connected.\n"; } + +my $node_count = 0; +my $script = ""; +# fork == one thread runs the sentserver, while the +# other spawns the sentclient commands. +my $pid = fork; +if ($pid == 0) { # child + sleep 8; # give other thread time to start sentserver + $script = "$cdcmd$sentclient $host:$port:$key $cmd"; + + if ($verbose){ + print STDERR "Client script:\n====\n"; + print STDERR $script; + print STDERR "====\n"; + } + for (my $jobn=0; $jobn<$numnodes; $jobn++){ + launch_job(); + } + if ($recycle_clients) { + my $ret; + my $livejobs; + while (1) { + $ret = waitpid($pid, WNOHANG); + #print STDERR "waitpid $pid ret = $ret \n"; + last if ($ret != 0); + $livejobs = numof_live_jobs(); + if ($numnodes >= $livejobs ) { # a client terminated, OR # lines of input was less than -j + print STDERR "num of requested nodes = $numnodes; num of currently live jobs = $livejobs; Client terminated - launching another.\n"; + launch_job(); + } else { + sleep 15; + } + } + } + print STDERR "CHILD PROCESSES SPAWNED ... WAITING\n"; + for my $p (@pids) { + waitpid($p, 0); + } +} else { +# my $todo = "$sentserver -k $key $multiflag $port "; + my $todo = "$sentserver -k $key $multiflag $port $stay_alive_flag "; + if ($verbose){ print STDERR "Running: $todo\n"; } + check_call($todo); + print STDERR "Call to $sentserver returned.\n"; + cleanup(); + exit(0); +} + +sub numof_live_jobs { + if ($use_fork) { + die "not implemented"; + } else { + # We can probably continue decoding if the qstat error is only temporary + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat"))); + return ($#livejobs + 1); + } +} +my (@errors,@outs,@cmds); + +sub launch_job { + if ($use_fork) { return launch_job_fork(); } + my $errorfile = "/dev/null"; + my $outfile = "/dev/null"; + $node_count++; + my $clientname = $executable; + $clientname =~ s/^(.{4}).*$/$1/; + $clientname = "$clientname.$node_count"; + if ($errordir){ + $errorfile = "$errordir/$clientname.ER"; + $outfile = "$errordir/$clientname.OU"; + push @errors,$errorfile; + push @outs,$outfile; + } + my $todo = qsub_args($pmem) . " -N $clientname -o $outfile -e $errorfile"; + push @cmds,$todo; + + print STDERR "Running: $todo\n"; + local(*QOUT, *QIN); + open2(\*QOUT, \*QIN, $todo) or die "Failed to open2: $!"; + print QIN $script; + close QIN; + while (my $jobid=){ + chomp $jobid; + if ($verbose){ print STDERR "Launched client job: $jobid"; } + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + print STDERR " short job id $jobid\n"; + if ($verbose){ + print STDERR "cd: $abscwd\n"; + print STDERR "cmd: $cmd\n"; + } + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + my $cleanfn="qdel $jobid 2> /dev/null"; + push(@cleanup_cmds, $cleanfn); + } + close QOUT; +} + +sub launch_job_fork { + my $errorfile = "/dev/null"; + my $outfile = "/dev/null"; + $node_count++; + my $clientname = $executable; + $clientname =~ s/^(.{4}).*$/$1/; + $clientname = "$clientname.$node_count"; + if ($errordir){ + $errorfile = "$errordir/$clientname.ER"; + $outfile = "$errordir/$clientname.OU"; + push @errors,$errorfile; + push @outs,$outfile; + } + my $pid = fork; + if ($pid == 0) { + my ($fh, $scr_name) = get_temp_script(); + print $fh $script; + close $fh; + my $todo = "/bin/bash -xeo pipefail $scr_name 1> $outfile 2> $errorfile"; + print STDERR "EXEC: $todo\n"; + my $out = check_output("$todo"); + unlink $scr_name or warn "Failed to remove $scr_name"; + exit 0; + } else { + push @pids, $pid; + } +} + +sub get_temp_script { + my ($fh, $filename) = tempfile( "$workdir/workXXXX", SUFFIX => '.sh'); + return ($fh, $filename); +} + +sub cleanup_and_die { + cleanup(); + die "\n"; +} + +sub cleanup { + print STDERR "Cleaning up...\n"; + for $cmd (@cleanup_cmds){ + print STDERR " Cleanup command: $cmd\n"; + eval $cmd; + } + print STDERR "outputs:\n",preview_files(\@outs,1),"\n"; + print STDERR "errors:\n",preview_files(\@errors,1),"\n"; + print STDERR "cmd:\n",$cmd,"\n"; + print STDERR " cat $errordir/*.ER\nfor logs.\n"; + print STDERR "Cleanup finished.\n"; +} + +sub print_help +{ + my $name = check_output("basename $0"); chomp $name; + print << "Help"; + +usage: $name [options] + + Automatic black-box parallelization of commands. + +options: + + --use-fork + Instead of using qsub, use fork. + + -e, --error-dir + Retain output files from jobs in , rather + than silently deleting them. + + -m, --multi-line + Expect that command may produce multiple output + lines for a single input line. $name makes a + reasonable attempt to obtain all output before + processing additional inputs. However, use of this + option is inherently unsafe. + + -v, --verbose + Print diagnostic informatoin on stderr. + + -j, --jobs + Number of jobs to use. + + -p, --pmem + pmem setting for each job. + +Help +} diff --git a/training/utils/risk.cc b/training/utils/risk.cc new file mode 100644 index 00000000..d5a12cfd --- /dev/null +++ b/training/utils/risk.cc @@ -0,0 +1,45 @@ +#include "risk.h" + +#include "prob.h" +#include "candidate_set.h" +#include "ns.h" + +using namespace std; + +namespace training { + +// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)]) +double CandidateSetRisk::operator()(const vector& params, + SparseVector* g) const { + prob_t z; + for (unsigned i = 0; i < cands_.size(); ++i) { + const prob_t u(cands_[i].fmap.dot(params), init_lnx()); + z += u; + } + const double log_z = log(z); + + SparseVector exp_feats; + if (g) { + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + exp_feats += cands_[i].fmap * prob; + } + } + + double risk = 0; + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) + : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); + const double r = prob * cost; + risk += r; + if (g) (*g) += (cands_[i].fmap - exp_feats) * r; + } + return risk; +} + +} + + diff --git a/training/utils/risk.h b/training/utils/risk.h new file mode 100644 index 00000000..2e8db0fb --- /dev/null +++ b/training/utils/risk.h @@ -0,0 +1,26 @@ +#ifndef _RISK_H_ +#define _RISK_H_ + +#include +#include "sparse_vector.h" +class EvaluationMetric; + +namespace training { + class CandidateSet; + + class CandidateSetRisk { + public: + explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) : + cands_(cs), + metric_(metric) {} + // compute the risk (expected loss) of a CandidateSet + // (optional) the gradient of the risk with respect to params + double operator()(const std::vector& params, + SparseVector* g = NULL) const; + private: + const CandidateSet& cands_; + const EvaluationMetric& metric_; + }; +}; + +#endif diff --git a/training/utils/sentclient.c b/training/utils/sentclient.c new file mode 100644 index 00000000..91d994ab --- /dev/null +++ b/training/utils/sentclient.c @@ -0,0 +1,76 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sentserver.h" + +int main (int argc, char *argv[]) { + int sock, port; + char *s, *key; + struct hostent *hp; + struct sockaddr_in server; + int errors = 0; + + if (argc < 3) { + fprintf(stderr, "Usage: sentclient host[:port[:key]] command [args ...]\n"); + exit(1); + } + + s = strchr(argv[1], ':'); + key = NULL; + + if (s == NULL) { + port = DEFAULT_PORT; + } else { + *s = '\0'; + s+=1; + /* dumb hack */ + key = strchr(s, ':'); + if (key != NULL){ + *key = '\0'; + key += 1; + } + port = atoi(s); + } + + sock = socket(AF_INET, SOCK_STREAM, 0); + + hp = gethostbyname(argv[1]); + if (hp == NULL) { + fprintf(stderr, "unknown host %s\n", argv[1]); + exit(1); + } + + bzero((char *)&server, sizeof(server)); + bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); + server.sin_family = hp->h_addrtype; + server.sin_port = htons(port); + + while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { + perror("connect()"); + sleep(1); + errors++; + if (errors > 5) + exit(1); + } + + close(0); + close(1); + dup2(sock, 0); + dup2(sock, 1); + + if (key != NULL){ + write(1, key, strlen(key)); + write(1, "\n", 1); + } + + execvp(argv[2], argv+2); + return 0; +} diff --git a/training/utils/sentserver.c b/training/utils/sentserver.c new file mode 100644 index 00000000..c20b4fa6 --- /dev/null +++ b/training/utils/sentserver.c @@ -0,0 +1,515 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sentserver.h" + +#define MAX_CLIENTS 64 + +struct clientinfo { + int s; + struct sockaddr_in sin; +}; + +struct line { + int id; + char *s; + int status; + struct line *next; +} *head, **ptail; + +int n_sent = 0, n_received=0, n_flushed=0; + +#define STATUS_RUNNING 0 +#define STATUS_ABORTED 1 +#define STATUS_FINISHED 2 + +pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t input_mutex = PTHREAD_MUTEX_INITIALIZER; + +int n_clients = 0; +int s; +int expect_multiline_output = 0; +int log_mutex = 0; +int stay_alive = 0; /* dont panic and die with zero clients */ + +void queue_finish(struct line *node, char *s, int fid); +char * read_line(int fd, int multiline); +void done (int code); + +struct line * queue_get(int fid) { + struct line *cur; + char *s, *synch; + + if (log_mutex) fprintf(stderr, "Getting for data for fid %d\n", fid); + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + + /* First, check for aborted sentences. */ + + if (log_mutex) fprintf(stderr, " Checking queue for aborted jobs (fid %d)\n", fid); + for (cur = head; cur != NULL; cur = cur->next) { + if (cur->status == STATUS_ABORTED) { + cur->status = STATUS_RUNNING; + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + return cur; + } + } + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + /* Otherwise, read a new one. */ + if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); + if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); + pthread_mutex_lock(&input_mutex); + s = read_line(0,0); + + while (s) { + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); + pthread_mutex_unlock(&input_mutex); + + cur = malloc(sizeof (struct line)); + cur->id = n_sent; + cur->s = s; + cur->next = NULL; + + *ptail = cur; + ptail = &cur->next; + + n_sent++; + + if (strcmp(s,"===SYNCH===\n")==0){ + fprintf(stderr, "Received ===SYNCH=== signal (fid %d)\n", fid); + // Note: queue_finish calls free(cur->s). + // Therefore we need to create a new string here. + synch = malloc((strlen("===SYNCH===\n")+2) * sizeof (char)); + synch = strcpy(synch, s); + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + queue_finish(cur, synch, fid); /* handles its own lock */ + + if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); + if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); + pthread_mutex_lock(&input_mutex); + + s = read_line(0,0); + } else { + if (log_mutex) fprintf(stderr, " Received new data %d (fid %d)\n", cur->id, fid); + cur->status = STATUS_RUNNING; + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + return cur; + } + } + + if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); + pthread_mutex_unlock(&input_mutex); + /* Only way to reach this point: no more output */ + + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + if (head == NULL) { + fprintf(stderr, "Reached end of file. Exiting.\n"); + done(0); + } else + ptail = NULL; /* This serves as a signal that there is no more input */ + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + return NULL; +} + +void queue_panic() { + struct line *next; + while (head && head->status == STATUS_FINISHED) { + /* Write out finished sentences */ + if (head->status == STATUS_FINISHED) { + fputs(head->s, stdout); + fflush(stdout); + } + /* Write out blank line for unfinished sentences */ + if (head->status == STATUS_ABORTED) { + fputs("\n", stdout); + fflush(stdout); + } + /* By defition, there cannot be any RUNNING sentences, since + function is only called when n_clients == 0 */ + free(head->s); + next = head->next; + free(head); + head = next; + n_flushed++; + } + fclose(stdout); + fprintf(stderr, "All clients died. Panicking, flushing completed sentences and exiting.\n"); + done(1); +} + +void queue_abort(struct line *node, int fid) { + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + node->status = STATUS_ABORTED; + if (n_clients == 0) { + if (stay_alive) { + fprintf(stderr, "Warning! No live clients detected! Staying alive, will retry soon.\n"); + } else { + queue_panic(); + } + } + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); +} + + +void queue_print() { + struct line *cur; + + fprintf(stderr, " Queue\n"); + + for (cur = head; cur != NULL; cur = cur->next) { + switch(cur->status) { + case STATUS_RUNNING: + fprintf(stderr, " %d running ", cur->id); break; + case STATUS_ABORTED: + fprintf(stderr, " %d aborted ", cur->id); break; + case STATUS_FINISHED: + fprintf(stderr, " %d finished ", cur->id); break; + + } + fprintf(stderr, "\n"); + //fprintf(stderr, cur->s); + } +} + +void queue_finish(struct line *node, char *s, int fid) { + struct line *next; + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + + free(node->s); + node->s = s; + node->status = STATUS_FINISHED; + n_received++; + + /* Flush out finished nodes */ + while (head && head->status == STATUS_FINISHED) { + + if (log_mutex) fprintf(stderr, " Flushing finished node %d\n", head->id); + + fputs(head->s, stdout); + fflush(stdout); + if (log_mutex) fprintf(stderr, " Flushed node %d\n", head->id); + free(head->s); + + next = head->next; + free(head); + + head = next; + + n_flushed++; + + if (head == NULL) { /* empty queue */ + if (ptail == NULL) { /* This can only happen if set in queue_get as signal that there is no more input. */ + fprintf(stderr, "All sentences finished. Exiting.\n"); + done(0); + } else /* ptail pointed at something which was just popped off the stack -- reset to head*/ + ptail = &head; + } + } + + if (log_mutex) fprintf(stderr, " Flushing output %d\n", head->id); + fflush(stdout); + fprintf(stderr, "%d sentences sent, %d sentences finished, %d sentences flushed\n", n_sent, n_received, n_flushed); + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + +} + +char * read_line(int fd, int multiline) { + int size = 80; + char errorbuf[100]; + char *s = malloc(size+2); + int result, errors=0; + int i = 0; + + result = read(fd, s+i, 1); + + while (1) { + if (result < 0) { + perror("read()"); + sprintf(errorbuf, "Error code: %d\n", errno); + fprintf(stderr, errorbuf); + errors++; + if (errors > 5) { + free(s); + return NULL; + } else { + sleep(1); /* retry after delay */ + } + } else if (result == 0) { + break; + } else if (multiline==0 && s[i] == '\n') { + break; + } else { + if (s[i] == '\n'){ + /* if we've reached this point, + then multiline must be 1, and we're + going to poll the fd for an additional + line of data. The basic design is to + run a select on the filedescriptor fd. + Select will return under two conditions: + if there is data on the fd, or if a + timeout is reached. We'll select on this + fd. If select returns because there's data + ready, keep going; else assume there's no + more and return the data we already have. + */ + + fd_set set; + FD_ZERO(&set); + FD_SET(fd, &set); + + struct timeval timeout; + timeout.tv_sec = 3; // number of seconds for timeout + timeout.tv_usec = 0; + + int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); + if (ready<1){ + break; // no more data, stop looping + } + } + i++; + + if (i == size) { + size = size*2; + s = realloc(s, size+2); + } + } + + result = read(fd, s+i, 1); + } + + if (result == 0 && i == 0) { /* end of file */ + free(s); + return NULL; + } + + s[i] = '\n'; + s[i+1] = '\0'; + + return s; +} + +void * new_client(void *arg) { + struct clientinfo *client = (struct clientinfo *)arg; + struct line *cur; + int result; + char *s; + char errorbuf[100]; + + pthread_mutex_lock(&clients_mutex); + n_clients++; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client connected (%d connected)\n", n_clients); + + for (;;) { + + cur = queue_get(client->s); + + if (cur) { + /* fprintf(stderr, "Sending to client: %s", cur->s); */ + fprintf(stderr, "Sending data %d to client (fid %d)\n", cur->id, client->s); + result = write(client->s, cur->s, strlen(cur->s)); + if (result < strlen(cur->s)){ + perror("write()"); + sprintf(errorbuf, "Error code: %d\n", errno); + fprintf(stderr, errorbuf); + + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client died (%d connected)\n", n_clients); + queue_abort(cur, client->s); + + close(client->s); + free(client); + + pthread_exit(NULL); + } + } else { + close(client->s); + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + fprintf(stderr, "Client dismissed (%d connected)\n", n_clients); + pthread_exit(NULL); + } + + s = read_line(client->s,expect_multiline_output); + if (s) { + /* fprintf(stderr, "Client (fid %d) returned: %s", client->s, s); */ + fprintf(stderr, "Client (fid %d) returned data %d\n", client->s, cur->id); +// queue_print(); + queue_finish(cur, s, client->s); + } else { + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client died (%d connected)\n", n_clients); + queue_abort(cur, client->s); + + close(client->s); + free(client); + + pthread_exit(NULL); + } + + } + return 0; +} + +void done (int code) { + close(s); + exit(code); +} + + + +int main (int argc, char *argv[]) { + struct sockaddr_in sin, from; + int g; + socklen_t len; + struct clientinfo *client; + int port; + int opt; + int errors = 0; + int argi; + char *key = NULL, *client_key; + int use_key = 0; + /* the key stuff here doesn't provide any + real measure of security, it's mainly to keep + jobs from bumping into each other. */ + + pthread_t tid; + port = DEFAULT_PORT; + + for (argi=1; argi < argc; argi++){ + if (strcmp(argv[argi], "-m")==0){ + expect_multiline_output = 1; + } else if (strcmp(argv[argi], "-k")==0){ + argi++; + if (argi == argc){ + fprintf(stderr, "Key must be specified after -k\n"); + exit(1); + } + key = argv[argi]; + use_key = 1; + } else if (strcmp(argv[argi], "--stay-alive")==0){ + stay_alive = 1; /* dont panic and die with zero clients */ + } else { + port = atoi(argv[argi]); + } + } + + /* Initialize data structures */ + head = NULL; + ptail = &head; + + /* Set up listener */ + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + opt = 1; + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = htonl(INADDR_ANY); + sin.sin_port = htons(port); + while (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) { + perror("bind()"); + sleep(1); + errors++; + if (errors > 100) + exit(1); + } + + len = sizeof(sin); + getsockname(s, (struct sockaddr *) &sin, &len); + + fprintf(stderr, "Listening on port %hu\n", ntohs(sin.sin_port)); + + while (listen(s, MAX_CLIENTS) < 0) { + perror("listen()"); + sleep(1); + errors++; + if (errors > 100) + exit(1); + } + + for (;;) { + len = sizeof(from); + g = accept(s, (struct sockaddr *)&from, &len); + if (g < 0) { + perror("accept()"); + sleep(1); + continue; + } + client = malloc(sizeof(struct clientinfo)); + client->s = g; + bcopy(&from, &client->sin, len); + + if (use_key){ + fd_set set; + FD_ZERO(&set); + FD_SET(client->s, &set); + + struct timeval timeout; + timeout.tv_sec = 3; // number of seconds for timeout + timeout.tv_usec = 0; + + int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); + if (ready<1){ + fprintf(stderr, "Prospective client failed to respond with correct key.\n"); + close(client->s); + free(client); + } else { + client_key = read_line(client->s,0); + client_key[strlen(client_key)-1]='\0'; /* chop trailing newline */ + if (strcmp(key, client_key)==0){ + pthread_create(&tid, NULL, new_client, client); + } else { + fprintf(stderr, "Prospective client failed to respond with correct key.\n"); + close(client->s); + free(client); + } + free(client_key); + } + } else { + pthread_create(&tid, NULL, new_client, client); + } + } + +} + + + diff --git a/training/utils/sentserver.h b/training/utils/sentserver.h new file mode 100644 index 00000000..cd17a546 --- /dev/null +++ b/training/utils/sentserver.h @@ -0,0 +1,6 @@ +#ifndef SENTSERVER_H +#define SENTSERVER_H + +#define DEFAULT_PORT 50000 + +#endif -- cgit v1.2.3