diff options
Diffstat (limited to 'training/utils')
| -rw-r--r-- | training/utils/Makefile.am | 37 | ||||
| -rw-r--r-- | training/utils/candidate_set.cc | 169 | ||||
| -rw-r--r-- | training/utils/candidate_set.h | 60 | ||||
| -rwxr-xr-x | training/utils/decode-and-evaluate.pl | 246 | ||||
| -rw-r--r-- | training/utils/entropy.cc | 41 | ||||
| -rw-r--r-- | training/utils/entropy.h | 22 | ||||
| -rw-r--r-- | training/utils/grammar_convert.cc | 348 | ||||
| -rw-r--r-- | training/utils/lbfgs.h | 1459 | ||||
| -rw-r--r-- | training/utils/lbfgs_test.cc | 117 | ||||
| -rw-r--r-- | training/utils/libcall.pl | 71 | ||||
| -rw-r--r-- | training/utils/online_optimizer.cc | 16 | ||||
| -rw-r--r-- | training/utils/online_optimizer.h | 129 | ||||
| -rw-r--r-- | training/utils/optimize.cc | 102 | ||||
| -rw-r--r-- | training/utils/optimize.h | 92 | ||||
| -rw-r--r-- | training/utils/optimize_test.cc | 118 | ||||
| -rwxr-xr-x | training/utils/parallelize.pl | 423 | ||||
| -rw-r--r-- | training/utils/risk.cc | 45 | ||||
| -rw-r--r-- | training/utils/risk.h | 26 | ||||
| -rw-r--r-- | training/utils/sentclient.c | 76 | ||||
| -rw-r--r-- | training/utils/sentserver.c | 515 | ||||
| -rw-r--r-- | training/utils/sentserver.h | 6 | 
21 files changed, 4118 insertions, 0 deletions
| diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am new file mode 100644 index 00000000..d708a9f5 --- /dev/null +++ b/training/utils/Makefile.am @@ -0,0 +1,37 @@ +noinst_LIBRARIES = libtraining_utils.a + +bin_PROGRAMS = \ +  sentserver \ +  sentclient \ +  grammar_convert + +noinst_PROGRAMS = \ +  lbfgs_test \ +  optimize_test + +sentserver_SOURCES = sentserver.c +sentserver_LDFLAGS = -pthread + +sentclient_SOURCES = sentclient.c +sentclient_LDFLAGS = -pthread + +TESTS = lbfgs_test optimize_test + +libtraining_utils_a_SOURCES = \ +  candidate_set.cc \ +  entropy.cc \ +  optimize.cc \ +  online_optimizer.cc \ +  risk.cc + +optimize_test_SOURCES = optimize_test.cc +optimize_test_LDADD = libtraining_utils.a $(top_srcdir)/utils/libutils.a + +grammar_convert_SOURCES = grammar_convert.cc +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a + +lbfgs_test_SOURCES = lbfgs_test.cc +lbfgs_test_LDADD = $(top_srcdir)/utils/libutils.a + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I$(top_srcdir)/klm + diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc new file mode 100644 index 00000000..087efec3 --- /dev/null +++ b/training/utils/candidate_set.cc @@ -0,0 +1,169 @@ +#include "candidate_set.h" + +#include <tr1/unordered_set> + +#include <boost/functional/hash.hpp> + +#include "verbose.h" +#include "ns.h" +#include "filelib.h" +#include "wordid.h" +#include "tdict.h" +#include "hg.h" +#include "kbest.h" +#include "viterbi.h" + +using namespace std; + +namespace training { + +struct ApproxVectorHasher { +  static const size_t MASK = 0xFFFFFFFFull; +  union UType { +    double f;   // leave as double +    size_t i; +  }; +  static inline double round(const double x) { +    UType t; +    t.f = x; +    size_t r = t.i & MASK; +    if ((r << 1) > MASK) +      t.i += MASK - r + 1; +    else +      t.i &= (1ull - MASK); +    return t.f; +  } +  size_t operator()(const SparseVector<double>& x) const { +    size_t h = 0x573915839; +    for (SparseVector<double>::const_iterator it = x.begin(); it != x.end(); ++it) { +      UType t; +      t.f = it->second; +      if (t.f) { +        size_t z = (t.i >> 32); +        boost::hash_combine(h, it->first); +        boost::hash_combine(h, z); +      } +    } +    return h; +  } +}; + +struct ApproxVectorEquals { +  bool operator()(const SparseVector<double>& a, const SparseVector<double>& b) const { +    SparseVector<double>::const_iterator bit = b.begin(); +    for (SparseVector<double>::const_iterator ait = a.begin(); ait != a.end(); ++ait) { +      if (bit == b.end() || +          ait->first != bit->first || +          ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) +        return false; +      ++bit; +    } +    if (bit != b.end()) return false; +    return true; +  } +}; + +struct CandidateCompare { +  bool operator()(const Candidate& a, const Candidate& b) const { +    ApproxVectorEquals eq; +    return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); +  } +}; + +struct CandidateHasher { +  size_t operator()(const Candidate& x) const { +    boost::hash<vector<WordID> > hhasher; +    ApproxVectorHasher vhasher; +    size_t ha = hhasher(x.ewords); +    boost::hash_combine(ha, vhasher(x.fmap)); +    return ha; +  } +}; + +static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { +  SparseVector<double>& x = *out; +  size_t last_start = cur; +  size_t last_comma = string::npos; +  while(cur <= line.size()) { +    if (line[cur] == ' ' || cur == line.size()) { +      if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { +        cerr << "[ERROR] " << line << endl << "  position = " << cur << endl; +        exit(1); +      } +      const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); +      if (cur < line.size()) line[cur] = 0; +      const double val = strtod(&line[last_comma + 1], NULL); +      x.set_value(fid, val); + +      last_comma = string::npos; +      last_start = cur+1; +    } else { +      if (line[cur] == '=') +        last_comma = cur; +    } +    ++cur; +  } +} + +void CandidateSet::WriteToFile(const string& file) const { +  WriteFile wf(file); +  ostream& out = *wf.stream(); +  out.precision(10); +  string ss; +  for (unsigned i = 0; i < cs.size(); ++i) { +    out << TD::GetString(cs[i].ewords) << endl; +    out << cs[i].fmap << endl; +    cs[i].eval_feats.Encode(&ss); +    out << ss << endl; +  } +} + +void CandidateSet::ReadFromFile(const string& file) { +  if(!SILENT) cerr << "Reading candidates from " << file << endl; +  ReadFile rf(file); +  istream& in = *rf.stream(); +  string cand; +  string feats; +  string ss; +  while(getline(in, cand)) { +    getline(in, feats); +    getline(in, ss); +    assert(in); +    cs.push_back(Candidate()); +    TD::ConvertSentence(cand, &cs.back().ewords); +    ParseSparseVector(feats, 0, &cs.back().fmap); +    cs.back().eval_feats = SufficientStats(ss); +  } +  if(!SILENT) cerr << "  read " << cs.size() << " candidates\n"; +} + +void CandidateSet::Dedup() { +  if(!SILENT) cerr << "Dedup in=" << cs.size(); +  tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u; +  while(cs.size() > 0) { +    u.insert(cs.back()); +    cs.pop_back(); +  } +  tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin(); +  while (it != u.end()) { +    cs.push_back(*it); +    it = u.erase(it); +  } +  if(!SILENT) cerr << "  out=" << cs.size() << endl; +} + +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { +  KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); + +  for (unsigned i = 0; i < kbest_size; ++i) { +    const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = +      kbest.LazyKthBest(hg.nodes_.size() - 1, i); +    if (!d) break; +    cs.push_back(Candidate(d->yield, d->feature_values)); +    if (scorer) +      scorer->Evaluate(d->yield, &cs.back().eval_feats); +  } +  Dedup(); +} + +} diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h new file mode 100644 index 00000000..9d326ed0 --- /dev/null +++ b/training/utils/candidate_set.h @@ -0,0 +1,60 @@ +#ifndef _CANDIDATE_SET_H_ +#define _CANDIDATE_SET_H_ + +#include <vector> +#include <algorithm> + +#include "ns.h" +#include "wordid.h" +#include "sparse_vector.h" + +class Hypergraph; + +namespace training { + +struct Candidate { +  Candidate() {} +  Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : +      ewords(e), +      fmap(fm) {} +  Candidate(const std::vector<WordID>& e, +            const SparseVector<double>& fm, +            const SegmentEvaluator& se) : +      ewords(e), +      fmap(fm) { +    se.Evaluate(ewords, &eval_feats); +  } + +  void swap(Candidate& other) { +    eval_feats.swap(other.eval_feats); +    ewords.swap(other.ewords); +    fmap.swap(other.fmap); +  } + +  std::vector<WordID> ewords; +  SparseVector<double> fmap; +  SufficientStats eval_feats; +}; + +// represents some kind of collection of translation candidates, e.g. +// aggregated k-best lists, sample lists, etc. +class CandidateSet { + public: +  CandidateSet() {} +  inline size_t size() const { return cs.size(); } +  const Candidate& operator[](size_t i) const { return cs[i]; } + +  void ReadFromFile(const std::string& file); +  void WriteToFile(const std::string& file) const; +  void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); +  // TODO add code to do unique k-best +  // TODO add code to draw k samples + + private: +  void Dedup(); +  std::vector<Candidate> cs; +}; + +} + +#endif diff --git a/training/utils/decode-and-evaluate.pl b/training/utils/decode-and-evaluate.pl new file mode 100755 index 00000000..1a332c08 --- /dev/null +++ b/training/utils/decode-and-evaluate.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use File::Basename qw(basename); +my $QSUB_CMD = qsub_args(mert_memory()); + +require "libcall.pl"; + +# Default settings +my $default_jobs = env_default_jobs(); +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $parallelize = "$bin_dir/parallelize.pl"; +my $libcall = "$bin_dir/libcall.pl"; +my $sentserver = "$bin_dir/sentserver"; +my $sentclient = "$bin_dir/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $jobs = $default_jobs;   # number of decode nodes +my $pmem = "9g"; +my $help = 0; +my $config; +my $test_set; +my $weights; +my $use_make = 1; +my $useqsub; +my $cpbin=1; +# Process command-line options +if (GetOptions( +	"jobs=i" => \$jobs, +	"help" => \$help, +	"qsub" => \$useqsub, +	"input=s" => \$test_set, +        "config=s" => \$config, +	"weights=s" => \$weights, +) == 0 || @ARGV!=0 || $help) { +	print_help(); +	exit; +} + +if ($useqsub) { +  $use_make = 0; +  die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); + +if (!defined $test_set) { push @missing_args, "--input"; } +if (!defined $config) { push @missing_args, "--config"; } +if (!defined $weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\nUse --help for more information.\n" if (@missing_args); + +my @tf = localtime(time); +my $tname = basename($test_set); +$tname =~ s/\.(sgm|sgml|xml)$//i; +my $dir = "eval.$tname." . sprintf('%d%02d%02d-%02d%02d%02d', 1900+$tf[5], $tf[4], $tf[3], $tf[2], $tf[1], $tf[0]); + +my $time = unchecked_output("date"); + +check_call("mkdir -p $dir"); + +split_devset($test_set, "$dir/test.input.raw", "$dir/test.refs"); +my $refs = "-r $dir/test.refs"; +my $newsrc = "$dir/test.input"; +enseg("$dir/test.input.raw", $newsrc); +my $src_file = $newsrc; +open F, "<$src_file" or die "Can't read $src_file: $!"; close F; + +my $test_trans="$dir/test.trans"; +my $logdir="$dir/logs"; +my $decoderLog="$logdir/decoder.sentserver.log"; +check_call("mkdir -p $logdir"); + +#decode +print STDERR "RUNNING DECODER AT "; +print STDERR unchecked_output("date"); +my $decoder_cmd = "$decoder -c $config --weights $weights"; +my $pcmd; +if ($use_make) { +	$pcmd = "cat $src_file | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; +} else { +	$pcmd = "cat $src_file | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; +} +my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $test_trans"; +check_bash_call($cmd); +print STDERR "DECODER COMPLETED AT "; +print STDERR unchecked_output("date"); +print STDERR "\nOUTPUT: $test_trans\n\n"; +my $bleu = check_output("cat $test_trans | $SCORER $refs -m ibm_bleu"); +chomp $bleu; +print STDERR "BLEU: $bleu\n"; +my $ter = check_output("cat $test_trans | $SCORER $refs -m ter"); +chomp $ter; +print STDERR " TER: $ter\n"; +open TR, ">$dir/test.scores" or die "Can't write $dir/test.scores: $!"; +print TR <<EOT; +### SCORE REPORT ############################################################# +        OUTPUT=$test_trans +  SCRIPT INPUT=$test_set + DECODER INPUT=$src_file +    REFERENCES=$dir/test.refs +------------------------------------------------------------------------------ +          BLEU=$bleu +           TER=$ter +############################################################################## +EOT +close TR; +my $sr = unchecked_output("cat $dir/test.scores"); +print STDERR "\n\n$sr\n(A copy of this report can be found in $dir/test.scores)\n\n"; +exit 0; + +sub enseg { +	my $src = shift; +	my $newsrc = shift; +	open(SRC, $src); +	open(NEWSRC, ">$newsrc"); +	my $i=0; +	while (my $line=<SRC>){ +		chomp $line; +		if ($line =~ /^\s*<seg/i) { +		    if($line =~ /id="[0-9]+"/) { +			print NEWSRC "$line\n"; +		    } else { +			die "When using segments with pre-generated <seg> tags, you must include a zero-based id attribute"; +		    } +		} else { +			print NEWSRC "<seg id=\"$i\">$line</seg>\n"; +		} +		$i++; +	} +	close SRC; +	close NEWSRC; +} + +sub print_help { +	my $executable = basename($0); chomp $executable; +	print << "Help"; + +Usage: $executable [options] <ini file> + +	$executable --config cdec.ini --weights weights.txt [--jobs N] [--qsub] <testset.in-ref> + +Options: + +	--help +		Print this message and exit. + +	--config <file> +		A path to the cdec.ini file. + +	--weights <file> +		A file specifying feature weights. + +	--dir <dir> +		Directory for intermediate and output files. + +Job control options: + +	--jobs <I> +		Number of decoder processes to run in parallel. [default=$default_jobs] + +	--qsub +		Use qsub to run jobs in parallel (qsub must be configured in +		environment/LocalEnvironment.pm) + +	--pmem <N> +		Amount of physical memory requested for parallel decoding jobs +		(used with qsub requests only) + +Help +} + +sub convert { +  my ($str) = @_; +  my @ps = split /;/, $str; +  my %dict = (); +  for my $p (@ps) { +    my ($k, $v) = split /=/, $p; +    $dict{$k} = $v; +  } +  return %dict; +} + + + +sub cmdline { +    return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { +    my ($arg)=@_; +    return undef unless defined $arg; +    if ($arg =~ /$is_shell_special/) { +        $arg =~ s/($shell_escape_in_quote)/\\$1/g; +        return "\"$arg\""; +    } +    return $arg; +} + +sub escaped_shell_args { +    return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { +    return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { +    return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} + +sub split_devset { +  my ($infile, $outsrc, $outref) = @_; +  open F, "<$infile" or die "Can't read $infile: $!"; +  open S, ">$outsrc" or die "Can't write $outsrc: $!"; +  open R, ">$outref" or die "Can't write $outref: $!"; +  while(<F>) { +    chomp; +    my ($src, @refs) = split /\s*\|\|\|\s*/; +    die "Malformed devset line: $_\n" unless scalar @refs > 0; +    print S "$src\n"; +    print R join(' ||| ', @refs) . "\n"; +  } +  close R; +  close S; +  close F; +} + diff --git a/training/utils/entropy.cc b/training/utils/entropy.cc new file mode 100644 index 00000000..4fdbe2be --- /dev/null +++ b/training/utils/entropy.cc @@ -0,0 +1,41 @@ +#include "entropy.h" + +#include "prob.h" +#include "candidate_set.h" + +using namespace std; + +namespace training { + +// see Mann and McCallum "Efficient Computation of Entropy Gradient ..." for +// a mostly clear derivation of: +//   g = E[ F(x,y) * log p(y|x) ] + H(y | x) * E[ F(x,y) ] +double CandidateSetEntropy::operator()(const vector<double>& params, +                                       SparseVector<double>* g) const { +  prob_t z; +  vector<double> dps(cands_.size()); +  for (unsigned i = 0; i < cands_.size(); ++i) { +    dps[i] = cands_[i].fmap.dot(params); +    const prob_t u(dps[i], init_lnx()); +    z += u; +  } +  const double log_z = log(z); + +  SparseVector<double> exp_feats; +  double entropy = 0; +  for (unsigned i = 0; i < cands_.size(); ++i) { +    const double log_prob = cands_[i].fmap.dot(params) - log_z; +    const double prob = exp(log_prob); +    const double e_logprob = prob * log_prob; +    entropy -= e_logprob; +    if (g) { +      (*g) += cands_[i].fmap * e_logprob; +      exp_feats += cands_[i].fmap * prob; +    } +  } +  if (g) (*g) += exp_feats * entropy; +  return entropy; +} + +} + diff --git a/training/utils/entropy.h b/training/utils/entropy.h new file mode 100644 index 00000000..796589ca --- /dev/null +++ b/training/utils/entropy.h @@ -0,0 +1,22 @@ +#ifndef _CSENTROPY_H_ +#define _CSENTROPY_H_ + +#include <vector> +#include "sparse_vector.h" + +namespace training { +  class CandidateSet; + +  class CandidateSetEntropy { +   public: +    explicit CandidateSetEntropy(const CandidateSet& cs) : cands_(cs) {} +    // compute the entropy (expected log likelihood) of a CandidateSet +    // (optional) the gradient of the entropy with respect to params +    double operator()(const std::vector<double>& params, +                      SparseVector<double>* g = NULL) const; +   private: +    const CandidateSet& cands_; +  }; +}; + +#endif diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc new file mode 100644 index 00000000..607a7cb9 --- /dev/null +++ b/training/utils/grammar_convert.cc @@ -0,0 +1,348 @@ +/* +  this program modifies cfg hypergraphs (forests) and extracts kbests? +  what are: json, split ? + */ +#include <iostream> +#include <algorithm> +#include <sstream> + +#include <boost/lexical_cast.hpp> +#include <boost/program_options.hpp> + +#include "inside_outside.h" +#include "tdict.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +WordID kSTART; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  opts.add_options() +        ("input,i", po::value<string>()->default_value("-"), "Input file") +        ("format,f", po::value<string>()->default_value("cfg"), "Input format. Values: cfg, json, split") +        ("output,o", po::value<string>()->default_value("json"), "Output command. Values: json, 1best") +        ("reorder,r", "Add Yamada & Knight (2002) reorderings") +        ("weights,w", po::value<string>(), "Feature weights for k-best derivations [optional]") +        ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") +        ("k_derivations,k", po::value<int>(), "Show k derivations and their features") +        ("max_reorder,m", po::value<int>()->default_value(999), "Move a constituent at most this far") +        ("help,h", "Print this help message and exit"); +  po::options_description clo("Command line options"); +  po::options_description dcmdline_options; +  dcmdline_options.add(opts); + +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  po::notify(*conf); + +  if (conf->count("help") || conf->count("input") == 0) { +    cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +int GetOrCreateNode(const WordID& lhs, map<WordID, int>* lhs2node, Hypergraph* hg) { +  int& node_id = (*lhs2node)[lhs]; +  if (!node_id) +    node_id = hg->AddNode(lhs)->id_ + 1; +  return node_id - 1; +} + +void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { +  if (goal < 0) { +    cerr << "Error! [S] not found in grammar!\n"; +    exit(1); +  } +  if (hg->nodes_[goal].in_edges_.size() != 1) { +    cerr << "Error! [S] has more than one rewrite!\n"; +    exit(1); +  } +  int old_size = hg->nodes_.size(); +  hg->TopologicallySortNodesAndEdges(goal); +  if (hg->nodes_.size() != old_size) { +    cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; +  } +  vector<double> inside; // inside score at each node +  double p = Inside<double, TransitionCountWeightFunction>(*hg, &inside); +  if (!p) { +    cerr << "Warning! Grammar defines the empty language!\n"; +    hg->clear(); +    return; +  } +  vector<bool> prune(hg->edges_.size(), false); +  int bad_edges = 0; +  for (unsigned i = 0; i < hg->edges_.size(); ++i) { +    Hypergraph::Edge& edge = hg->edges_[i]; +    bool bad = false; +    for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { +      if (!inside[edge.tail_nodes_[j]]) { +        bad = true; +        ++bad_edges; +      } +    } +    prune[i] = bad; +  } +  cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; +  for (unsigned i = 0; i < hg->edges_.size(); ++i) { +    if (prune[i]) +      cerr << "   " << hg->edges_[i].rule_->AsString() << endl; +  } +  hg->PruneEdges(prune); +} + +void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { +  Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); +  hg->ConnectEdgeToHeadNode(new_edge, head_node); +  new_edge->feature_values_ = r->scores_; +} + +// from a category label like "NP_2", return "NP" +string PureCategory(WordID cat) { +  assert(cat < 0); +  string c = TD::Convert(cat*-1); +  size_t p = c.find("_"); +  if (p == string::npos) return c; +  return c.substr(0, p); +}; + +string ConstituentOrderFeature(const TRule& rule, const vector<int>& pi) { +  const static string kTERM_VAR = "x"; +  const vector<WordID>& f = rule.f(); +  map<string, int> used; +  vector<string> terms(f.size()); +  for (int i = 0; i < f.size(); ++i) { +    const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); +    int& count = used[term]; +    if (!count) { +      terms[i] = term; +    } else { +      ostringstream os; +      os << term << count; +      terms[i] = os.str(); +    } +    ++count; +  } +  ostringstream os; +  os << PureCategory(rule.GetLHS()) << ':'; +  for (int i = 0; i < f.size(); ++i) { +    if (i > 0) os << '_'; +    os << terms[pi[i]]; +  } +  return os.str(); +} + +bool CheckPermutationMask(const vector<int>& mask, const vector<int>& pi) { +  assert(mask.size() == pi.size()); + +  int req_min = -1; +  int cur_max = 0; +  int cur_mask = -1; +  for (int i = 0; i < mask.size(); ++i) { +    if (mask[i] != cur_mask) { +      cur_mask = mask[i]; +      req_min = cur_max - 1; +    } +    if (pi[i] > req_min) { +      if (pi[i] > cur_max) cur_max = pi[i]; +    } else { +      return false; +    } +  } + +  return true; +} + +void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { +  // Hypergraph tmp = *hg; +  Hypergraph::Node* node = &hg->nodes_[nodeid]; +  if (node->in_edges_.size() != 1) { +    cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; +    cerr << "  not recursing!\n"; +    return; +  } +//  for (int eii = 0; eii < node->in_edges_.size(); ++eii) { +    const int oe_index = node->in_edges_.front(); +    const TRule& rule = *hg->edges_[oe_index].rule_; +    const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; +    const int tail_size = orig_tail.size(); +    for (int i = 0; i < tail_size; ++i) { +      PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); +    } +    const vector<WordID>& of = rule.f_; +    if (of.size() == 1) return; +  //  cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; +  //  cerr << "ORIG: " << rule.AsString() << endl; +    vector<WordID> pi(of.size(), 0); +    for (int i = 0; i < pi.size(); ++i) pi[i] = i; + +    vector<int> permutation_mask(of.size(), 0); +    const bool dont_reorder_across_PU = true;  // TODO add configuration +    if (dont_reorder_across_PU) { +      int cur = 0; +      for (int i = 0; i < pi.size(); ++i) { +        if (of[i] >= 0) continue; +        const string cat = PureCategory(of[i]); +        if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { +          ++cur; +          permutation_mask[i] = cur; +          ++cur; +        } else { +          permutation_mask[i] = cur; +        } +      } +    } +    int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); +    hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); +    while (next_permutation(pi.begin(), pi.end())) { +      if (!CheckPermutationMask(permutation_mask, pi)) +        continue; +      vector<WordID> nf(pi.size(), 0); +      Hypergraph::TailNodeVector tail(pi.size(), 0); +      bool skip = false; +      for (int i = 0; i < pi.size(); ++i) { +        int dist = pi[i] - i; if (dist < 0) dist *= -1; +        if (dist > max_reorder) { skip = true; break; } +        nf[i] = of[pi[i]]; +        tail[i] = orig_tail[pi[i]]; +      } +      if (skip) continue; +      TRulePtr nr(new TRule(rule)); +      nr->f_ = nf; +      int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); +      nr->scores_.set_value(fid, 1.0); +  //    cerr << "PERM: " << nr->AsString() << endl; +      CreateEdge(nr, tail, node, hg); +    } + // } +} + +void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { +  assert(hg->nodes_.back().cat_ == kSTART); +  assert(hg->nodes_.back().in_edges_.size() == 1); +  PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); +} + +void CollapseWeights(Hypergraph* hg) { +  int fid = FD::Convert("Reordering"); +  for (int i = 0; i < hg->edges_.size(); ++i) { +    Hypergraph::Edge& edge = hg->edges_[i]; +    edge.feature_values_.clear(); +    if (edge.edge_prob_ != prob_t::Zero()) { +      edge.feature_values_.set_value(fid, log(edge.edge_prob_)); +    } +  } +} + +void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { +  if (conf.count("reorder")) +    PermuteYamadaAndKnight(hg, conf["max_reorder"].as<int>()); +  if (w.size() > 0) { hg->Reweight(w); } +  if (conf.count("collapse_weights")) CollapseWeights(hg); +  if (conf["output"].as<string>() == "json") { +    HypergraphIO::WriteToJSON(*hg, false, &cout); +    if (!ref.empty()) { cerr << "REF: " << ref << endl; } +  } else { +    vector<WordID> onebest; +    ViterbiESentence(*hg, &onebest); +    if (ref.empty()) { +      cout << TD::GetString(onebest) << endl; +    } else { +      cout << TD::GetString(onebest) << " ||| " << ref << endl; +    } +  } +  if (conf.count("k_derivations")) { +    const int k = conf["k_derivations"].as<int>(); +    KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k); +    for (int i = 0; i < k; ++i) { +      const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = +        kbest.LazyKthBest(hg->nodes_.size() - 1, i); +      if (!d) break; +      cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; +    } +  } +} + +int main(int argc, char **argv) { +  kSTART = TD::Convert("S") * -1; +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); +  string infile = conf["input"].as<string>(); +  const bool is_split_input = (conf["format"].as<string>() == "split"); +  const bool is_json_input = is_split_input || (conf["format"].as<string>() == "json"); +  const bool collapse_weights = conf.count("collapse_weights"); +  vector<double> w; +  if (conf.count("weights")) +    Weights::InitFromFile(conf["weights"].as<string>(), &w); + +  if (collapse_weights && !w.size()) { +    cerr << "--collapse_weights requires a weights file to be specified!\n"; +    exit(1); +  } +  ReadFile rf(infile); +  istream* in = rf.stream(); +  assert(*in); +  int lc = 0; +  Hypergraph hg; +  map<WordID, int> lhs2node; +  while(*in) { +    string line; +    ++lc; +    getline(*in, line); +    if (is_json_input) { +      if (line.empty() || line[0] == '#') continue; +      string ref; +      if (is_split_input) { +        size_t pos = line.rfind("}}"); +        assert(pos != string::npos); +        size_t rstart = line.find("||| ", pos); +        assert(rstart != string::npos); +        ref = line.substr(rstart + 4); +        line = line.substr(0, pos + 2); +      } +      istringstream is(line); +      if (HypergraphIO::ReadFromJSON(&is, &hg)) { +        ProcessHypergraph(w, conf, ref, &hg); +        hg.clear(); +      } else { +        cerr << "Error reading grammar from JSON: line " << lc << endl; +        exit(1); +      } +    } else { +      if (line.empty()) { +        int goal = lhs2node[kSTART] - 1; +        FilterAndCheckCorrectness(goal, &hg); +        ProcessHypergraph(w, conf, "", &hg); +        hg.clear(); +        lhs2node.clear(); +        continue; +      } +      if (line[0] == '#') continue; +      if (line[0] != '[') { +        cerr << "Line " << lc << ": bad format\n"; +        exit(1); +      } +      TRulePtr tr(TRule::CreateRuleMonolingual(line)); +      Hypergraph::TailNodeVector tail; +      for (int i = 0; i < tr->f_.size(); ++i) { +        WordID var_cat = tr->f_[i]; +        if (var_cat < 0) +          tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); +      } +      const WordID lhs = tr->GetLHS(); +      int head = GetOrCreateNode(lhs, &lhs2node, &hg); +      Hypergraph::Edge* edge = hg.AddEdge(tr, tail); +      edge->feature_values_ = tr->scores_; +      Hypergraph::Node* node = &hg.nodes_[head]; +      hg.ConnectEdgeToHeadNode(edge, node); +    } +  } +} + diff --git a/training/utils/lbfgs.h b/training/utils/lbfgs.h new file mode 100644 index 00000000..e8baecab --- /dev/null +++ b/training/utils/lbfgs.h @@ -0,0 +1,1459 @@ +#ifndef SCITBX_LBFGS_H +#define SCITBX_LBFGS_H + +#include <cstdio> +#include <cstddef> +#include <cmath> +#include <stdexcept> +#include <algorithm> +#include <vector> +#include <string> +#include <iostream> +#include <sstream> + +namespace scitbx { + +//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. +/*! Implementation of the +    Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) +    algorithm for large-scale multidimensional minimization +    problems. + +    This code was manually derived from Java code which was +    in turn derived from the Fortran program +    <code>lbfgs.f</code>.  The Java translation was +    effected mostly mechanically, with some manual +    clean-up; in particular, array indices start at 0 +    instead of 1.  Most of the comments from the Fortran +    code have been pasted in. + +    Information on the original LBFGS Fortran source code is +    available at +    http://www.netlib.org/opt/lbfgs_um.shar . The following +    information is taken verbatim from the Netlib documentation +    for the Fortran source. + +    <pre> +    file    opt/lbfgs_um.shar +    for     unconstrained optimization problems +    alg     limited memory BFGS method +    by      J. Nocedal +    contact nocedal@eecs.nwu.edu +    ref     D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for +    ,       large scale optimization methods'' Mathematical Programming 45 +    ,       (1989), pp. 503-528. +    ,       (Postscript file of this paper is available via anonymous ftp +    ,       to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.) +    </pre> + +    @author Jorge Nocedal: original Fortran version, including comments +    (July 1990).<br> +    Robert Dodier: Java translation, August 1997.<br> +    Ralf W. Grosse-Kunstleve: C++ port, March 2002.<br> +    Chris Dyer: serialize/deserialize functionality + */ +namespace lbfgs { + +  //! Generic exception class for %lbfgs %error messages. +  /*! All exceptions thrown by the minimizer are derived from this class. +   */ +  class error : public std::exception { +    public: +      //! Constructor. +      error(std::string const& msg) throw() +        : msg_("lbfgs error: " + msg) +      {} +      //! Access to error message. +      virtual const char* what() const throw() { return msg_.c_str(); } +    protected: +      virtual ~error() throw() {} +      std::string msg_; +    public: +      static std::string itoa(unsigned long i) { +        std::ostringstream os; +        os << i; +        return os.str(); +      } +  }; + +  //! Specific exception class. +  class error_internal_error : public error { +    public: +      //! Constructor. +      error_internal_error(const char* file, unsigned long line) throw() +        : error( +            "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") +      {} +  }; + +  //! Specific exception class. +  class error_improper_input_parameter : public error { +    public: +      //! Constructor. +      error_improper_input_parameter(std::string const& msg) throw() +        : error("Improper input parameter: " + msg) +      {} +  }; + +  //! Specific exception class. +  class error_improper_input_data : public error { +    public: +      //! Constructor. +      error_improper_input_data(std::string const& msg) throw() +        : error("Improper input data: " + msg) +      {} +  }; + +  //! Specific exception class. +  class error_search_direction_not_descent : public error { +    public: +      //! Constructor. +      error_search_direction_not_descent() throw() +        : error("The search direction is not a descent direction.") +      {} +  }; + +  //! Specific exception class. +  class error_line_search_failed : public error { +    public: +      //! Constructor. +      error_line_search_failed(std::string const& msg) throw() +        : error("Line search failed: " + msg) +      {} +  }; + +  //! Specific exception class. +  class error_line_search_failed_rounding_errors +  : public error_line_search_failed { +    public: +      //! Constructor. +      error_line_search_failed_rounding_errors(std::string const& msg) throw() +        : error_line_search_failed(msg) +      {} +  }; + +  namespace detail { + +    template <typename NumType> +    inline +    NumType +    pow2(NumType const& x) { return x * x; } + +    template <typename NumType> +    inline +    NumType +    abs(NumType const& x) { +      if (x < NumType(0)) return -x; +      return x; +    } + +    // This class implements an algorithm for multi-dimensional line search. +    template <typename FloatType, typename SizeType = std::size_t> +    class mcsrch +    { +      protected: +        int infoc; +        FloatType dginit; +        bool brackt; +        bool stage1; +        FloatType finit; +        FloatType dgtest; +        FloatType width; +        FloatType width1; +        FloatType stx; +        FloatType fx; +        FloatType dgx; +        FloatType sty; +        FloatType fy; +        FloatType dgy; +        FloatType stmin; +        FloatType stmax; + +        static FloatType const& max3( +          FloatType const& x, +          FloatType const& y, +          FloatType const& z) +        { +          return x < y ? (y < z ? z : y ) : (x < z ? z : x ); +        } + +      public: +        /* Minimize a function along a search direction. This code is +           a Java translation of the function <code>MCSRCH</code> from +           <code>lbfgs.f</code>, which in turn is a slight modification +           of the subroutine <code>CSRCH</code> of More' and Thuente. +           The changes are to allow reverse communication, and do not +           affect the performance of the routine. This function, in turn, +           calls <code>mcstep</code>.<p> + +           The Java translation was effected mostly mechanically, with +           some manual clean-up; in particular, array indices start at 0 +           instead of 1.  Most of the comments from the Fortran code have +           been pasted in here as well.<p> + +           The purpose of <code>mcsrch</code> is to find a step which +           satisfies a sufficient decrease condition and a curvature +           condition.<p> + +           At each stage this function updates an interval of uncertainty +           with endpoints <code>stx</code> and <code>sty</code>. The +           interval of uncertainty is initially chosen so that it +           contains a minimizer of the modified function +           <pre> +                f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s). +           </pre> +           If a step is obtained for which the modified function has a +           nonpositive function value and nonnegative derivative, then +           the interval of uncertainty is chosen so that it contains a +           minimizer of <code>f(x+stp*s)</code>.<p> + +           The algorithm is designed to find a step which satisfies +           the sufficient decrease condition +           <pre> +                 f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s), +           </pre> +           and the curvature condition +           <pre> +                 abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s). +           </pre> +           If <code>ftol</code> is less than <code>gtol</code> and if, +           for example, the function is bounded below, then there is +           always a step which satisfies both conditions. If no step can +           be found which satisfies both conditions, then the algorithm +           usually stops when rounding errors prevent further progress. +           In this case <code>stp</code> only satisfies the sufficient +           decrease condition.<p> + +           @author Original Fortran version by Jorge J. More' and +             David J. Thuente as part of the Minpack project, June 1983, +             Argonne National Laboratory. Java translation by Robert +             Dodier, August 1997. + +           @param n The number of variables. + +           @param x On entry this contains the base point for the line +             search. On exit it contains <code>x + stp*s</code>. + +           @param f On entry this contains the value of the objective +             function at <code>x</code>. On exit it contains the value +             of the objective function at <code>x + stp*s</code>. + +           @param g On entry this contains the gradient of the objective +             function at <code>x</code>. On exit it contains the gradient +             at <code>x + stp*s</code>. + +           @param s The search direction. + +           @param stp On entry this contains an initial estimate of a +             satifactory step length. On exit <code>stp</code> contains +             the final estimate. + +           @param ftol Tolerance for the sufficient decrease condition. + +           @param xtol Termination occurs when the relative width of the +             interval of uncertainty is at most <code>xtol</code>. + +           @param maxfev Termination occurs when the number of evaluations +             of the objective function is at least <code>maxfev</code> by +             the end of an iteration. + +           @param info This is an output variable, which can have these +             values: +             <ul> +             <li><code>info = -1</code> A return is made to compute +                 the function and gradient. +             <li><code>info = 1</code> The sufficient decrease condition +                 and the directional derivative condition hold. +             </ul> + +           @param nfev On exit, this is set to the number of function +             evaluations. + +           @param wa Temporary storage array, of length <code>n</code>. +         */ +        void run( +          FloatType const& gtol, +          FloatType const& stpmin, +          FloatType const& stpmax, +          SizeType n, +          FloatType* x, +          FloatType f, +          const FloatType* g, +          FloatType* s, +          SizeType is0, +          FloatType& stp, +          FloatType ftol, +          FloatType xtol, +          SizeType maxfev, +          int& info, +          SizeType& nfev, +          FloatType* wa); + +        /* The purpose of this function is to compute a safeguarded step +           for a linesearch and to update an interval of uncertainty for +           a minimizer of the function.<p> + +           The parameter <code>stx</code> contains the step with the +           least function value. The parameter <code>stp</code> contains +           the current step. It is assumed that the derivative at +           <code>stx</code> is negative in the direction of the step. If +           <code>brackt</code> is <code>true</code> when +           <code>mcstep</code> returns then a minimizer has been +           bracketed in an interval of uncertainty with endpoints +           <code>stx</code> and <code>sty</code>.<p> + +           Variables that must be modified by <code>mcstep</code> are +           implemented as 1-element arrays. + +           @param stx Step at the best step obtained so far. +             This variable is modified by <code>mcstep</code>. +           @param fx Function value at the best step obtained so far. +             This variable is modified by <code>mcstep</code>. +           @param dx Derivative at the best step obtained so far. +             The derivative must be negative in the direction of the +             step, that is, <code>dx</code> and <code>stp-stx</code> must +             have opposite signs.  This variable is modified by +             <code>mcstep</code>. + +           @param sty Step at the other endpoint of the interval of +             uncertainty. This variable is modified by <code>mcstep</code>. +           @param fy Function value at the other endpoint of the interval +             of uncertainty. This variable is modified by +             <code>mcstep</code>. + +           @param dy Derivative at the other endpoint of the interval of +             uncertainty. This variable is modified by <code>mcstep</code>. + +           @param stp Step at the current step. If <code>brackt</code> is set +             then on input <code>stp</code> must be between <code>stx</code> +             and <code>sty</code>. On output <code>stp</code> is set to the +             new step. +           @param fp Function value at the current step. +           @param dp Derivative at the current step. + +           @param brackt Tells whether a minimizer has been bracketed. +             If the minimizer has not been bracketed, then on input this +             variable must be set <code>false</code>. If the minimizer has +             been bracketed, then on output this variable is +             <code>true</code>. + +           @param stpmin Lower bound for the step. +           @param stpmax Upper bound for the step. + +           If the return value is 1, 2, 3, or 4, then the step has +           been computed successfully. A return value of 0 indicates +           improper input parameters. + +           @author Jorge J. More, David J. Thuente: original Fortran version, +             as part of Minpack project. Argonne Nat'l Laboratory, June 1983. +             Robert Dodier: Java translation, August 1997. +         */ +        static int mcstep( +          FloatType& stx, +          FloatType& fx, +          FloatType& dx, +          FloatType& sty, +          FloatType& fy, +          FloatType& dy, +          FloatType& stp, +          FloatType fp, +          FloatType dp, +          bool& brackt, +          FloatType stpmin, +          FloatType stpmax); + +        void serialize(std::ostream* out) const { +          out->write((const char*)&infoc,sizeof(infoc)); +          out->write((const char*)&dginit,sizeof(dginit)); +          out->write((const char*)&brackt,sizeof(brackt)); +          out->write((const char*)&stage1,sizeof(stage1)); +          out->write((const char*)&finit,sizeof(finit)); +          out->write((const char*)&dgtest,sizeof(dgtest)); +          out->write((const char*)&width,sizeof(width)); +          out->write((const char*)&width1,sizeof(width1)); +          out->write((const char*)&stx,sizeof(stx)); +          out->write((const char*)&fx,sizeof(fx)); +          out->write((const char*)&dgx,sizeof(dgx)); +          out->write((const char*)&sty,sizeof(sty)); +          out->write((const char*)&fy,sizeof(fy)); +          out->write((const char*)&dgy,sizeof(dgy)); +          out->write((const char*)&stmin,sizeof(stmin)); +          out->write((const char*)&stmax,sizeof(stmax)); +        } + +        void deserialize(std::istream* in) const { +          in->read((char*)&infoc, sizeof(infoc)); +          in->read((char*)&dginit, sizeof(dginit)); +          in->read((char*)&brackt, sizeof(brackt)); +          in->read((char*)&stage1, sizeof(stage1)); +          in->read((char*)&finit, sizeof(finit)); +          in->read((char*)&dgtest, sizeof(dgtest)); +          in->read((char*)&width, sizeof(width)); +          in->read((char*)&width1, sizeof(width1)); +          in->read((char*)&stx, sizeof(stx)); +          in->read((char*)&fx, sizeof(fx)); +          in->read((char*)&dgx, sizeof(dgx)); +          in->read((char*)&sty, sizeof(sty)); +          in->read((char*)&fy, sizeof(fy)); +          in->read((char*)&dgy, sizeof(dgy)); +          in->read((char*)&stmin, sizeof(stmin)); +          in->read((char*)&stmax, sizeof(stmax)); +        } +    }; + +    template <typename FloatType, typename SizeType> +    void mcsrch<FloatType, SizeType>::run( +      FloatType const& gtol, +      FloatType const& stpmin, +      FloatType const& stpmax, +      SizeType n, +      FloatType* x, +      FloatType f, +      const FloatType* g, +      FloatType* s, +      SizeType is0, +      FloatType& stp, +      FloatType ftol, +      FloatType xtol, +      SizeType maxfev, +      int& info, +      SizeType& nfev, +      FloatType* wa) +    { +      if (info != -1) { +        infoc = 1; +        if (   n == 0 +            || maxfev == 0 +            || gtol < FloatType(0) +            || xtol < FloatType(0) +            || stpmin < FloatType(0) +            || stpmax < stpmin) { +          throw error_internal_error(__FILE__, __LINE__); +        } +        if (stp <= FloatType(0) || ftol < FloatType(0)) { +          throw error_internal_error(__FILE__, __LINE__); +        } +        // Compute the initial gradient in the search direction +        // and check that s is a descent direction. +        dginit = FloatType(0); +        for (SizeType j = 0; j < n; j++) { +          dginit += g[j] * s[is0+j]; +        } +        if (dginit >= FloatType(0)) { +          throw error_search_direction_not_descent(); +        } +        brackt = false; +        stage1 = true; +        nfev = 0; +        finit = f; +        dgtest = ftol*dginit; +        width = stpmax - stpmin; +        width1 = FloatType(2) * width; +        std::copy(x, x+n, wa); +        // The variables stx, fx, dgx contain the values of the step, +        // function, and directional derivative at the best step. +        // The variables sty, fy, dgy contain the value of the step, +        // function, and derivative at the other endpoint of +        // the interval of uncertainty. +        // The variables stp, f, dg contain the values of the step, +        // function, and derivative at the current step. +        stx = FloatType(0); +        fx = finit; +        dgx = dginit; +        sty = FloatType(0); +        fy = finit; +        dgy = dginit; +      } +      for (;;) { +        if (info != -1) { +          // Set the minimum and maximum steps to correspond +          // to the present interval of uncertainty. +          if (brackt) { +            stmin = std::min(stx, sty); +            stmax = std::max(stx, sty); +          } +          else { +            stmin = stx; +            stmax = stp + FloatType(4) * (stp - stx); +          } +          // Force the step to be within the bounds stpmax and stpmin. +          stp = std::max(stp, stpmin); +          stp = std::min(stp, stpmax); +          // If an unusual termination is to occur then let +          // stp be the lowest point obtained so far. +          if (   (brackt && (stp <= stmin || stp >= stmax)) +              || nfev >= maxfev - 1 || infoc == 0 +              || (brackt && stmax - stmin <= xtol * stmax)) { +            stp = stx; +          } +          // Evaluate the function and gradient at stp +          // and compute the directional derivative. +          // We return to main program to obtain F and G. +          for (SizeType j = 0; j < n; j++) { +            x[j] = wa[j] + stp * s[is0+j]; +          } +          info=-1; +          break; +        } +        info = 0; +        nfev++; +        FloatType dg(0); +        for (SizeType j = 0; j < n; j++) { +          dg += g[j] * s[is0+j]; +        } +        FloatType ftest1 = finit + stp*dgtest; +        // Test for convergence. +        if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { +          throw error_line_search_failed_rounding_errors( +            "Rounding errors prevent further progress." +            " There may not be a step which satisfies the" +            " sufficient decrease and curvature conditions." +            " Tolerances may be too small."); +        } +        if (stp == stpmax && f <= ftest1 && dg <= dgtest) { +          throw error_line_search_failed( +            "The step is at the upper bound stpmax()."); +        } +        if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { +          throw error_line_search_failed( +            "The step is at the lower bound stpmin()."); +        } +        if (nfev >= maxfev) { +          throw error_line_search_failed( +            "Number of function evaluations has reached maxfev()."); +        } +        if (brackt && stmax - stmin <= xtol * stmax) { +          throw error_line_search_failed( +            "Relative width of the interval of uncertainty" +            " is at most xtol()."); +        } +        // Check for termination. +        if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { +          info = 1; +          break; +        } +        // In the first stage we seek a step for which the modified +        // function has a nonpositive value and nonnegative derivative. +        if (   stage1 && f <= ftest1 +            && dg >= std::min(ftol, gtol) * dginit) { +          stage1 = false; +        } +        // A modified function is used to predict the step only if +        // we have not obtained a step for which the modified +        // function has a nonpositive function value and nonnegative +        // derivative, and if a lower function value has been +        // obtained but the decrease is not sufficient. +        if (stage1 && f <= fx && f > ftest1) { +          // Define the modified function and derivative values. +          FloatType fm = f - stp*dgtest; +          FloatType fxm = fx - stx*dgtest; +          FloatType fym = fy - sty*dgtest; +          FloatType dgm = dg - dgtest; +          FloatType dgxm = dgx - dgtest; +          FloatType dgym = dgy - dgtest; +          // Call cstep to update the interval of uncertainty +          // and to compute the new step. +          infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, +                         brackt, stmin, stmax); +          // Reset the function and gradient values for f. +          fx = fxm + stx*dgtest; +          fy = fym + sty*dgtest; +          dgx = dgxm + dgtest; +          dgy = dgym + dgtest; +        } +        else { +          // Call mcstep to update the interval of uncertainty +          // and to compute the new step. +          infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, +                         brackt, stmin, stmax); +        } +        // Force a sufficient decrease in the size of the +        // interval of uncertainty. +        if (brackt) { +          if (abs(sty - stx) >= FloatType(0.66) * width1) { +            stp = stx + FloatType(0.5) * (sty - stx); +          } +          width1 = width; +          width = abs(sty - stx); +        } +      } +    } + +    template <typename FloatType, typename SizeType> +    int mcsrch<FloatType, SizeType>::mcstep( +      FloatType& stx, +      FloatType& fx, +      FloatType& dx, +      FloatType& sty, +      FloatType& fy, +      FloatType& dy, +      FloatType& stp, +      FloatType fp, +      FloatType dp, +      bool& brackt, +      FloatType stpmin, +      FloatType stpmax) +    { +      bool bound; +      FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; +      int info = 0; +      if (   (   brackt && (stp <= std::min(stx, sty) +              || stp >= std::max(stx, sty))) +          || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { +        return 0; +      } +      // Determine if the derivatives have opposite sign. +      sgnd = dp * (dx / abs(dx)); +      if (fp > fx) { +        // First case. A higher function value. +        // The minimum is bracketed. If the cubic step is closer +        // to stx than the quadratic step, the cubic step is taken, +        // else the average of the cubic and quadratic steps is taken. +        info = 1; +        bound = true; +        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; +        s = max3(abs(theta), abs(dx), abs(dp)); +        gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); +        if (stp < stx) gamma = - gamma; +        p = (gamma - dx) + theta; +        q = ((gamma - dx) + gamma) + dp; +        r = p/q; +        stpc = stx + r * (stp - stx); +        stpq = stx +          + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) +            * (stp - stx); +        if (abs(stpc - stx) < abs(stpq - stx)) { +          stpf = stpc; +        } +        else { +          stpf = stpc + (stpq - stpc) / FloatType(2); +        } +        brackt = true; +      } +      else if (sgnd < FloatType(0)) { +        // Second case. A lower function value and derivatives of +        // opposite sign. The minimum is bracketed. If the cubic +        // step is closer to stx than the quadratic (secant) step, +        // the cubic step is taken, else the quadratic step is taken. +        info = 2; +        bound = false; +        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; +        s = max3(abs(theta), abs(dx), abs(dp)); +        gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); +        if (stp > stx) gamma = - gamma; +        p = (gamma - dp) + theta; +        q = ((gamma - dp) + gamma) + dx; +        r = p/q; +        stpc = stp + r * (stx - stp); +        stpq = stp + (dp / (dp - dx)) * (stx - stp); +        if (abs(stpc - stp) > abs(stpq - stp)) { +          stpf = stpc; +        } +        else { +          stpf = stpq; +        } +        brackt = true; +      } +      else if (abs(dp) < abs(dx)) { +        // Third case. A lower function value, derivatives of the +        // same sign, and the magnitude of the derivative decreases. +        // The cubic step is only used if the cubic tends to infinity +        // in the direction of the step or if the minimum of the cubic +        // is beyond stp. Otherwise the cubic step is defined to be +        // either stpmin or stpmax. The quadratic (secant) step is also +        // computed and if the minimum is bracketed then the the step +        // closest to stx is taken, else the step farthest away is taken. +        info = 3; +        bound = true; +        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; +        s = max3(abs(theta), abs(dx), abs(dp)); +        gamma = s * std::sqrt( +          std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); +        if (stp > stx) gamma = -gamma; +        p = (gamma - dp) + theta; +        q = (gamma + (dx - dp)) + gamma; +        r = p/q; +        if (r < FloatType(0) && gamma != FloatType(0)) { +          stpc = stp + r * (stx - stp); +        } +        else if (stp > stx) { +          stpc = stpmax; +        } +        else { +          stpc = stpmin; +        } +        stpq = stp + (dp / (dp - dx)) * (stx - stp); +        if (brackt) { +          if (abs(stp - stpc) < abs(stp - stpq)) { +            stpf = stpc; +          } +          else { +            stpf = stpq; +          } +        } +        else { +          if (abs(stp - stpc) > abs(stp - stpq)) { +            stpf = stpc; +          } +          else { +            stpf = stpq; +          } +        } +      } +      else { +        // Fourth case. A lower function value, derivatives of the +        // same sign, and the magnitude of the derivative does +        // not decrease. If the minimum is not bracketed, the step +        // is either stpmin or stpmax, else the cubic step is taken. +        info = 4; +        bound = false; +        if (brackt) { +          theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; +          s = max3(abs(theta), abs(dy), abs(dp)); +          gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); +          if (stp > sty) gamma = -gamma; +          p = (gamma - dp) + theta; +          q = ((gamma - dp) + gamma) + dy; +          r = p/q; +          stpc = stp + r * (sty - stp); +          stpf = stpc; +        } +        else if (stp > stx) { +          stpf = stpmax; +        } +        else { +          stpf = stpmin; +        } +      } +      // Update the interval of uncertainty. This update does not +      // depend on the new step or the case analysis above. +      if (fp > fx) { +        sty = stp; +        fy = fp; +        dy = dp; +      } +      else { +        if (sgnd < FloatType(0)) { +          sty = stx; +          fy = fx; +          dy = dx; +        } +        stx = stp; +        fx = fp; +        dx = dp; +      } +      // Compute the new step and safeguard it. +      stpf = std::min(stpmax, stpf); +      stpf = std::max(stpmin, stpf); +      stp = stpf; +      if (brackt && bound) { +        if (sty > stx) { +          stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); +        } +        else { +          stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); +        } +      } +      return info; +    } + +    /* Compute the sum of a vector times a scalar plus another vector. +       Adapted from the subroutine <code>daxpy</code> in +       <code>lbfgs.f</code>. +     */ +    template <typename FloatType, typename SizeType> +    void daxpy( +      SizeType n, +      FloatType da, +      const FloatType* dx, +      SizeType ix0, +      SizeType incx, +      FloatType* dy, +      SizeType iy0, +      SizeType incy) +    { +      SizeType i, ix, iy, m; +      if (n == 0) return; +      if (da == FloatType(0)) return; +      if  (!(incx == 1 && incy == 1)) { +        ix = 0; +        iy = 0; +        for (i = 0; i < n; i++) { +          dy[iy0+iy] += da * dx[ix0+ix]; +          ix += incx; +          iy += incy; +        } +        return; +      } +      m = n % 4; +      for (i = 0; i < m; i++) { +        dy[iy0+i] += da * dx[ix0+i]; +      } +      for (; i < n;) { +        dy[iy0+i] += da * dx[ix0+i]; i++; +        dy[iy0+i] += da * dx[ix0+i]; i++; +        dy[iy0+i] += da * dx[ix0+i]; i++; +        dy[iy0+i] += da * dx[ix0+i]; i++; +      } +    } + +    template <typename FloatType, typename SizeType> +    inline +    void daxpy( +      SizeType n, +      FloatType da, +      const FloatType* dx, +      SizeType ix0, +      FloatType* dy) +    { +      daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); +    } + +    /* Compute the dot product of two vectors. +       Adapted from the subroutine <code>ddot</code> +       in <code>lbfgs.f</code>. +     */ +    template <typename FloatType, typename SizeType> +    FloatType ddot( +      SizeType n, +      const FloatType* dx, +      SizeType ix0, +      SizeType incx, +      const FloatType* dy, +      SizeType iy0, +      SizeType incy) +    { +      SizeType i, ix, iy, m; +      FloatType dtemp(0); +      if (n == 0) return FloatType(0); +      if (!(incx == 1 && incy == 1)) { +        ix = 0; +        iy = 0; +        for (i = 0; i < n; i++) { +          dtemp += dx[ix0+ix] * dy[iy0+iy]; +          ix += incx; +          iy += incy; +        } +        return dtemp; +      } +      m = n % 5; +      for (i = 0; i < m; i++) { +        dtemp += dx[ix0+i] * dy[iy0+i]; +      } +      for (; i < n;) { +        dtemp += dx[ix0+i] * dy[iy0+i]; i++; +        dtemp += dx[ix0+i] * dy[iy0+i]; i++; +        dtemp += dx[ix0+i] * dy[iy0+i]; i++; +        dtemp += dx[ix0+i] * dy[iy0+i]; i++; +        dtemp += dx[ix0+i] * dy[iy0+i]; i++; +      } +      return dtemp; +    } + +    template <typename FloatType, typename SizeType> +    inline +    FloatType ddot( +      SizeType n, +      const FloatType* dx, +      const FloatType* dy) +    { +      return ddot( +        n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); +    } + +  } // namespace detail + +  //! Interface to the LBFGS %minimizer. +  /*! This class solves the unconstrained minimization problem +      <pre> +          min f(x),  x = (x1,x2,...,x_n), +      </pre> +      using the limited-memory BFGS method. The routine is +      especially effective on problems involving a large number of +      variables. In a typical iteration of this method an +      approximation Hk to the inverse of the Hessian +      is obtained by applying <code>m</code> BFGS updates to a +      diagonal matrix Hk0, using information from the +      previous <code>m</code> steps.  The user specifies the number +      <code>m</code>, which determines the amount of storage +      required by the routine. The user may also provide the +      diagonal matrices Hk0 (parameter <code>diag</code> in the run() +      function) if not satisfied with the default choice. The +      algorithm is described in "On the limited memory BFGS method for +      large scale optimization", by D. Liu and J. Nocedal, Mathematical +      Programming B 45 (1989) 503-528. + +      The user is required to calculate the function value +      <code>f</code> and its gradient <code>g</code>. In order to +      allow the user complete control over these computations, +      reverse communication is used. The routine must be called +      repeatedly under the control of the member functions +      <code>requests_f_and_g()</code>, +      <code>requests_diag()</code>. +      If neither requests_f_and_g() nor requests_diag() is +      <code>true</code> the user should check for convergence +      (using class traditional_convergence_test or any +      other custom test). If the convergence test is negative, +      the minimizer may be called again for the next iteration. + +      The steplength (stp()) is determined at each iteration +      by means of the line search routine <code>mcsrch</code>, which is +      a slight modification of the routine <code>CSRCH</code> written +      by More' and Thuente. + +      The only variables that are machine-dependent are +      <code>xtol</code>, +      <code>stpmin</code> and +      <code>stpmax</code>. + +      Fatal errors cause <code>error</code> exceptions to be thrown. +      The generic class <code>error</code> is sub-classed (e.g. +      class <code>error_line_search_failed</code>) to facilitate +      granular %error handling. + +      A note on performance: Using Compaq Fortran V5.4 and +      Compaq C++ V6.5, the C++ implementation is about 15% slower +      than the Fortran implementation. +   */ +  template <typename FloatType, typename SizeType = std::size_t> +  class minimizer +  { +    public: +      //! Default constructor. Some members are not initialized! +      minimizer() +      : n_(0), m_(0), maxfev_(0), +        gtol_(0), xtol_(0), +        stpmin_(0), stpmax_(0), +        ispt(0), iypt(0) +      {} + +      //! Constructor. +      /*! @param n The number of variables in the minimization problem. +             Restriction: <code>n > 0</code>. + +          @param m The number of corrections used in the BFGS update. +             Values of <code>m</code> less than 3 are not recommended; +             large values of <code>m</code> will result in excessive +             computing time. <code>3 <= m <= 7</code> is +             recommended. +             Restriction: <code>m > 0</code>. + +          @param maxfev Maximum number of function evaluations +             <b>per line search</b>. +             Termination occurs when the number of evaluations +             of the objective function is at least <code>maxfev</code> by +             the end of an iteration. + +          @param gtol Controls the accuracy of the line search. +            If the function and gradient evaluations are inexpensive with +            respect to the cost of the iteration (which is sometimes the +            case when solving very large problems) it may be advantageous +            to set <code>gtol</code> to a small value. A typical small +            value is 0.1. +            Restriction: <code>gtol</code> should be greater than 1e-4. + +          @param xtol An estimate of the machine precision (e.g. 10e-16 +            on a SUN station 3/60). The line search routine will +            terminate if the relative width of the interval of +            uncertainty is less than <code>xtol</code>. + +          @param stpmin Specifies the lower bound for the step +            in the line search. +            The default value is 1e-20. This value need not be modified +            unless the exponent is too large for the machine being used, +            or unless the problem is extremely badly scaled (in which +            case the exponent should be increased). + +          @param stpmax specifies the upper bound for the step +            in the line search. +            The default value is 1e20. This value need not be modified +            unless the exponent is too large for the machine being used, +            or unless the problem is extremely badly scaled (in which +            case the exponent should be increased). +       */ +      explicit +      minimizer( +        SizeType n, +        SizeType m = 5, +        SizeType maxfev = 20, +        FloatType gtol = FloatType(0.9), +        FloatType xtol = FloatType(1.e-16), +        FloatType stpmin = FloatType(1.e-20), +        FloatType stpmax = FloatType(1.e20)) +        : n_(n), m_(m), maxfev_(maxfev), +          gtol_(gtol), xtol_(xtol), +          stpmin_(stpmin), stpmax_(stpmax), +          iflag_(0), requests_f_and_g_(false), requests_diag_(false), +          iter_(0), nfun_(0), stp_(0), +          stp1(0), ftol(0.0001), ys(0), point(0), npt(0), +          ispt(n+2*m), iypt((n+2*m)+n*m), +          info(0), bound(0), nfev(0) +      { +        if (n_ == 0) { +          throw error_improper_input_parameter("n = 0."); +        } +        if (m_ == 0) { +          throw error_improper_input_parameter("m = 0."); +        } +        if (maxfev_ == 0) { +         throw error_improper_input_parameter("maxfev = 0."); +        } +        if (gtol_ <= FloatType(1.e-4)) { +          throw error_improper_input_parameter("gtol <= 1.e-4."); +        } +        if (xtol_ < FloatType(0)) { +          throw error_improper_input_parameter("xtol < 0."); +        } +        if (stpmin_ < FloatType(0)) { +          throw error_improper_input_parameter("stpmin < 0."); +        } +        if (stpmax_ < stpmin) { +          throw error_improper_input_parameter("stpmax < stpmin"); +        } +        w_.resize(n_*(2*m_+1)+2*m_); +        scratch_array_.resize(n_); +      } + +      //! Number of free parameters (as passed to the constructor). +      SizeType n() const { return n_; } + +      //! Number of corrections kept (as passed to the constructor). +      SizeType m() const { return m_; } + +      /*! \brief Maximum number of evaluations of the objective function +          per line search (as passed to the constructor). +       */ +      SizeType maxfev() const { return maxfev_; } + +      /*! \brief Control of the accuracy of the line search. +          (as passed to the constructor). +       */ +      FloatType gtol() const { return gtol_; } + +      //! Estimate of the machine precision (as passed to the constructor). +      FloatType xtol() const { return xtol_; } + +      /*! \brief Lower bound for the step in the line search. +          (as passed to the constructor). +       */ +      FloatType stpmin() const { return stpmin_; } + +      /*! \brief Upper bound for the step in the line search. +          (as passed to the constructor). +       */ +      FloatType stpmax() const { return stpmax_; } + +      //! Status indicator for reverse communication. +      /*! <code>true</code> if the run() function returns to request +          evaluation of the objective function (<code>f</code>) and +          gradients (<code>g</code>) for the current point +          (<code>x</code>). To continue the minimization the +          run() function is called again with the updated values for +          <code>f</code> and <code>g</code>. +          <p> +          See also: requests_diag() +       */ +      bool requests_f_and_g() const { return requests_f_and_g_; } + +      //! Status indicator for reverse communication. +      /*! <code>true</code> if the run() function returns to request +          evaluation of the diagonal matrix (<code>diag</code>) +          for the current point (<code>x</code>). +          To continue the minimization the run() function is called +          again with the updated values for <code>diag</code>. +          <p> +          See also: requests_f_and_g() +       */ +      bool requests_diag() const { return requests_diag_; } + +      //! Number of iterations so far. +      /*! Note that one iteration may involve multiple evaluations +          of the objective function. +          <p> +          See also: nfun() +       */ +      SizeType iter() const { return iter_; } + +      //! Total number of evaluations of the objective function so far. +      /*! The total number of function evaluations increases by the +          number of evaluations required for the line search. The total +          is only increased after a successful line search. +          <p> +          See also: iter() +       */ +      SizeType nfun() const { return nfun_; } + +      //! Norm of gradient given gradient array of length n(). +      FloatType euclidean_norm(const FloatType* a) const { +        return std::sqrt(detail::ddot(n_, a, a)); +      } + +      //! Current stepsize. +      FloatType stp() const { return stp_; } + +      //! Execution of one step of the minimization. +      /*! @param x On initial entry this must be set by the user to +             the values of the initial estimate of the solution vector. + +          @param f Before initial entry or on re-entry under the +             control of requests_f_and_g(), <code>f</code> must be set +             by the user to contain the value of the objective function +             at the current point <code>x</code>. + +          @param g Before initial entry or on re-entry under the +             control of requests_f_and_g(), <code>g</code> must be set +             by the user to contain the components of the gradient at +             the current point <code>x</code>. + +          The return value is <code>true</code> if either +          requests_f_and_g() or requests_diag() is <code>true</code>. +          Otherwise the user should check for convergence +          (e.g. using class traditional_convergence_test) and +          call the run() function again to continue the minimization. +          If the return value is <code>false</code> the user +          should <b>not</b> update <code>f</code>, <code>g</code> or +          <code>diag</code> (other overload) before calling +          the run() function again. + +          Note that <code>x</code> is always modified by the run() +          function. Depending on the situation it can therefore be +          necessary to evaluate the objective function one more time +          after the minimization is terminated. +       */ +      bool run( +        FloatType* x, +        FloatType f, +        const FloatType* g) +      { +        return generic_run(x, f, g, false, 0); +      } + +      //! Execution of one step of the minimization. +      /*! @param x See other overload. + +          @param f See other overload. + +          @param g See other overload. + +          @param diag On initial entry or on re-entry under the +             control of requests_diag(), <code>diag</code> must be set by +             the user to contain the values of the diagonal matrix Hk0. +             The routine will return at each iteration of the algorithm +             with requests_diag() set to <code>true</code>. +             <p> +             Restriction: all elements of <code>diag</code> must be +             positive. +       */ +      bool run( +        FloatType* x, +        FloatType f, +        const FloatType* g, +        const FloatType* diag) +      { +        return generic_run(x, f, g, true, diag); +      } + +      void serialize(std::ostream* out) const { +        out->write((const char*)&n_, sizeof(n_)); // sanity check +        out->write((const char*)&m_, sizeof(m_)); // sanity check +        SizeType fs = sizeof(FloatType); +        out->write((const char*)&fs, sizeof(fs)); // sanity check + +        mcsrch_instance.serialize(out); +        out->write((const char*)&iflag_, sizeof(iflag_)); +        out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); +        out->write((const char*)&requests_diag_, sizeof(requests_diag_)); +        out->write((const char*)&iter_, sizeof(iter_)); +        out->write((const char*)&nfun_, sizeof(nfun_)); +        out->write((const char*)&stp_, sizeof(stp_)); +        out->write((const char*)&stp1, sizeof(stp1)); +        out->write((const char*)&ftol, sizeof(ftol)); +        out->write((const char*)&ys, sizeof(ys)); +        out->write((const char*)&point, sizeof(point)); +        out->write((const char*)&npt, sizeof(npt)); +        out->write((const char*)&info, sizeof(info)); +        out->write((const char*)&bound, sizeof(bound)); +        out->write((const char*)&nfev, sizeof(nfev)); +        out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); +        out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); +      } + +      void deserialize(std::istream* in) { +        SizeType n, m, fs; +        in->read((char*)&n, sizeof(n)); +        in->read((char*)&m, sizeof(m)); +        in->read((char*)&fs, sizeof(fs)); +        assert(n == n_); +        assert(m == m_); +        assert(fs == sizeof(FloatType)); + +        mcsrch_instance.deserialize(in); +        in->read((char*)&iflag_, sizeof(iflag_)); +        in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); +        in->read((char*)&requests_diag_, sizeof(requests_diag_)); +        in->read((char*)&iter_, sizeof(iter_)); +        in->read((char*)&nfun_, sizeof(nfun_)); +        in->read((char*)&stp_, sizeof(stp_)); +        in->read((char*)&stp1, sizeof(stp1)); +        in->read((char*)&ftol, sizeof(ftol)); +        in->read((char*)&ys, sizeof(ys)); +        in->read((char*)&point, sizeof(point)); +        in->read((char*)&npt, sizeof(npt)); +        in->read((char*)&info, sizeof(info)); +        in->read((char*)&bound, sizeof(bound)); +        in->read((char*)&nfev, sizeof(nfev)); +        in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); +        in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); +      } + +    protected: +      static void throw_diagonal_element_not_positive(SizeType i) { +        throw error_improper_input_data( +          "The " + error::itoa(i) + ". diagonal element of the" +          " inverse Hessian approximation is not positive."); +      } + +      bool generic_run( +        FloatType* x, +        FloatType f, +        const FloatType* g, +        bool diagco, +        const FloatType* diag); + +      detail::mcsrch<FloatType, SizeType> mcsrch_instance; +      const SizeType n_; +      const SizeType m_; +      const SizeType maxfev_; +      const FloatType gtol_; +      const FloatType xtol_; +      const FloatType stpmin_; +      const FloatType stpmax_; +      int iflag_; +      bool requests_f_and_g_; +      bool requests_diag_; +      SizeType iter_; +      SizeType nfun_; +      FloatType stp_; +      FloatType stp1; +      FloatType ftol; +      FloatType ys; +      SizeType point; +      SizeType npt; +      const SizeType ispt; +      const SizeType iypt; +      int info; +      SizeType bound; +      SizeType nfev; +      std::vector<FloatType> w_; +      std::vector<FloatType> scratch_array_; +  }; + +  template <typename FloatType, typename SizeType> +  bool minimizer<FloatType, SizeType>::generic_run( +    FloatType* x, +    FloatType f, +    const FloatType* g, +    bool diagco, +    const FloatType* diag) +  { +    bool execute_entire_while_loop = false; +    if (!(requests_f_and_g_ || requests_diag_)) { +      execute_entire_while_loop = true; +    } +    requests_f_and_g_ = false; +    requests_diag_ = false; +    FloatType* w = &(*(w_.begin())); +    if (iflag_ == 0) { // Initialize. +      nfun_ = 1; +      if (diagco) { +        for (SizeType i = 0; i < n_; i++) { +          if (diag[i] <= FloatType(0)) { +            throw_diagonal_element_not_positive(i); +          } +        } +      } +      else { +        std::fill_n(scratch_array_.begin(), n_, FloatType(1)); +        diag = &(*(scratch_array_.begin())); +      } +      for (SizeType i = 0; i < n_; i++) { +        w[ispt + i] = -g[i] * diag[i]; +      } +      FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); +      if (gnorm == FloatType(0)) return false; +      stp1 = FloatType(1) / gnorm; +      execute_entire_while_loop = true; +    } +    if (execute_entire_while_loop) { +      bound = iter_; +      iter_++; +      info = 0; +      if (iter_ != 1) { +        if (iter_ > m_) bound = m_; +        ys = detail::ddot( +          n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); +        if (!diagco) { +          FloatType yy = detail::ddot( +            n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); +          std::fill_n(scratch_array_.begin(), n_, ys / yy); +          diag = &(*(scratch_array_.begin())); +        } +        else { +          iflag_ = 2; +          requests_diag_ = true; +          return true; +        } +      } +    } +    if (execute_entire_while_loop || iflag_ == 2) { +      if (iter_ != 1) { +        if (diag == 0) { +          throw error_internal_error(__FILE__, __LINE__); +        } +        if (diagco) { +          for (SizeType i = 0; i < n_; i++) { +            if (diag[i] <= FloatType(0)) { +              throw_diagonal_element_not_positive(i); +            } +          } +        } +        SizeType cp = point; +        if (point == 0) cp = m_; +        w[n_ + cp -1] = 1 / ys; +        SizeType i; +        for (i = 0; i < n_; i++) { +          w[i] = -g[i]; +        } +        cp = point; +        for (i = 0; i < bound; i++) { +          if (cp == 0) cp = m_; +          cp--; +          FloatType sq = detail::ddot( +            n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); +          SizeType inmc=n_+m_+cp; +          SizeType iycn=iypt+cp*n_; +          w[inmc] = w[n_ + cp] * sq; +          detail::daxpy(n_, -w[inmc], w, iycn, w); +        } +        for (i = 0; i < n_; i++) { +          w[i] *= diag[i]; +        } +        for (i = 0; i < bound; i++) { +          FloatType yr = detail::ddot( +            n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); +          FloatType beta = w[n_ + cp] * yr; +          SizeType inmc=n_+m_+cp; +          beta = w[inmc] - beta; +          SizeType iscn=ispt+cp*n_; +          detail::daxpy(n_, beta, w, iscn, w); +          cp++; +          if (cp == m_) cp = 0; +        } +        std::copy(w, w+n_, w+(ispt + point * n_)); +      } +      stp_ = FloatType(1); +      if (iter_ == 1) stp_ = stp1; +      std::copy(g, g+n_, w); +    } +    mcsrch_instance.run( +      gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, +      stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); +    if (info == -1) { +      iflag_ = 1; +      requests_f_and_g_ = true; +      return true; +    } +    if (info != 1) { +      throw error_internal_error(__FILE__, __LINE__); +    } +    nfun_ += nfev; +    npt = point*n_; +    for (SizeType i = 0; i < n_; i++) { +      w[ispt + npt + i] = stp_ * w[ispt + npt + i]; +      w[iypt + npt + i] = g[i] - w[i]; +    } +    point++; +    if (point == m_) point = 0; +    return false; +  } + +  //! Traditional LBFGS convergence test. +  /*! This convergence test is equivalent to the test embedded +      in the <code>lbfgs.f</code> Fortran code. The test assumes that +      there is a meaningful relation between the Euclidean norm of the +      parameter vector <code>x</code> and the norm of the gradient +      vector <code>g</code>. Therefore this test should not be used if +      this assumption is not correct for a given problem. +   */ +  template <typename FloatType, typename SizeType = std::size_t> +  class traditional_convergence_test +  { +    public: +      //! Default constructor. +      traditional_convergence_test() +      : n_(0), eps_(0) +      {} + +      //! Constructor. +      /*! @param n The number of variables in the minimization problem. +             Restriction: <code>n > 0</code>. + +          @param eps Determines the accuracy with which the solution +            is to be found. +       */ +      explicit +      traditional_convergence_test( +        SizeType n, +        FloatType eps = FloatType(1.e-5)) +      : n_(n), eps_(eps) +      { +        if (n_ == 0) { +          throw error_improper_input_parameter("n = 0."); +        } +        if (eps_ < FloatType(0)) { +          throw error_improper_input_parameter("eps < 0."); +        } +      } + +      //! Number of free parameters (as passed to the constructor). +      SizeType n() const { return n_; } + +      /*! \brief Accuracy with which the solution is to be found +          (as passed to the constructor). +       */ +      FloatType eps() const { return eps_; } + +      //! Execution of the convergence test for the given parameters. +      /*! Returns <code>true</code> if +          <pre> +            ||g|| < eps * max(1,||x||), +          </pre> +          where <code>||.||</code> denotes the Euclidean norm. + +          @param x Current solution vector. + +          @param g Components of the gradient at the current +            point <code>x</code>. +       */ +      bool +      operator()(const FloatType* x, const FloatType* g) const +      { +        FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); +        FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); +        if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; +        return false; +      } +    protected: +      const SizeType n_; +      const FloatType eps_; +  }; + +}} // namespace scitbx::lbfgs + +template <typename T> +std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer<T>& min) { +  return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); +} + + +#endif // SCITBX_LBFGS_H diff --git a/training/utils/lbfgs_test.cc b/training/utils/lbfgs_test.cc new file mode 100644 index 00000000..9678e788 --- /dev/null +++ b/training/utils/lbfgs_test.cc @@ -0,0 +1,117 @@ +#include <cassert> +#include <iostream> +#include <sstream> +#include <cmath> +#include "lbfgs.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer() { +  cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + +  // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 +  // df/dx1 = 8*x1 + x2 +  // df/dx2 = 2*x2 + x1 +  // df/dx3 = 2*x3 + 6 +  double x[3]; +  double g[3]; +  scitbx::lbfgs::minimizer<double> opt(3); +  scitbx::lbfgs::traditional_convergence_test<double> converged(3); +  x[0] = 8; +  x[1] = 8; +  x[2] = 8; +  double obj = 0; +  do { +    g[0] = 8 * x[0] + x[1]; +    g[1] = 2 * x[1] + x[0]; +    g[2] = 2 * x[2] + 6; +    obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; +    opt.run(x, obj, g); +    if (!opt.requests_f_and_g()) { +      if (converged(x,g)) break; +      opt.run(x, obj, g); +    } +    cerr << x[0] << " " << x[1] << " " << x[2] << endl; +    cerr << "   obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; +    cerr << opt << endl; +  } while (true); +  return obj; +} + +double TestPersistentOptimizer() { +  cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; +  // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 +  // df/dx1 = 8*x1 + x2 +  // df/dx2 = 2*x2 + x1 +  // df/dx3 = 2*x3 + 6 +  double x[3]; +  double g[3]; +  scitbx::lbfgs::traditional_convergence_test<double> converged(3); +  x[0] = 8; +  x[1] = 8; +  x[2] = 8; +  double obj = 0; +  string state; +  do { +    g[0] = 8 * x[0] + x[1]; +    g[1] = 2 * x[1] + x[0]; +    g[2] = 2 * x[2] + 6; +    obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + +    { +      scitbx::lbfgs::minimizer<double> opt(3); +      if (state.size() > 0) { +        istringstream is(state, ios::binary); +        opt.deserialize(&is); +      } +      opt.run(x, obj, g); +      ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); +    } + +    cerr << x[0] << " " << x[1] << " " << x[2] << endl; +    cerr << "   obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; +  } while (!converged(x, g)); +  return obj; +} + +void TestSparseVector() { +  cerr << "Testing SparseVector<double> serialization.\n"; +  int f1 = FD::Convert("Feature_1"); +  int f2 = FD::Convert("Feature_2"); +  FD::Convert("LanguageModel"); +  int f4 = FD::Convert("SomeFeature"); +  int f5 = FD::Convert("SomeOtherFeature"); +  SparseVector<double> g; +  g.set_value(f2, log(0.5)); +  g.set_value(f4, log(0.125)); +  g.set_value(f1, 0); +  g.set_value(f5, 23.777); +  ostringstream os; +  double iobj = 1.5; +  B64::Encode(iobj, g, &os); +  cerr << iobj << "\t" << g << endl; +  string data = os.str(); +  cout << data << endl; +  SparseVector<double> v; +  double obj; +  bool decode_b64 = B64::Decode(&obj, &v, &data[0], data.size()); +  cerr << obj << "\t" << v << endl; +  assert(decode_b64); +  assert(obj == iobj); +  assert(g.size() == v.size()); +} + +int main() { +  double o1 = TestOptimizer(); +  double o2 = TestPersistentOptimizer(); +  if (fabs(o1 - o2) > 1e-5) { +    cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; +    return 1; +  } +  TestSparseVector(); +  cerr << "SUCCESS\n"; +  return 0; +} + diff --git a/training/utils/libcall.pl b/training/utils/libcall.pl new file mode 100644 index 00000000..c7d0f128 --- /dev/null +++ b/training/utils/libcall.pl @@ -0,0 +1,71 @@ +use IPC::Open3; +use Symbol qw(gensym); + +$DUMMY_STDERR = gensym(); +$DUMMY_STDIN = gensym(); + +# Run the command and ignore failures +sub unchecked_call { +    system("@_") +} + +# Run the command and return its output, if any ignoring failures +sub unchecked_output { +    return `@_` +} + +# WARNING: Do not use this for commands that will return large amounts +# of stdout or stderr -- they might block indefinitely +sub check_output { +    print STDERR "Executing and gathering output: @_\n"; + +    my $pid = open3($DUMMY_STDIN, \*PH, $DUMMY_STDERR, @_); +    my $proc_output = ""; +    while( <PH> ) { +	$proc_output .= $_; +    } +    waitpid($pid, 0); +    # TODO: Grab signal that the process died from +    my $child_exit_status = $? >> 8; +    if($child_exit_status == 0) { +	return $proc_output; +    } else { +	print STDERR "ERROR: Execution of @_ failed.\n"; +	exit(1); +    } +} + +# Based on Moses' safesystem sub +sub check_call { +    print STDERR "Executing: @_\n"; +    system(@_); +    my $exitcode = $? >> 8; +    if($exitcode == 0) { +	return 0; +    } elsif ($? == -1) { +	print STDERR "ERROR: Failed to execute: @_\n  $!\n"; +	exit(1); + +    } elsif ($? & 127) { +      printf STDERR "ERROR: Execution of: @_\n  died with signal %d, %s coredump\n", +      ($? & 127),  ($? & 128) ? 'with' : 'without'; +      exit(1); + +    } else { +	print STDERR "Failed with exit code: $exitcode\n" if $exitcode; +	exit($exitcode); +    } +} + +sub check_bash_call { +    my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); +    check_call(@args); +} + +sub check_bash_output { +    my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); +    return check_output(@args); +} + +# perl module weirdness... +return 1; diff --git a/training/utils/online_optimizer.cc b/training/utils/online_optimizer.cc new file mode 100644 index 00000000..3ed95452 --- /dev/null +++ b/training/utils/online_optimizer.cc @@ -0,0 +1,16 @@ +#include "online_optimizer.h" + +LearningRateSchedule::~LearningRateSchedule() {} + +double StandardLearningRate::eta(int k) const { +  return eta_0_ / (1.0 + k / N_); +} + +double ExponentialDecayLearningRate::eta(int k) const { +  return eta_0_ * pow(alpha_, k / N_); +} + +OnlineOptimizer::~OnlineOptimizer() {} + +void OnlineOptimizer::ResetEpochImpl() {} + diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h new file mode 100644 index 00000000..28d89344 --- /dev/null +++ b/training/utils/online_optimizer.h @@ -0,0 +1,129 @@ +#ifndef _ONL_OPTIMIZE_H_ +#define _ONL_OPTIMIZE_H_ + +#include <tr1/memory> +#include <set> +#include <string> +#include <cmath> +#include "sparse_vector.h" + +struct LearningRateSchedule { +  virtual ~LearningRateSchedule(); +  // returns the learning rate for the kth iteration +  virtual double eta(int k) const = 0; +}; + +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) +struct StandardLearningRate : public LearningRateSchedule { +  StandardLearningRate( +      size_t batch_size,        // batch size, not corpus size! +      double eta_0 = 0.2) : +    eta_0_(eta_0), +    N_(static_cast<double>(batch_size)) {} + +  virtual double eta(int k) const; + + private: +  const double eta_0_; +  const double N_; +}; + +struct ExponentialDecayLearningRate : public LearningRateSchedule { +  ExponentialDecayLearningRate( +      size_t batch_size,        // batch size, not corpus size! +      double eta_0 = 0.2, +      double alpha = 0.85       // recommended by Tsuruoka et al. (ACL 2009) +    ) : eta_0_(eta_0), +        N_(static_cast<double>(batch_size)), +        alpha_(alpha) { +    assert(alpha > 0); +    assert(alpha < 1.0); +  } + +  virtual double eta(int k) const; + + private: +  const double eta_0_; +  const double N_; +  const double alpha_; +}; + +class OnlineOptimizer { + public: +  virtual ~OnlineOptimizer(); +  OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, +                  size_t batch_size, +                  const std::vector<int>& frozen_feats = std::vector<int>()) +      : N_(batch_size),schedule_(s),k_() { +    for (int i = 0; i < frozen_feats.size(); ++i) +      frozen_.insert(frozen_feats[i]); +  } +  void ResetEpoch() { k_ = 0; ResetEpochImpl(); } +  void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { +    ++k_; +    const double eta = schedule_->eta(k_); +    UpdateWeightsImpl(eta, approx_g, max_feat, weights); +  } + + protected: +  virtual void ResetEpochImpl(); +  virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0; +  const size_t N_; // number of training instances per batch +  std::set<int> frozen_;  // frozen (non-optimizing) features + + private: +  std::tr1::shared_ptr<LearningRateSchedule> schedule_; +  int k_;  // iteration count +}; + +class CumulativeL1OnlineOptimizer : public OnlineOptimizer { + public: +  CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, +                              size_t training_instances, double C, +                              const std::vector<int>& frozen) : +    OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} + + protected: +  void ResetEpochImpl() { u_ = 0; } +  void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { +    u_ += eta * C_ / N_; +    for (SparseVector<double>::const_iterator it = approx_g.begin();  +         it != approx_g.end(); ++it) { +      if (frozen_.count(it->first) == 0) +        weights->add_value(it->first, eta * it->second); +    } +    for (int i = 1; i < max_feat; ++i) +      if (frozen_.count(i) == 0) ApplyPenalty(i, weights); +  } + + private: +  void ApplyPenalty(int i, SparseVector<double>* w) { +    const double z = w->value(i); +    double w_i = z; +    double q_i = q_.value(i); +    if (w_i > 0.0) +      w_i = std::max(0.0, w_i - (u_ + q_i)); +    else if (w_i < 0.0) +      w_i = std::min(0.0, w_i + (u_ - q_i)); +    q_i += w_i - z; +    if (q_i == 0.0) +      q_.erase(i); +    else +      q_.set_value(i, q_i); +    if (w_i == 0.0) +      w->erase(i); +    else +      w->set_value(i, w_i); +  } + +  const double C_;  // reguarlization strength +  double u_; +  SparseVector<double> q_; +}; + +#endif diff --git a/training/utils/optimize.cc b/training/utils/optimize.cc new file mode 100644 index 00000000..41ac90d8 --- /dev/null +++ b/training/utils/optimize.cc @@ -0,0 +1,102 @@ +#include "optimize.h" + +#include <iostream> +#include <cassert> + +#include "lbfgs.h" + +using namespace std; + +BatchOptimizer::~BatchOptimizer() {} + +void BatchOptimizer::Save(ostream* out) const { +  out->write((const char*)&eval_, sizeof(eval_)); +  out->write((const char*)&has_converged_, sizeof(has_converged_)); +  SaveImpl(out); +  unsigned int magic = 0xABCDDCBA;  // should be uint32_t +  out->write((const char*)&magic, sizeof(magic)); +} + +void BatchOptimizer::Load(istream* in) { +  in->read((char*)&eval_, sizeof(eval_)); +  in->read((char*)&has_converged_, sizeof(has_converged_)); +  LoadImpl(in); +  unsigned int magic = 0;           // should be uint32_t +  in->read((char*)&magic, sizeof(magic)); +  assert(magic == 0xABCDDCBA); +  cerr << Name() << " EVALUATION #" << eval_ << endl; +} + +void BatchOptimizer::SaveImpl(ostream* out) const { +  (void)out; +} + +void BatchOptimizer::LoadImpl(istream* in) { +  (void)in; +} + +string RPropOptimizer::Name() const { +  return "RPropOptimizer"; +} + +void RPropOptimizer::OptimizeImpl(const double& obj, +                              const vector<double>& g, +                              vector<double>* x) { +  for (int i = 0; i < g.size(); ++i) { +    const double g_i = g[i]; +    const double sign_i = (signbit(g_i) ? -1.0 : 1.0); +    const double prod = g_i * prev_g_[i]; +    if (prod > 0.0) { +      const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); +      (*x)[i] -= dij * sign_i; +      delta_ij_[i] = dij; +      prev_g_[i] = g_i; +    } else if (prod < 0.0) { +      delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); +      prev_g_[i] = 0.0; +    } else { +      (*x)[i] -= delta_ij_[i] * sign_i; +      prev_g_[i] = g_i; +    } +  } +} + +void RPropOptimizer::SaveImpl(ostream* out) const { +  const size_t n = prev_g_.size(); +  out->write((const char*)&n, sizeof(n)); +  out->write((const char*)&prev_g_[0], sizeof(double) * n); +  out->write((const char*)&delta_ij_[0], sizeof(double) * n); +} + +void RPropOptimizer::LoadImpl(istream* in) { +  size_t n; +  in->read((char*)&n, sizeof(n)); +  assert(n == prev_g_.size()); +  assert(n == delta_ij_.size()); +  in->read((char*)&prev_g_[0], sizeof(double) * n); +  in->read((char*)&delta_ij_[0], sizeof(double) * n); +} + +string LBFGSOptimizer::Name() const { +  return "LBFGSOptimizer"; +} + +LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : +  opt_(num_feats, memory_buffers) {} + +void LBFGSOptimizer::SaveImpl(ostream* out) const { +  opt_.serialize(out); +} + +void LBFGSOptimizer::LoadImpl(istream* in) { +  opt_.deserialize(in); +} + +void LBFGSOptimizer::OptimizeImpl(const double& obj, +                                  const vector<double>& g, +                                  vector<double>* x) { +  opt_.run(&(*x)[0], obj, &g[0]); +  if (!opt_.requests_f_and_g()) opt_.run(&(*x)[0], obj, &g[0]); +  // cerr << opt_ << endl; +} + diff --git a/training/utils/optimize.h b/training/utils/optimize.h new file mode 100644 index 00000000..07943b44 --- /dev/null +++ b/training/utils/optimize.h @@ -0,0 +1,92 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include <iostream> +#include <vector> +#include <string> +#include <cassert> + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class BatchOptimizer { + public: +  BatchOptimizer() : eval_(1), has_converged_(false) {} +  virtual ~BatchOptimizer(); +  virtual std::string Name() const = 0; +  int EvaluationCount() const { return eval_; } +  bool HasConverged() const { return has_converged_; } + +  void Optimize(const double& obj, +                const std::vector<double>& g, +                std::vector<double>* x) { +    assert(g.size() == x->size()); +    ++eval_; +    OptimizeImpl(obj, g, x); +    scitbx::lbfgs::traditional_convergence_test<double> converged(g.size()); +    has_converged_ = converged(&(*x)[0], &g[0]); +  } + +  void Save(std::ostream* out) const; +  void Load(std::istream* in); + protected: +  virtual void SaveImpl(std::ostream* out) const; +  virtual void LoadImpl(std::istream* in); +  virtual void OptimizeImpl(const double& obj, +                            const std::vector<double>& g, +                            std::vector<double>* x) = 0; + +  int eval_; + private: +  bool has_converged_; +}; + +class RPropOptimizer : public BatchOptimizer { + public: +  explicit RPropOptimizer(int num_vars, +                          double eta_plus = 1.2, +                          double eta_minus = 0.5, +                          double delta_0 = 0.1, +                          double delta_max = 50.0, +                          double delta_min = 1e-6) : +      prev_g_(num_vars, 0.0), +      delta_ij_(num_vars, delta_0), +      eta_plus_(eta_plus), +      eta_minus_(eta_minus), +      delta_max_(delta_max), +      delta_min_(delta_min) { +    assert(eta_plus > 1.0); +    assert(eta_minus > 0.0 && eta_minus < 1.0); +    assert(delta_max > 0.0); +    assert(delta_min > 0.0); +  } +  std::string Name() const; +  void OptimizeImpl(const double& obj, +                    const std::vector<double>& g, +                    std::vector<double>* x); +  void SaveImpl(std::ostream* out) const; +  void LoadImpl(std::istream* in); + private: +  std::vector<double> prev_g_; +  std::vector<double> delta_ij_; +  const double eta_plus_; +  const double eta_minus_; +  const double delta_max_; +  const double delta_min_; +}; + +class LBFGSOptimizer : public BatchOptimizer { + public: +  explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); +  std::string Name() const; +  void SaveImpl(std::ostream* out) const; +  void LoadImpl(std::istream* in); +  void OptimizeImpl(const double& obj, +                    const std::vector<double>& g, +                    std::vector<double>* x); + private: +  scitbx::lbfgs::minimizer<double> opt_; +}; + +#endif diff --git a/training/utils/optimize_test.cc b/training/utils/optimize_test.cc new file mode 100644 index 00000000..bff2ca03 --- /dev/null +++ b/training/utils/optimize_test.cc @@ -0,0 +1,118 @@ +#include <cassert> +#include <iostream> +#include <sstream> +#include <boost/program_options/variables_map.hpp> +#include "optimize.h" +#include "online_optimizer.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer(BatchOptimizer* opt) { +  cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + +  // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 +  // df/dx1 = 8*x1 + x2 +  // df/dx2 = 2*x2 + x1 +  // df/dx3 = 2*x3 + 6 +  vector<double> x(3); +  vector<double> g(3); +  x[0] = 8; +  x[1] = 8; +  x[2] = 8; +  double obj = 0; +  do { +    g[0] = 8 * x[0] + x[1]; +    g[1] = 2 * x[1] + x[0]; +    g[2] = 2 * x[2] + 6; +    obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; +    opt->Optimize(obj, g, &x); + +    cerr << x[0] << " " << x[1] << " " << x[2] << endl; +    cerr << "   obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; +  } while (!opt->HasConverged()); +  return obj; +} + +double TestPersistentOptimizer(BatchOptimizer* opt) { +  cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; +  // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 +  // df/dx1 = 8*x1 + x2 +  // df/dx2 = 2*x2 + x1 +  // df/dx3 = 2*x3 + 6 +  vector<double> x(3); +  vector<double> g(3); +  x[0] = 8; +  x[1] = 8; +  x[2] = 8; +  double obj = 0; +  string state; +  bool converged = false; +  while (!converged) { +    g[0] = 8 * x[0] + x[1]; +    g[1] = 2 * x[1] + x[0]; +    g[2] = 2 * x[2] + 6; +    obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + +    { +      if (state.size() > 0) { +        istringstream is(state, ios::binary); +        opt->Load(&is); +      } +      opt->Optimize(obj, g, &x); +      ostringstream os(ios::binary); opt->Save(&os); state = os.str(); + +    } + +    cerr << x[0] << " " << x[1] << " " << x[2] << endl; +    cerr << "   obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; +    converged = opt->HasConverged(); +    if (!converged) { +      // now screw up the state (should be undone by Load) +      obj += 2.0; +      g[1] = -g[2]; +      vector<double> x2 = x; +      try { +        opt->Optimize(obj, g, &x2); +      } catch (...) { } +    } +  } +  return obj; +} + +template <class O> +void TestOptimizerVariants(int num_vars) { +  O oa(num_vars); +  cerr << "-------------------------------------------------------------------------\n"; +  cerr << "TESTING: " << oa.Name() << endl; +  double o1 = TestOptimizer(&oa); +  O ob(num_vars); +  double o2 = TestPersistentOptimizer(&ob); +  if (o1 != o2) { +    cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; +    exit(1); +  } +  cerr << oa.Name() << " SUCCESS\n"; +} + +using namespace std::tr1; + +void TestOnline() { +  size_t N = 20; +  double C = 1.0; +  double eta0 = 0.2; +  std::tr1::shared_ptr<LearningRateSchedule> r(new ExponentialDecayLearningRate(N, eta0, 0.85)); +  //shared_ptr<LearningRateSchedule> r(new StandardLearningRate(N, eta0)); +  CumulativeL1OnlineOptimizer opt(r, N, C, std::vector<int>()); +  assert(r->eta(10) < r->eta(1)); +} + +int main() { +  int n = 3; +  TestOptimizerVariants<LBFGSOptimizer>(n); +  TestOptimizerVariants<RPropOptimizer>(n); +  TestOnline(); +  return 0; +} + diff --git a/training/utils/parallelize.pl b/training/utils/parallelize.pl new file mode 100755 index 00000000..4197e0e5 --- /dev/null +++ b/training/utils/parallelize.pl @@ -0,0 +1,423 @@ +#!/usr/bin/env perl + +# Author: Adam Lopez +# +# This script takes a command that processes input +# from stdin one-line-at-time, and parallelizes it +# on the cluster using David Chiang's sentserver/ +# sentclient architecture. +# +# Prerequisites: the command *must* read each line +# without waiting for subsequent lines of input +# (for instance, a command which must read all lines +# of input before processing will not work) and +# return it to the output *without* buffering +# multiple lines. + +#TODO: if -j 1, run immediately, not via sentserver?  possible differences in environment might make debugging harder + +#ANNOYANCE: if input is shorter than -j n lines, or at the very last few lines, repeatedly sleeps.  time cut down to 15s from 60s + +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } +use LocalConfig; + +use Cwd qw/ abs_path cwd getcwd /;  +use File::Temp qw/ tempfile /; +use Getopt::Long; +use IPC::Open2; +use strict; +use POSIX ":sys_wait_h"; + +use File::Basename; +my $myDir = dirname(__FILE__); +print STDERR __FILE__." -> $myDir\n"; +push(@INC, $myDir); +require "libcall.pl"; + +my $tailn=5; # +0 = concatenate all the client logs.  5 = last 5 lines +my $recycle_clients;    # spawn new clients when previous ones terminate +my $stay_alive;      # dont let server die when having zero clients +my $joblist = ""; +my $errordir=""; +my $multiline; +my $workdir = '.'; +my $numnodes = 8; +my $user = $ENV{"USER"}; +my $pmem = "9g"; +my $basep=50300; +my $randp=300; +my $tryp=50; +my $no_which; +my $no_cd; + +my $DEBUG=$ENV{DEBUG}; +print STDERR "DEBUG=$DEBUG output enabled.\n" if $DEBUG; +my $verbose = 1; +sub verbose { +    if ($verbose) { +        print STDERR @_,"\n"; +    } +} +sub debug { +    if ($DEBUG) { +        my ($package, $filename, $line) = caller; +        print STDERR "DEBUG: $filename($line): ",join(' ',@_),"\n"; +    } +} +my $is_shell_special=qr.[ \t\n\\><|&;"'`~*?{}$!()].; +my $shell_escape_in_quote=qr.[\\"\$`!].; +sub escape_shell { +    my ($arg)=@_; +    return undef unless defined $arg; +    return '""' unless $arg; +    if ($arg =~ /$is_shell_special/) { +        $arg =~ s/($shell_escape_in_quote)/\\$1/g; +        return "\"$arg\""; +    } +    return $arg; +} +sub preview_files { +    my ($l,$skipempty,$footer,$n)=@_; +    $n=$tailn unless defined $n; +    my @f=grep { ! ($skipempty && -z $_) } @$l; +    my $fn=join(' ',map {escape_shell($_)} @f); +    my $cmd="tail -n $n $fn"; +    unchecked_output("$cmd").($footer?"\nNONEMPTY FILES:\n$fn\n":""); +} +sub prefix_dirname($) { +    #like `dirname but if ends in / then return the whole thing +    local ($_)=@_; +    if (/\/$/) { +        $_; +    } else { +        s#/[^/]$##; +        $_ ? $_ : ''; +    } +} +sub ensure_final_slash($) { +    local ($_)=@_; +    m#/$# ? $_ : ($_."/"); +} +sub extend_path($$;$$) { +    my ($base,$ext,$mkdir,$baseisdir)=@_; +    if (-d $base) { +        $base.="/"; +    } else { +        my $dir; +        if ($baseisdir) { +            $dir=$base; +            $base.='/' unless $base =~ /\/$/; +        } else { +            $dir=prefix_dirname($base); +        } +        my @cmd=("/bin/mkdir","-p",$dir); +        check_call(@cmd) if $mkdir; +    } +    return $base.$ext; +} + +my $abscwd=abs_path(&getcwd); +sub print_help; + +my $use_fork; +my @pids; + +# Process command-line options +unless (GetOptions( +      "stay-alive" => \$stay_alive, +      "recycle-clients" => \$recycle_clients, +      "error-dir=s" => \$errordir, +      "multi-line" => \$multiline, +      "workdir=s" => \$workdir, +      "use-fork" => \$use_fork, +      "verbose" => \$verbose, +      "jobs=i" => \$numnodes, +      "pmem=s" => \$pmem, +        "baseport=i" => \$basep, +#       "iport=i" => \$randp, #for short name -i +        "no-which!" => \$no_which, +            "no-cd!" => \$no_cd, +            "tailn=s" => \$tailn, +) && scalar @ARGV){ +  print_help(); +    die "bad options."; +} + +my $cmd = ""; +my $prog=shift; +if ($no_which) { +    $cmd=$prog; +} else { +    $cmd=check_output("which $prog"); +    chomp $cmd; +    die "$prog not found - $cmd" unless $cmd; +} +#$cmd=abs_path($cmd); +for my $arg (@ARGV) { +    $cmd .= " ".escape_shell($arg); +} +die "Please specify a command to parallelize\n" if $cmd eq ''; + +my $cdcmd=$no_cd ? '' : ("cd ".escape_shell($abscwd)."\n"); + +my $executable = $cmd; +$executable =~ s/^\s*(\S+)($|\s.*)/$1/; +$executable=check_output("basename $executable"); +chomp $executable; + + +print STDERR "Parallelizing ($numnodes ways): $cmd\n\n"; + +# create -e dir and save .sh +use File::Temp qw/tempdir/; +unless ($errordir) { +    $errordir=tempdir("$executable.XXXXXX",CLEANUP=>1); +} +if ($errordir) { +    my $scriptfile=extend_path("$errordir/","$executable.sh",1,1); +    -d $errordir || die "should have created -e dir $errordir"; +    open SF,">",$scriptfile || die; +    print SF "$cdcmd$cmd\n"; +    close SF; +    chmod 0755,$scriptfile; +    $errordir=abs_path($errordir); +    &verbose("-e dir: $errordir"); +} + +# set cleanup handler +my @cleanup_cmds; +sub cleanup; +sub cleanup_and_die; +$SIG{INT} = "cleanup_and_die"; +$SIG{TERM} = "cleanup_and_die"; +$SIG{HUP} = "cleanup_and_die"; + +# other subs: +sub numof_live_jobs; +sub launch_job_on_node; + + +# vars +my $mydir = check_output("dirname $0"); chomp $mydir; +my $sentserver = "$mydir/sentserver"; +my $sentclient = "$mydir/sentclient"; +my $host = check_output("hostname"); +chomp $host; + + +# find open port +srand; +my $port = 50300+int(rand($randp)); +my $endp=$port+$tryp; +sub listening_port_lines { +    my $quiet=$verbose?'':'2>/dev/null'; +    return unchecked_output("netstat -a -n $quiet | grep LISTENING | grep -i tcp"); +} +my $netstat=&listening_port_lines; + +if ($verbose){ print STDERR "Testing port $port...";} + +while ($netstat=~/$port/ || &listening_port_lines=~/$port/){ +  if ($verbose){ print STDERR "port is busy\n";} +  $port++; +  if ($port > $endp){ +    die "Unable to find open port\n"; +  } +  if ($verbose){ print STDERR "Testing port $port... "; } +} +if ($verbose){ +  print STDERR "port $port is available\n"; +} + +my $key = int(rand()*1000000); + +my $multiflag = ""; +if ($multiline){ $multiflag = "-m"; print STDERR "expecting multiline output.\n"; } +my $stay_alive_flag = ""; +if ($stay_alive){ $stay_alive_flag = "--stay-alive"; print STDERR "staying alive while no clients are connected.\n"; } + +my $node_count = 0; +my $script = ""; +# fork == one thread runs the sentserver, while the +# other spawns the sentclient commands. +my $pid = fork; +if ($pid == 0) { # child +  sleep 8; # give other thread time to start sentserver +  $script = "$cdcmd$sentclient $host:$port:$key $cmd"; + +  if ($verbose){ +    print STDERR "Client script:\n====\n"; +    print STDERR $script; +    print STDERR "====\n"; +  } +  for (my $jobn=0; $jobn<$numnodes; $jobn++){ +    launch_job(); +  } +  if ($recycle_clients) { +    my $ret; +    my $livejobs; +    while (1) { +      $ret = waitpid($pid, WNOHANG); +      #print STDERR "waitpid $pid ret = $ret \n"; +      last if ($ret != 0); +      $livejobs = numof_live_jobs(); +      if ($numnodes >= $livejobs ) {  # a client terminated, OR # lines of input was less than -j +        print STDERR "num of requested nodes = $numnodes; num of currently live jobs = $livejobs; Client terminated - launching another.\n"; +        launch_job(); +      } else { +        sleep 15; +      } +    } +  } +  print STDERR "CHILD PROCESSES SPAWNED ... WAITING\n"; +  for my $p (@pids) { +    waitpid($p, 0); +  } +} else { +#  my $todo = "$sentserver -k $key $multiflag $port "; +  my $todo = "$sentserver -k $key $multiflag $port $stay_alive_flag "; +  if ($verbose){ print STDERR "Running: $todo\n"; } +  check_call($todo); +  print STDERR "Call to $sentserver returned.\n"; +  cleanup(); +  exit(0); +} + +sub numof_live_jobs { +  if ($use_fork) { +    die "not implemented"; +  } else { +    # We can probably continue decoding if the qstat error is only temporary +    my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat"))); +    return ($#livejobs + 1); +  } +} +my (@errors,@outs,@cmds); + +sub launch_job { +    if ($use_fork) { return launch_job_fork(); } +    my $errorfile = "/dev/null"; +    my $outfile = "/dev/null"; +    $node_count++; +    my $clientname = $executable; +    $clientname =~ s/^(.{4}).*$/$1/; +    $clientname = "$clientname.$node_count"; +    if ($errordir){ +      $errorfile = "$errordir/$clientname.ER"; +      $outfile = "$errordir/$clientname.OU"; +      push @errors,$errorfile; +      push @outs,$outfile; +    } +    my $todo = qsub_args($pmem) . " -N $clientname -o $outfile -e $errorfile"; +    push @cmds,$todo; + +    print STDERR "Running: $todo\n"; +    local(*QOUT, *QIN); +    open2(\*QOUT, \*QIN, $todo) or die "Failed to open2: $!"; +    print QIN $script; +    close QIN; +    while (my $jobid=<QOUT>){ +      chomp $jobid; +      if ($verbose){ print STDERR "Launched client job: $jobid"; } +      $jobid =~ s/^(\d+)(.*?)$/\1/g; +            $jobid =~ s/^Your job (\d+) .*$/\1/; +      print STDERR " short job id $jobid\n"; +            if ($verbose){ +                print STDERR "cd: $abscwd\n"; +                print STDERR "cmd: $cmd\n"; +            } +      if ($joblist == "") { $joblist = $jobid; } +      else {$joblist = $joblist . "\|" . $jobid; } +      my $cleanfn="qdel $jobid 2> /dev/null"; +      push(@cleanup_cmds, $cleanfn); +    } +    close QOUT; +} + +sub launch_job_fork { +  my $errorfile = "/dev/null"; +  my $outfile = "/dev/null"; +  $node_count++; +  my $clientname = $executable; +  $clientname =~ s/^(.{4}).*$/$1/; +  $clientname = "$clientname.$node_count"; +  if ($errordir){ +    $errorfile = "$errordir/$clientname.ER"; +    $outfile = "$errordir/$clientname.OU"; +    push @errors,$errorfile; +    push @outs,$outfile; +  } +  my $pid = fork; +  if ($pid == 0) { +    my ($fh, $scr_name) = get_temp_script(); +    print $fh $script; +    close $fh; +    my $todo = "/bin/bash -xeo pipefail $scr_name 1> $outfile 2> $errorfile"; +    print STDERR "EXEC: $todo\n"; +    my $out = check_output("$todo"); +    unlink $scr_name or warn "Failed to remove $scr_name"; +    exit 0; +  } else { +    push @pids, $pid; +  } +} + +sub get_temp_script { +  my ($fh, $filename) = tempfile( "$workdir/workXXXX", SUFFIX => '.sh'); +  return ($fh, $filename); +} + +sub cleanup_and_die { +  cleanup(); +  die "\n"; +} + +sub cleanup { +  print STDERR "Cleaning up...\n"; +  for $cmd (@cleanup_cmds){ +    print STDERR "  Cleanup command: $cmd\n"; +    eval $cmd; +  } +  print STDERR "outputs:\n",preview_files(\@outs,1),"\n"; +  print STDERR "errors:\n",preview_files(\@errors,1),"\n"; +  print STDERR "cmd:\n",$cmd,"\n"; +  print STDERR " cat $errordir/*.ER\nfor logs.\n"; +  print STDERR "Cleanup finished.\n"; +} + +sub print_help +{ +  my $name = check_output("basename $0"); chomp $name; +  print << "Help"; + +usage: $name [options] + +  Automatic black-box parallelization of commands. + +options: + +  --use-fork +    Instead of using qsub, use fork. + +  -e, --error-dir <dir> +    Retain output files from jobs in <dir>, rather +    than silently deleting them. + +  -m, --multi-line +    Expect that command may produce multiple output +    lines for a single input line.  $name makes a +    reasonable attempt to obtain all output before +    processing additional inputs.  However, use of this +    option is inherently unsafe. + +  -v, --verbose +    Print diagnostic informatoin on stderr. + +  -j, --jobs +    Number of jobs to use. + +  -p, --pmem +    pmem setting for each job. + +Help +} diff --git a/training/utils/risk.cc b/training/utils/risk.cc new file mode 100644 index 00000000..d5a12cfd --- /dev/null +++ b/training/utils/risk.cc @@ -0,0 +1,45 @@ +#include "risk.h" + +#include "prob.h" +#include "candidate_set.h" +#include "ns.h" + +using namespace std; + +namespace training { + +// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)]) +double CandidateSetRisk::operator()(const vector<double>& params, +                                    SparseVector<double>* g) const { +  prob_t z; +  for (unsigned i = 0; i < cands_.size(); ++i) { +    const prob_t u(cands_[i].fmap.dot(params), init_lnx()); +    z += u; +  } +  const double log_z = log(z); + +  SparseVector<double> exp_feats; +  if (g) { +    for (unsigned i = 0; i < cands_.size(); ++i) { +      const double log_prob = cands_[i].fmap.dot(params) - log_z; +      const double prob = exp(log_prob); +      exp_feats += cands_[i].fmap * prob; +    } +  } + +  double risk = 0; +  for (unsigned i = 0; i < cands_.size(); ++i) { +    const double log_prob = cands_[i].fmap.dot(params) - log_z; +    const double prob = exp(log_prob); +    const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) +                                                : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); +    const double r = prob * cost; +    risk += r; +    if (g) (*g) += (cands_[i].fmap - exp_feats) * r; +  } +  return risk; +} + +} + + diff --git a/training/utils/risk.h b/training/utils/risk.h new file mode 100644 index 00000000..2e8db0fb --- /dev/null +++ b/training/utils/risk.h @@ -0,0 +1,26 @@ +#ifndef _RISK_H_ +#define _RISK_H_ + +#include <vector> +#include "sparse_vector.h" +class EvaluationMetric; + +namespace training { +  class CandidateSet; + +  class CandidateSetRisk { +   public: +    explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) : +       cands_(cs), +       metric_(metric) {} +    // compute the risk (expected loss) of a CandidateSet +    // (optional) the gradient of the risk with respect to params +    double operator()(const std::vector<double>& params, +                      SparseVector<double>* g = NULL) const; +   private: +    const CandidateSet& cands_; +    const EvaluationMetric& metric_; +  }; +}; + +#endif diff --git a/training/utils/sentclient.c b/training/utils/sentclient.c new file mode 100644 index 00000000..91d994ab --- /dev/null +++ b/training/utils/sentclient.c @@ -0,0 +1,76 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <netinet/in.h> +#include <netdb.h> +#include <string.h> + +#include "sentserver.h" + +int main (int argc, char *argv[]) { +  int sock, port; +  char *s, *key; +  struct hostent *hp; +  struct sockaddr_in server; +  int errors = 0; + +  if (argc < 3) { +    fprintf(stderr, "Usage: sentclient host[:port[:key]] command [args ...]\n"); +    exit(1); +  } + +  s = strchr(argv[1], ':'); +  key = NULL; + +  if (s == NULL) { +    port = DEFAULT_PORT; +  } else { +    *s = '\0'; +    s+=1; +	/* dumb hack */ +	key = strchr(s, ':'); +	if (key != NULL){ +		*key = '\0'; +		key += 1; +	} +    port = atoi(s); +  } + +  sock = socket(AF_INET, SOCK_STREAM, 0); + +  hp = gethostbyname(argv[1]); +  if (hp == NULL) { +    fprintf(stderr, "unknown host %s\n", argv[1]); +    exit(1); +  } + +  bzero((char *)&server, sizeof(server)); +  bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); +  server.sin_family = hp->h_addrtype; +  server.sin_port = htons(port); + +  while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { +    perror("connect()"); +    sleep(1); +    errors++; +    if (errors > 5) +      exit(1); +  } + +  close(0); +  close(1); +  dup2(sock, 0); +  dup2(sock, 1); + +  if (key != NULL){ +	write(1, key, strlen(key)); +	write(1, "\n", 1); +  } + +  execvp(argv[2], argv+2); +  return 0; +} diff --git a/training/utils/sentserver.c b/training/utils/sentserver.c new file mode 100644 index 00000000..c20b4fa6 --- /dev/null +++ b/training/utils/sentserver.c @@ -0,0 +1,515 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include <string.h> +#include <stdlib.h> +#include <unistd.h> +#include <fcntl.h> +#include <stdio.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/time.h> +#include <netinet/in.h> +#include <sched.h> +#include <pthread.h> +#include <errno.h> + +#include "sentserver.h" + +#define MAX_CLIENTS 64 + +struct clientinfo { +  int s; +  struct sockaddr_in sin; +}; + +struct line { +  int id; +  char *s; +  int status; +  struct line *next; +} *head, **ptail; + +int n_sent = 0, n_received=0, n_flushed=0; + +#define STATUS_RUNNING 0 +#define STATUS_ABORTED 1 +#define STATUS_FINISHED 2 + +pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t input_mutex = PTHREAD_MUTEX_INITIALIZER; + +int n_clients = 0; +int s; +int expect_multiline_output = 0; +int log_mutex = 0; +int stay_alive = 0;		/* dont panic and die with zero clients */ + +void queue_finish(struct line *node, char *s, int fid); +char * read_line(int fd, int multiline); +void done (int code); + +struct line * queue_get(int fid) { +	struct line *cur; +	char *s, *synch; + +	if (log_mutex) fprintf(stderr, "Getting for data for fid %d\n", fid); +	if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); +	pthread_mutex_lock(&queue_mutex); + +	/* First, check for aborted sentences. */ + +	if (log_mutex) fprintf(stderr, "  Checking queue for aborted jobs (fid %d)\n", fid); +	for (cur = head; cur != NULL; cur = cur->next) { +		if (cur->status == STATUS_ABORTED) { +			cur->status = STATUS_RUNNING; + +			if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +			pthread_mutex_unlock(&queue_mutex); + +			return cur; +		} +	} +	if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +	pthread_mutex_unlock(&queue_mutex); + +	/* Otherwise, read a new one. */ +	if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); +	if (log_mutex) fprintf(stderr, "  Reading input for new data (fid %d)\n", fid); +	pthread_mutex_lock(&input_mutex); +	s = read_line(0,0); + +	while (s) { +		if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); +		pthread_mutex_lock(&queue_mutex); +		if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); +		pthread_mutex_unlock(&input_mutex); + +		cur = malloc(sizeof (struct line)); +		cur->id = n_sent; +		cur->s = s; +		cur->next = NULL; + +		*ptail = cur; +		ptail = &cur->next; + +		n_sent++; + +		if (strcmp(s,"===SYNCH===\n")==0){ +			fprintf(stderr, "Received ===SYNCH=== signal (fid %d)\n", fid); +			// Note: queue_finish calls free(cur->s). +			// Therefore we need to create a new string here. +			synch = malloc((strlen("===SYNCH===\n")+2) * sizeof (char)); +			synch = strcpy(synch, s); + +			if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +			pthread_mutex_unlock(&queue_mutex); +			queue_finish(cur, synch, fid); /* handles its own lock */ + +			if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); +			if (log_mutex) fprintf(stderr, "  Reading input for new data (fid %d)\n", fid); +			pthread_mutex_lock(&input_mutex); + +			s = read_line(0,0); +		} else { +			if (log_mutex) fprintf(stderr, "  Received new data %d (fid %d)\n", cur->id, fid); +			cur->status = STATUS_RUNNING; +			if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +			pthread_mutex_unlock(&queue_mutex); +			return cur; +		} +	} + +	if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); +	pthread_mutex_unlock(&input_mutex); +	/* Only way to reach this point: no more output */ + +	if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); +	pthread_mutex_lock(&queue_mutex); +	if (head == NULL) { +		fprintf(stderr, "Reached end of file. Exiting.\n"); +		done(0); +	} else +		ptail = NULL; /* This serves as a signal that there is no more input */ +	if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +	pthread_mutex_unlock(&queue_mutex); + +	return NULL; +} + +void queue_panic() { +	struct line *next; +	while (head && head->status == STATUS_FINISHED) { +		/* Write out finished sentences */ +		if (head->status == STATUS_FINISHED) { +			fputs(head->s, stdout); +			fflush(stdout); +		} +		/* Write out blank line for unfinished sentences */ +		if (head->status == STATUS_ABORTED) { +			fputs("\n", stdout); +			fflush(stdout); +		} +		/* By defition, there cannot be any RUNNING sentences, since +		function is only called when n_clients == 0 */ +		free(head->s); +		next = head->next; +		free(head); +		head = next; +		n_flushed++; +	} +	fclose(stdout); +	fprintf(stderr, "All clients died. Panicking, flushing completed sentences and exiting.\n"); +	done(1); +} + +void queue_abort(struct line *node, int fid) { +	if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); +	pthread_mutex_lock(&queue_mutex); +	node->status = STATUS_ABORTED; +	if (n_clients == 0) { +		if (stay_alive) { +			fprintf(stderr, "Warning! No live clients detected! Staying alive, will retry soon.\n"); +		} else { +			queue_panic(); +		} +	} +	if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +	pthread_mutex_unlock(&queue_mutex); +} + + +void queue_print() { +  struct line *cur; + +  fprintf(stderr, "  Queue\n"); + +  for (cur = head; cur != NULL; cur = cur->next) { +    switch(cur->status) { +    case STATUS_RUNNING: +      fprintf(stderr, "    %d running  ", cur->id); break; +    case STATUS_ABORTED: +      fprintf(stderr, "    %d aborted  ", cur->id); break; +    case STATUS_FINISHED: +      fprintf(stderr, "    %d finished ", cur->id); break; + +    } +	fprintf(stderr, "\n"); +    //fprintf(stderr, cur->s); +  } +} + +void queue_finish(struct line *node, char *s, int fid) { +  struct line *next; +  if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); +  pthread_mutex_lock(&queue_mutex); + +  free(node->s); +  node->s = s; +  node->status = STATUS_FINISHED; +  n_received++; + +  /* Flush out finished nodes */ +  while (head && head->status == STATUS_FINISHED) { + +    if (log_mutex) fprintf(stderr, "  Flushing finished node %d\n", head->id); + +    fputs(head->s, stdout); +    fflush(stdout); +    if (log_mutex) fprintf(stderr, "  Flushed node %d\n", head->id); +    free(head->s); + +    next = head->next; +    free(head); + +    head = next; + +    n_flushed++; + +    if (head == NULL) { /* empty queue */ +      if (ptail == NULL) { /* This can only happen if set in queue_get as signal that there is no more input. */ +        fprintf(stderr, "All sentences finished. Exiting.\n"); +        done(0); +      } else /* ptail pointed at something which was just popped off the stack -- reset to head*/ +        ptail = &head; +    } +  } + +  if (log_mutex) fprintf(stderr, "  Flushing output %d\n", head->id); +  fflush(stdout); +  fprintf(stderr, "%d sentences sent, %d sentences finished, %d sentences flushed\n", n_sent, n_received, n_flushed); + +  if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); +  pthread_mutex_unlock(&queue_mutex); + +} + +char * read_line(int fd, int multiline) { +  int size = 80; +  char errorbuf[100]; +  char *s = malloc(size+2); +  int result, errors=0; +  int i = 0; + +  result = read(fd, s+i, 1); + +  while (1) { +    if (result < 0) { +      perror("read()"); +      sprintf(errorbuf, "Error code: %d\n", errno); +      fprintf(stderr, errorbuf); +      errors++; +      if (errors > 5) { +	free(s); +	return NULL; +      } else { +	sleep(1); /* retry after delay */ +      } +    } else if (result == 0) { +      break; +    } else if (multiline==0 && s[i] == '\n') { +      break; +    } else { +      if (s[i] == '\n'){ +	/* if we've reached this point, +	   then multiline must be 1, and we're +	   going to poll the fd for an additional +	   line of data.  The basic design is to +	   run a select on the filedescriptor fd. +	   Select will return under two conditions: +	   if there is data on the fd, or if a +	   timeout is reached.  We'll select on this +	   fd.  If select returns because there's data +	   ready, keep going; else assume there's no +	   more and return the data we already have. +	*/ + +	fd_set set; +	FD_ZERO(&set); +	FD_SET(fd, &set); + +	struct timeval timeout; +	timeout.tv_sec = 3; // number of seconds for timeout +	timeout.tv_usec = 0; + +	int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); +	if (ready<1){ +	  break; // no more data, stop looping +	} +      } +      i++; + +      if (i == size) { +	size = size*2; +	s = realloc(s, size+2); +      } +    } + +    result = read(fd, s+i, 1); +  } + +  if (result == 0 && i == 0) { /* end of file */ +    free(s); +    return NULL; +  } + +  s[i] = '\n'; +  s[i+1] = '\0'; + +  return s; +} + +void * new_client(void *arg) { +  struct clientinfo *client = (struct clientinfo *)arg; +  struct line *cur; +  int result; +  char *s; +  char errorbuf[100]; + +  pthread_mutex_lock(&clients_mutex); +  n_clients++; +  pthread_mutex_unlock(&clients_mutex); + +  fprintf(stderr, "Client connected (%d connected)\n", n_clients); + +  for (;;) { + +    cur = queue_get(client->s); + +    if (cur) { +      /* fprintf(stderr, "Sending to client: %s", cur->s); */ +      fprintf(stderr, "Sending data %d to client (fid %d)\n", cur->id, client->s); +      result = write(client->s, cur->s, strlen(cur->s)); +      if (result < strlen(cur->s)){ +        perror("write()"); +        sprintf(errorbuf, "Error code: %d\n", errno); +        fprintf(stderr, errorbuf); + +        pthread_mutex_lock(&clients_mutex); +        n_clients--; +        pthread_mutex_unlock(&clients_mutex); + +        fprintf(stderr, "Client died (%d connected)\n", n_clients); +        queue_abort(cur, client->s); + +        close(client->s); +        free(client); + +        pthread_exit(NULL); +      } +    } else { +      close(client->s); +      pthread_mutex_lock(&clients_mutex); +      n_clients--; +      pthread_mutex_unlock(&clients_mutex); +      fprintf(stderr, "Client dismissed (%d connected)\n", n_clients); +      pthread_exit(NULL); +    } + +    s = read_line(client->s,expect_multiline_output); +    if (s) { +      /* fprintf(stderr, "Client (fid %d) returned: %s", client->s, s); */ +      fprintf(stderr, "Client (fid %d) returned data %d\n", client->s, cur->id); +//      queue_print(); +      queue_finish(cur, s, client->s); +    } else { +      pthread_mutex_lock(&clients_mutex); +      n_clients--; +      pthread_mutex_unlock(&clients_mutex); + +      fprintf(stderr, "Client died (%d connected)\n", n_clients); +      queue_abort(cur, client->s); + +      close(client->s); +      free(client); + +      pthread_exit(NULL); +    } + +  } +  return 0; +} + +void done (int code) { +  close(s); +  exit(code); +} + + + +int main (int argc, char *argv[]) { +  struct sockaddr_in sin, from; +  int g; +  socklen_t len; +  struct clientinfo *client; +  int port; +  int opt; +  int errors = 0; +  int argi; +  char *key = NULL, *client_key; +  int use_key = 0; +  /* the key stuff here doesn't provide any +  real measure of security, it's mainly to keep +  jobs from bumping into each other.  */ + +  pthread_t tid; +  port = DEFAULT_PORT; + +  for (argi=1; argi < argc; argi++){ +    if (strcmp(argv[argi], "-m")==0){ +      expect_multiline_output = 1; +    } else if (strcmp(argv[argi], "-k")==0){ +      argi++; +      if (argi == argc){ +      	fprintf(stderr, "Key must be specified after -k\n"); +      	exit(1); +      } +      key = argv[argi]; +      use_key = 1; +    } else if (strcmp(argv[argi], "--stay-alive")==0){ +      stay_alive = 1;    /* dont panic and die with zero clients */ +    } else { +      port = atoi(argv[argi]); +    } +  } + +  /* Initialize data structures */ +  head = NULL; +  ptail = &head; + +  /* Set up listener */ +  s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); +  opt = 1; +  setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + +  sin.sin_family = AF_INET; +  sin.sin_addr.s_addr = htonl(INADDR_ANY); +  sin.sin_port = htons(port); +  while (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) { +	perror("bind()"); +	sleep(1); +	errors++; +	if (errors > 100) +	  exit(1); +  } + +  len = sizeof(sin); +  getsockname(s, (struct sockaddr *) &sin, &len); + +  fprintf(stderr, "Listening on port %hu\n", ntohs(sin.sin_port)); + +  while (listen(s, MAX_CLIENTS) < 0) { +	perror("listen()"); +	sleep(1); +	errors++; +	if (errors > 100) +	  exit(1); +  } + +  for (;;) { +    len = sizeof(from); +    g = accept(s, (struct sockaddr *)&from, &len); +    if (g < 0) { +      perror("accept()"); +      sleep(1); +      continue; +    } +    client = malloc(sizeof(struct clientinfo)); +    client->s = g; +    bcopy(&from, &client->sin, len); + +	if (use_key){ +		fd_set set; +		FD_ZERO(&set); +		FD_SET(client->s, &set); + +		struct timeval timeout; +		timeout.tv_sec = 3; // number of seconds for timeout +		timeout.tv_usec = 0; + +		int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); +		if (ready<1){ +			fprintf(stderr, "Prospective client failed to respond with correct key.\n"); +			close(client->s); +			free(client); +		} else { +			client_key = read_line(client->s,0); +			client_key[strlen(client_key)-1]='\0'; /* chop trailing newline */ +			if (strcmp(key, client_key)==0){ +				pthread_create(&tid, NULL, new_client, client); +			} else { +				fprintf(stderr, "Prospective client failed to respond with correct key.\n"); +				close(client->s); +				free(client); +			} +			free(client_key); +		} +	} else { +		pthread_create(&tid, NULL, new_client, client); +	} +  } + +} + + + diff --git a/training/utils/sentserver.h b/training/utils/sentserver.h new file mode 100644 index 00000000..cd17a546 --- /dev/null +++ b/training/utils/sentserver.h @@ -0,0 +1,6 @@ +#ifndef SENTSERVER_H +#define SENTSERVER_H + +#define DEFAULT_PORT 50000 + +#endif | 
