diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/Makefile.am | 50 | ||||
-rw-r--r-- | training/atools.cc | 369 | ||||
-rwxr-xr-x | training/cluster-em.pl | 114 | ||||
-rwxr-xr-x | training/cluster-ptrain.pl | 206 | ||||
-rw-r--r-- | training/collapse_weights.cc | 102 | ||||
-rwxr-xr-x | training/dep-reorder/conll2reordering-forest.pl | 65 | ||||
-rw-r--r-- | training/dep-reorder/george.conll | 4 | ||||
-rwxr-xr-x | training/dep-reorder/scripts/conll2simplecfg.pl | 57 | ||||
-rw-r--r-- | training/grammar_convert.cc | 319 | ||||
-rw-r--r-- | training/lbfgs.h | 1459 | ||||
-rw-r--r-- | training/lbfgs_test.cc | 112 | ||||
-rwxr-xr-x | training/make-lexcrf-grammar.pl | 285 | ||||
-rw-r--r-- | training/model1.cc | 103 | ||||
-rw-r--r-- | training/mr_em_adapted_reduce.cc | 194 | ||||
-rw-r--r-- | training/mr_em_map_adapter.cc | 160 | ||||
-rw-r--r-- | training/mr_optimize_reduce.cc | 243 | ||||
-rw-r--r-- | training/mr_reduce_to_weights.cc | 109 | ||||
-rw-r--r-- | training/optimize.cc | 114 | ||||
-rw-r--r-- | training/optimize.h | 104 | ||||
-rw-r--r-- | training/optimize_test.cc | 105 | ||||
-rw-r--r-- | training/plftools.cc | 93 |
21 files changed, 4367 insertions, 0 deletions
diff --git a/training/Makefile.am b/training/Makefile.am new file mode 100644 index 00000000..490de774 --- /dev/null +++ b/training/Makefile.am @@ -0,0 +1,50 @@ +bin_PROGRAMS = \ + model1 \ + mr_em_map_adapter \ + mr_em_adapted_reduce \ + mr_reduce_to_weights \ + mr_optimize_reduce \ + grammar_convert \ + atools \ + plftools \ + collapse_weights + +noinst_PROGRAMS = \ + lbfgs_test \ + optimize_test + +atools_SOURCES = atools.cc +atools_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +model1_SOURCES = model1.cc +model1_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +grammar_convert_SOURCES = grammar_convert.cc +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +optimize_test_SOURCES = optimize_test.cc optimize.cc +optimize_test_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +collapse_weights_SOURCES = collapse_weights.cc +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +lbfgs_test_SOURCES = lbfgs_test.cc +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc optimize.cc +mr_optimize_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +plftools_SOURCES = plftools.cc +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder + diff --git a/training/atools.cc b/training/atools.cc new file mode 100644 index 00000000..af62804d --- /dev/null +++ b/training/atools.cc @@ -0,0 +1,369 @@ +#include <iostream> +#include <sstream> +#include <vector> + +#include <queue> +#include <map> +#include <boost/program_options.hpp> +#include <boost/shared_ptr.hpp> + +#include "filelib.h" +#include "aligner.h" + +namespace po = boost::program_options; +using namespace std; +using boost::shared_ptr; + +struct Command { + virtual ~Command() {} + virtual string Name() const = 0; + + // returns 1 for alignment grid output [default] + // returns 2 if Summary() should be called [for AER, etc] + virtual int Result() const { return 1; } + + virtual bool RequiresTwoOperands() const { return true; } + virtual void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) = 0; + void EnsureSize(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + x->resize(max(a.width(), b.width()), max(a.height(), b.width())); + } + static bool Safe(const Array2D<bool>& a, int i, int j) { + if (i >= 0 && j >= 0 && i < a.width() && j < a.height()) + return a(i,j); + else + return false; + } + virtual void Summary() { assert(!"Summary should have been overridden"); } +}; + +// compute fmeasure, second alignment is reference, first is hyp +struct FMeasureCommand : public Command { + FMeasureCommand() : matches(), num_predicted(), num_in_ref() {} + int Result() const { return 2; } + string Name() const { return "fmeasure"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D<bool>& hyp, const Array2D<bool>& ref, Array2D<bool>* x) { + (void) x; // AER just computes statistics, not an alignment + int i_len = ref.width(); + int j_len = ref.height(); + for (int i = 0; i < i_len; ++i) { + for (int j = 0; j < j_len; ++j) { + if (ref(i,j)) { + ++num_in_ref; + if (Safe(hyp, i, j)) ++matches; + } + } + } + for (int i = 0; i < hyp.width(); ++i) + for (int j = 0; j < hyp.height(); ++j) + if (hyp(i,j)) ++num_predicted; + } + void Summary() { + if (num_predicted == 0 || num_in_ref == 0) { + cerr << "Insufficient statistics to compute f-measure!\n"; + abort(); + } + const double prec = static_cast<double>(matches) / num_predicted; + const double rec = static_cast<double>(matches) / num_in_ref; + cout << "P: " << prec << endl; + cout << "R: " << rec << endl; + const double f = (2.0 * prec * rec) / (rec + prec); + cout << "F: " << f << endl; + } + int matches; + int num_predicted; + int num_in_ref; +}; + +struct DisplayCommand : public Command { + string Name() const { return "display"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D<bool>& in, const Array2D<bool>¬_used, Array2D<bool>* x) { + *x = in; + cout << *x << endl; + } +}; + +struct ConvertCommand : public Command { + string Name() const { return "convert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D<bool>& in, const Array2D<bool>¬_used, Array2D<bool>* x) { + *x = in; + } +}; + +struct InvertCommand : public Command { + string Name() const { return "invert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D<bool>& in, const Array2D<bool>¬_used, Array2D<bool>* x) { + Array2D<bool>& res = *x; + res.resize(in.height(), in.width()); + for (int i = 0; i < in.height(); ++i) + for (int j = 0; j < in.width(); ++j) + res(i, j) = in(j, i); + } +}; + +struct IntersectCommand : public Command { + string Name() const { return "intersect"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + EnsureSize(a, b, x); + Array2D<bool>& res = *x; + for (int i = 0; i < a.width(); ++i) + for (int j = 0; j < a.height(); ++j) + res(i, j) = Safe(a, i, j) && Safe(b, i, j); + } +}; + +struct UnionCommand : public Command { + string Name() const { return "union"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + EnsureSize(a, b, x); + Array2D<bool>& res = *x; + for (int i = 0; i < res.width(); ++i) + for (int j = 0; j < res.height(); ++j) + res(i, j) = Safe(a, i, j) || Safe(b, i, j); + } +}; + +struct RefineCommand : public Command { + RefineCommand() { + neighbors_.push_back(make_pair(1,0)); + neighbors_.push_back(make_pair(-1,0)); + neighbors_.push_back(make_pair(0,1)); + neighbors_.push_back(make_pair(0,-1)); + } + bool RequiresTwoOperands() const { return true; } + + void Align(int i, int j) { + res_(i, j) = true; + is_i_aligned_[i] = true; + is_j_aligned_[j] = true; + } + + bool IsNeighborAligned(int i, int j) const { + for (int k = 0; k < neighbors_.size(); ++k) { + const int di = neighbors_[k].first; + const int dj = neighbors_[k].second; + if (Safe(res_, i + di, j + dj)) + return true; + } + return false; + } + + bool IsNeitherAligned(int i, int j) const { + return !(is_i_aligned_[i] || is_j_aligned_[j]); + } + + bool IsOneOrBothUnaligned(int i, int j) const { + return !(is_i_aligned_[i] && is_j_aligned_[j]); + } + + bool KoehnAligned(int i, int j) const { + return IsOneOrBothUnaligned(i, j) && IsNeighborAligned(i, j); + } + + typedef bool (RefineCommand::*Predicate)(int i, int j) const; + + protected: + void InitRefine( + const Array2D<bool>& a, + const Array2D<bool>& b) { + res_.clear(); + EnsureSize(a, b, &res_); + in_.clear(); un_.clear(); is_i_aligned_.clear(); is_j_aligned_.clear(); + EnsureSize(a, b, &in_); + EnsureSize(a, b, &un_); + is_i_aligned_.resize(res_.width(), false); + is_j_aligned_.resize(res_.height(), false); + for (int i = 0; i < in_.width(); ++i) + for (int j = 0; j < in_.height(); ++j) { + un_(i, j) = Safe(a, i, j) || Safe(b, i, j); + in_(i, j) = Safe(a, i, j) && Safe(b, i, j); + if (in_(i, j)) Align(i, j); + } + } + // "grow" the resulting alignment using the points in adds + // if they match the constraints determined by pred + void Grow(Predicate pred, bool idempotent, const Array2D<bool>& adds) { + if (idempotent) { + for (int i = 0; i < adds.width(); ++i) + for (int j = 0; j < adds.height(); ++j) { + if (adds(i, j) && !res_(i, j) && + (this->*pred)(i, j)) Align(i, j); + } + return; + } + set<pair<int, int> > p; + for (int i = 0; i < adds.width(); ++i) + for (int j = 0; j < adds.height(); ++j) + if (adds(i, j) && !res_(i, j)) + p.insert(make_pair(i, j)); + bool keep_going = !p.empty(); + while (keep_going) { + keep_going = false; + for (set<pair<int, int> >::iterator pi = p.begin(); + pi != p.end(); ++pi) { + if ((this->*pred)(pi->first, pi->second)) { + Align(pi->first, pi->second); + p.erase(pi); + keep_going = true; + } + } + } + } + Array2D<bool> res_; // refined alignment + Array2D<bool> in_; // intersection alignment + Array2D<bool> un_; // union alignment + vector<bool> is_i_aligned_; + vector<bool> is_j_aligned_; + vector<pair<int,int> > neighbors_; +}; + +struct DiagCommand : public RefineCommand { + DiagCommand() { + neighbors_.push_back(make_pair(1,1)); + neighbors_.push_back(make_pair(-1,1)); + neighbors_.push_back(make_pair(1,-1)); + neighbors_.push_back(make_pair(-1,-1)); + } +}; + +struct GDCommand : public DiagCommand { + string Name() const { return "grow-diag"; } + void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + InitRefine(a, b); + Grow(&RefineCommand::KoehnAligned, false, un_); + *x = res_; + } +}; + +struct GDFCommand : public DiagCommand { + string Name() const { return "grow-diag-final"; } + void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + InitRefine(a, b); + Grow(&RefineCommand::KoehnAligned, false, un_); + Grow(&RefineCommand::IsOneOrBothUnaligned, true, a); + Grow(&RefineCommand::IsOneOrBothUnaligned, true, b); + *x = res_; + } +}; + +struct GDFACommand : public DiagCommand { + string Name() const { return "grow-diag-final-and"; } + void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) { + InitRefine(a, b); + Grow(&RefineCommand::KoehnAligned, false, un_); + Grow(&RefineCommand::IsNeitherAligned, true, a); + Grow(&RefineCommand::IsNeitherAligned, true, b); + *x = res_; + } +}; + +map<string, boost::shared_ptr<Command> > commands; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + ostringstream os; + os << "[REQ] Operation to perform:"; + for (map<string, boost::shared_ptr<Command> >::iterator it = commands.begin(); + it != commands.end(); ++it) { + os << ' ' << it->first; + } + string cstr = os.str(); + opts.add_options() + ("input_1,i", po::value<string>(), "[REQ] Alignment 1 file, - for STDIN") + ("input_2,j", po::value<string>(), "[OPT] Alignment 2 file, - for STDIN") + ("command,c", po::value<string>()->default_value("convert"), cstr.c_str()) + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input_1") == 0 || conf->count("command") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } + const string cmd = (*conf)["command"].as<string>(); + if (commands.count(cmd) == 0) { + cerr << "Don't understand command: " << cmd << endl; + exit(1); + } + if (commands[cmd]->RequiresTwoOperands()) { + if (conf->count("input_2") == 0) { + cerr << "Command '" << cmd << "' requires two alignment files\n"; + exit(1); + } + if ((*conf)["input_1"].as<string>() == "-" && (*conf)["input_2"].as<string>() == "-") { + cerr << "Both inputs cannot be STDIN\n"; + exit(1); + } + } else { + if (conf->count("input_2") != 0) { + cerr << "Command '" << cmd << "' requires only one alignment file\n"; + exit(1); + } + } +} + +template<class C> static void AddCommand() { + C* c = new C; + commands[c->Name()].reset(c); +} + +int main(int argc, char **argv) { + AddCommand<ConvertCommand>(); + AddCommand<DisplayCommand>(); + AddCommand<InvertCommand>(); + AddCommand<IntersectCommand>(); + AddCommand<UnionCommand>(); + AddCommand<GDCommand>(); + AddCommand<GDFCommand>(); + AddCommand<GDFACommand>(); + AddCommand<FMeasureCommand>(); + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + Command& cmd = *commands[conf["command"].as<string>()]; + boost::shared_ptr<ReadFile> rf1(new ReadFile(conf["input_1"].as<string>())); + boost::shared_ptr<ReadFile> rf2; + if (cmd.RequiresTwoOperands()) + rf2.reset(new ReadFile(conf["input_2"].as<string>())); + istream* in1 = rf1->stream(); + istream* in2 = NULL; + if (rf2) in2 = rf2->stream(); + while(*in1) { + string line1; + string line2; + getline(*in1, line1); + if (in2) { + getline(*in2, line2); + if ((*in1 && !*in2) || (*in2 && !*in1)) { + cerr << "Mismatched number of lines!\n"; + exit(1); + } + } + if (line1.empty() && !*in1) break; + shared_ptr<Array2D<bool> > out(new Array2D<bool>); + shared_ptr<Array2D<bool> > a1 = AlignerTools::ReadPharaohAlignmentGrid(line1); + if (in2) { + shared_ptr<Array2D<bool> > a2 = AlignerTools::ReadPharaohAlignmentGrid(line2); + cmd.Apply(*a1, *a2, out.get()); + } else { + Array2D<bool> dummy; + cmd.Apply(*a1, dummy, out.get()); + } + + if (cmd.Result() == 1) { + AlignerTools::SerializePharaohFormat(*out, &cout); + } + } + if (cmd.Result() == 2) + cmd.Summary(); + return 0; +} + diff --git a/training/cluster-em.pl b/training/cluster-em.pl new file mode 100755 index 00000000..267ab642 --- /dev/null +++ b/training/cluster-em.pl @@ -0,0 +1,114 @@ +#!/usr/bin/perl -w + +use strict; +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; } +use Getopt::Long; +my $parallel = 0; + +my $CWD=`pwd`; chomp $CWD; +my $BIN_DIR = "$CWD/.."; +my $REDUCER = "$BIN_DIR/training/mr_em_adapted_reduce"; +my $REDUCE2WEIGHTS = "$BIN_DIR/training/mr_reduce_to_weights"; +my $ADAPTER = "$BIN_DIR/training/mr_em_map_adapter"; +my $DECODER = "$BIN_DIR/decoder/cdec"; +my $COMBINER_CACHE_SIZE = 10000000; +my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl"; +die "Can't find $REDUCER" unless -f $REDUCER; +die "Can't execute $REDUCER" unless -x $REDUCER; +die "Can't find $REDUCE2WEIGHTS" unless -f $REDUCE2WEIGHTS; +die "Can't execute $REDUCE2WEIGHTS" unless -x $REDUCE2WEIGHTS; +die "Can't find $ADAPTER" unless -f $ADAPTER; +die "Can't execute $ADAPTER" unless -x $ADAPTER; +die "Can't find $DECODER" unless -f $DECODER; +die "Can't execute $DECODER" unless -x $DECODER; +my $restart = ''; +if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; } + +die "Usage: $0 [--restart] training.corpus cdec.ini\n" unless (scalar @ARGV == 2); + +my $training_corpus = shift @ARGV; +my $config = shift @ARGV; +my $pmem="2500mb"; +my $nodes = 40; +my $max_iteration = 1000; +my $CFLAG = "-C 1"; +if ($parallel) { + die "Can't find $PARALLEL" unless -f $PARALLEL; + die "Can't execute $PARALLEL" unless -x $PARALLEL; +} else { $CFLAG = "-C 500"; } + +my $initial_weights = ''; + +print STDERR <<EOT; +EM TRAIN CONFIGURATION INFORMATION + + Config file: $config + Training corpus: $training_corpus + Initial weights: $initial_weights + Decoder memory: $pmem + Nodes requested: $nodes + Max iterations: $max_iteration + restart: $restart +EOT + +my $nodelist="1"; +for (my $i=1; $i<$nodes; $i++) { $nodelist .= " 1"; } +my $iter = 1; + +my $dir = "$CWD/emtrain"; +if ($restart) { + die "$dir doesn't exist, but --restart specified!\n" unless -d $dir; + my $o = `ls -t $dir/weights.*`; + my ($a, @x) = split /\n/, $o; + if ($a =~ /weights.(\d+)\.gz$/) { + $iter = $1; + } else { + die "Unexpected file: $a!\n"; + } + print STDERR "Restarting at iteration $iter\n"; +} else { + die "$dir already exists!\n" if -e $dir; + mkdir $dir or die "Can't create $dir: $!"; + + if ($initial_weights) { + unless ($initial_weights =~ /\.gz$/) { + `cp $initial_weights $dir/weights.1`; + `gzip -9 $dir/weights.1`; + } else { + `cp $initial_weights $dir/weights.1.gz`; + } + } +} + +while ($iter < $max_iteration) { + my $cur_time = `date`; chomp $cur_time; + print STDERR "\nStarting iteration $iter...\n"; + print STDERR " time: $cur_time\n"; + my $start = time; + my $next_iter = $iter + 1; + my $WSTR = "-w $dir/weights.$iter.gz"; + if ($iter == 1) { $WSTR = ''; } + my $dec_cmd="$DECODER --feature_expectations -c $config $WSTR $CFLAG < $training_corpus 2> $dir/deco.log.$iter"; + my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- "; + my $cmd = ""; + if ($parallel) { $cmd = $pcmd; } + $cmd .= "$dec_cmd"; + $cmd .= "| $ADAPTER | sort -k1 | $REDUCER | $REDUCE2WEIGHTS -o $dir/weights.$next_iter.gz"; + print STDERR "EXECUTING: $cmd\n"; + my $result = `$cmd`; + if ($? != 0) { + die "Error running iteration $iter: $!"; + } + chomp $result; + my $end = time; + my $diff = ($end - $start); + print STDERR " ITERATION $iter TOOK $diff SECONDS\n"; + $iter = $next_iter; + if ($result =~ /1$/) { + print STDERR "Training converged.\n"; + last; + } +} + +print "FINAL WEIGHTS: $dir/weights.$iter\n"; + diff --git a/training/cluster-ptrain.pl b/training/cluster-ptrain.pl new file mode 100755 index 00000000..03122df9 --- /dev/null +++ b/training/cluster-ptrain.pl @@ -0,0 +1,206 @@ +#!/usr/bin/perl -w + +use strict; +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path getcwd /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR; } +use Getopt::Long; + +my $MAX_ITER_ATTEMPTS = 5; # number of times to retry a failed function evaluation +my $CWD=getcwd(); +my $OPTIMIZER = "$SCRIPT_DIR/mr_optimize_reduce"; +my $DECODER = "$SCRIPT_DIR/../decoder/cdec"; +my $COMBINER_CACHE_SIZE = 150; +# This is a hack to run this on a weird cluster, +# eventually, I'll provide Hadoop scripts. +my $PARALLEL = "/chomes/redpony/svn-trunk/sa-utils/parallelize.pl"; +die "Can't find $OPTIMIZER" unless -f $OPTIMIZER; +die "Can't execute $OPTIMIZER" unless -x $OPTIMIZER; +my $restart = ''; +if ($ARGV[0] && $ARGV[0] eq '--restart') { shift @ARGV; $restart = 1; } + +my $pmem="2500mb"; +my $nodes = 1; +my $max_iteration = 1000; +my $PRIOR_FLAG = ""; +my $parallel = 1; +my $CFLAG = "-C 1"; +my $LOCAL; +my $DISTRIBUTED; +my $PRIOR; +my $OALG = "lbfgs"; +my $sigsq = 1; +my $means_file; +my $mem_buffers = 20; +my $RESTART_IF_NECESSARY; +GetOptions("cdec=s" => \$DECODER, + "distributed" => \$DISTRIBUTED, + "sigma_squared=f" => \$sigsq, + "lbfgs_memory_buffers=i" => \$mem_buffers, + "max_iteration=i" => \$max_iteration, + "means=s" => \$means_file, + "optimizer=s" => \$OALG, + "gaussian_prior" => \$PRIOR, + "restart_if_necessary" => \$RESTART_IF_NECESSARY, + "jobs=i" => \$nodes, + "pmem=s" => \$pmem + ) or usage(); +usage() unless scalar @ARGV==3; +my $config_file = shift @ARGV; +my $training_corpus = shift @ARGV; +my $initial_weights = shift @ARGV; +unless ($DISTRIBUTED) { $LOCAL = 1; } +die "Can't find $config_file" unless -f $config_file; +die "Can't find $DECODER" unless -f $DECODER; +die "Can't execute $DECODER" unless -x $DECODER; +if ($LOCAL) { print STDERR "Will run LOCALLY.\n"; $parallel = 0; } +if ($PRIOR) { + $PRIOR_FLAG="-p --sigma_squared $sigsq"; + if ($means_file) { $PRIOR_FLAG .= " -u $means_file"; } +} + +if ($parallel) { + die "Can't find $PARALLEL" unless -f $PARALLEL; + die "Can't execute $PARALLEL" unless -x $PARALLEL; +} +unless ($parallel) { $CFLAG = "-C 500"; } +unless ($config_file =~ /^\//) { $config_file = $CWD . '/' . $config_file; } +my $clines = num_lines($training_corpus); +my $dir = "$CWD/ptrain"; + +if ($RESTART_IF_NECESSARY && -d $dir) { + $restart = 1; +} + +print STDERR <<EOT; +PTRAIN CONFIGURATION INFORMATION + + Config file: $config_file + Training corpus: $training_corpus + Corpus size: $clines + Initial weights: $initial_weights + Decoder memory: $pmem + Max iterations: $max_iteration + Optimizer: $OALG + Jobs requested: $nodes + prior?: $PRIOR_FLAG + restart?: $restart +EOT + +if ($OALG) { $OALG="-m $OALG"; } + +my $nodelist="1"; +for (my $i=1; $i<$nodes; $i++) { $nodelist .= " 1"; } +my $iter = 1; + +if ($restart) { + die "$dir doesn't exist, but --restart specified!\n" unless -d $dir; + my $o = `ls -t $dir/weights.*`; + my ($a, @x) = split /\n/, $o; + if ($a =~ /weights.(\d+)\.gz$/) { + $iter = $1; + } else { + die "Unexpected file: $a!\n"; + } + print STDERR "Restarting at iteration $iter\n"; +} else { + die "$dir already exists!\n" if -e $dir; + mkdir $dir or die "Can't create $dir: $!"; + + unless ($initial_weights =~ /\.gz$/) { + `cp $initial_weights $dir/weights.1`; + `gzip -9 $dir/weights.1`; + } else { + `cp $initial_weights $dir/weights.1.gz`; + } + open T, "<$training_corpus" or die "Can't read $training_corpus: $!"; + open TO, ">$dir/training.in"; + my $lc = 0; + while(<T>) { + chomp; + s/^\s+//; + s/\s+$//; + die "Expected A ||| B in input file" unless / \|\|\| /; + print TO "<seg id=\"$lc\">$_</seg>\n"; + $lc++; + } + close T; + close TO; +} +$training_corpus = "$dir/training.in"; + +my $iter_attempts = 1; +while ($iter < $max_iteration) { + my $cur_time = `date`; chomp $cur_time; + print STDERR "\nStarting iteration $iter...\n"; + print STDERR " time: $cur_time\n"; + my $start = time; + my $next_iter = $iter + 1; + my $dec_cmd="$DECODER -G $CFLAG -c $config_file -w $dir/weights.$iter.gz < $training_corpus 2> $dir/deco.log.$iter"; + my $opt_cmd = "$OPTIMIZER $PRIOR_FLAG -M $mem_buffers $OALG -s $dir/opt.state -i $dir/weights.$iter.gz -o $dir/weights.$next_iter.gz"; + my $pcmd = "$PARALLEL -e $dir/err -p $pmem --nodelist \"$nodelist\" -- "; + my $cmd = ""; + if ($parallel) { $cmd = $pcmd; } + $cmd .= "$dec_cmd | $opt_cmd"; + + print STDERR "EXECUTING: $cmd\n"; + my $result = `$cmd`; + my $exit_code = $? >> 8; + if ($exit_code == 99) { + $iter_attempts++; + if ($iter_attempts > $MAX_ITER_ATTEMPTS) { + die "Received restart request $iter_attempts times from optimizer, giving up\n"; + } + print STDERR "Function evaluation failed, retrying (attempt $iter_attempts)\n"; + next; + } + if ($? != 0) { + die "Error running iteration $iter: $!"; + } + chomp $result; + my $end = time; + my $diff = ($end - $start); + print STDERR " ITERATION $iter TOOK $diff SECONDS\n"; + $iter = $next_iter; + if ($result =~ /1$/) { + print STDERR "Training converged.\n"; + last; + } + $iter_attempts = 1; +} + +print "FINAL WEIGHTS: $dir/weights.$iter\n"; +`mv $dir/weights.$iter.gz $dir/weights.final.gz`; + +sub usage { + die <<EOT; + +Usage: $0 [OPTIONS] cdec.ini training.corpus weights.init + + Options: + + --distributed Parallelize function evaluation + --jobs N Number of jobs to use + --cdec PATH Path to cdec binary + --optimize OPT lbfgs, rprop, sgd + --gaussian_prior add Gaussian prior + --means FILE if you want means other than 0 + --sigma_squared S variance on prior + --pmem MEM Memory required for decoder + --lbfgs_memory_buffers Number of buffers to use + with LBFGS optimizer + +EOT +} + +sub num_lines { + my $file = shift; + my $fh; + if ($file=~ /\.gz$/) { + open $fh, "zcat $file|" or die "Couldn't fork zcat $file: $!"; + } else { + open $fh, "<$file" or die "Couldn't read $file: $!"; + } + my $lines = 0; + while(<$fh>) { $lines++; } + close $fh; + return $lines; +} diff --git a/training/collapse_weights.cc b/training/collapse_weights.cc new file mode 100644 index 00000000..5e0f3f72 --- /dev/null +++ b/training/collapse_weights.cc @@ -0,0 +1,102 @@ +#include <iostream> +#include <fstream> +#include <tr1/unordered_map> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include <boost/functional/hash.hpp> + +#include "prob.h" +#include "filelib.h" +#include "trule.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +typedef std::tr1::unordered_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > > MarginalMap; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("grammar,g", po::value<string>(), "Grammar file") + ("weights,w", po::value<string>(), "Weights file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + const string cfg = (*conf)["config"].as<string>(); + cerr << "Configuration file: " << cfg << endl; + ifstream config(cfg.c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string wfile = conf["weights"].as<string>(); + const string gfile = conf["grammar"].as<string>(); + Weights wm; + wm.InitFromFile(wfile); + vector<double> w; + wm.InitVector(&w); + MarginalMap e_tots; + MarginalMap f_tots; + prob_t tot; + { + ReadFile rf(gfile); + assert(*rf.stream()); + istream& in = *rf.stream(); + cerr << "Computing marginals...\n"; + int lc = 0; + while(in) { + string line; + getline(in, line); + ++lc; + if (line.empty()) continue; + TRule tr(line, true); + if (tr.GetFeatureValues().empty()) + cerr << "Line " << lc << ": empty features - may introduce bias\n"; + prob_t prob; + prob.logeq(tr.GetFeatureValues().dot(w)); + e_tots[tr.e_] += prob; + f_tots[tr.f_] += prob; + tot += prob; + } + } + bool normalized = (fabs(log(tot)) < 0.001); + cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl; + ReadFile rf(gfile); + istream&in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + TRule tr(line, true); + const double lp = tr.GetFeatureValues().dot(w); + if (isinf(lp)) { continue; } + tr.scores_.clear(); + + cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot); + if (!normalized) { + cout << ";ZF_and_E=" << lp; + } + cout << ";F_given_E=" << lp - log(e_tots[tr.e_]) + << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl; + } + return 0; +} + diff --git a/training/dep-reorder/conll2reordering-forest.pl b/training/dep-reorder/conll2reordering-forest.pl new file mode 100755 index 00000000..3cd226be --- /dev/null +++ b/training/dep-reorder/conll2reordering-forest.pl @@ -0,0 +1,65 @@ +#!/usr/bin/perl -w +use strict; + +my $script_dir; BEGIN { use Cwd qw/ abs_path cwd /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, $script_dir; } +my $FIRST_CONV = "$script_dir/scripts/conll2simplecfg.pl"; +my $CDEC = "$script_dir/../../decoder/cdec"; + +our $tfile1 = "grammar1.$$"; +our $tfile2 = "text.$$"; + +die "Usage: $0 parses.conll\n" unless scalar @ARGV == 1; +open C, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!"; + +END { unlink $tfile1; unlink "$tfile1.cfg"; unlink $tfile2; } + +my $first = 1; +open T, ">$tfile1" or die "Can't write $tfile1: $!"; +my $lc = 0; +my $flag = 0; +my @words = (); +while(<C>) { + print T; + chomp; + if (/^$/) { + if ($first) { $first = undef; } else { if ($flag) { print "\n"; $flag = 0; } } + $first = undef; + close T; + open SO, ">$tfile2" or die "Can't write $tfile2: $!"; + print SO "@words\n"; + close SO; + @words=(); + `$FIRST_CONV < $tfile1 > $tfile1.cfg`; + if ($? != 0) { + die "Error code: $?"; + } + my $cfg = `$CDEC -n -S 10000 -f scfg -g $tfile1.cfg -i $tfile2 --show_cfg_search_space 2>/dev/null`; + if ($? != 0) { + die "Error code: $?"; + } + my @rules = split /\n/, $cfg; + shift @rules; # get rid of output + for my $rule (@rules) { + my ($lhs, $f, $e, $feats) = split / \|\|\| /, $rule; + $f =~ s/,\d\]/\]/g; + $feats = 'TOP=1' unless $feats; + if ($lhs =~ /\[Goal_\d+\]/) { $lhs = '[S]'; } + print "$lhs ||| $f ||| $feats\n"; + if ($e eq '[1] [2]') { + my ($a, $b) = split /\s+/, $f; + $feats =~ s/=1$//; + my ($x, $y) = split /_/, $feats; + print "$lhs ||| $b $a ||| ${y}_$x=1\n"; + } + $flag = 1; + } + open T, ">$tfile1" or die "Can't write $tfile1: $!"; + $lc = -1; + } else { + my ($ind, $word, @dmmy) = split /\s+/; + push @words, $word; + } + $lc++; +} +close T; + diff --git a/training/dep-reorder/george.conll b/training/dep-reorder/george.conll new file mode 100644 index 00000000..7eebb360 --- /dev/null +++ b/training/dep-reorder/george.conll @@ -0,0 +1,4 @@ +1 George _ GEORGE _ _ 2 X _ _ +2 hates _ HATES _ _ 0 X _ _ +3 broccoli _ BROC _ _ 2 X _ _ + diff --git a/training/dep-reorder/scripts/conll2simplecfg.pl b/training/dep-reorder/scripts/conll2simplecfg.pl new file mode 100755 index 00000000..b101347a --- /dev/null +++ b/training/dep-reorder/scripts/conll2simplecfg.pl @@ -0,0 +1,57 @@ +#!/usr/bin/perl -w +use strict; + +# 1 在 _ 10 _ _ 4 X _ _ +# 2 门厅 _ 3 _ _ 1 X _ _ +# 3 下面 _ 23 _ _ 4 X _ _ +# 4 。 _ 45 _ _ 0 X _ _ + +my @ldeps; +my @rdeps; +@ldeps=(); for (my $i =0; $i <1000; $i++) { push @ldeps, []; } +@rdeps=(); for (my $i =0; $i <1000; $i++) { push @rdeps, []; } +my $rootcat = 0; +my @cats = ('S'); +my $len = 0; +my @noposcats = ('S'); +while(<>) { + chomp; + if (/^\s*$/) { + write_cfg($len); + $len = 0; + @cats=('S'); + @noposcats = ('S'); + @ldeps=(); for (my $i =0; $i <1000; $i++) { push @ldeps, []; } + @rdeps=(); for (my $i =0; $i <1000; $i++) { push @rdeps, []; } + next; + } + $len++; + my ($pos, $word, $d1, $xcat, $d2, $d3, $headpos, $deptype) = split /\s+/; + my $cat = "C$xcat"; + my $catpos = $cat . "_$pos"; + push @cats, $catpos; + push @noposcats, $cat; + print "[$catpos] ||| $word ||| $word ||| Word=1\n"; + if ($headpos == 0) { $rootcat = $pos; } + if ($pos < $headpos) { + push @{$ldeps[$headpos]}, $pos; + } else { + push @{$rdeps[$headpos]}, $pos; + } +} + +sub write_cfg { + my $len = shift; + for (my $i = 1; $i <= $len; $i++) { + my @lds = @{$ldeps[$i]}; + for my $ld (@lds) { + print "[$cats[$i]] ||| [$cats[$ld],1] [$cats[$i],2] ||| [1] [2] ||| $noposcats[$ld]_$noposcats[$i]=1\n"; + } + my @rds = @{$rdeps[$i]}; + for my $rd (@rds) { + print "[$cats[$i]] ||| [$cats[$i],1] [$cats[$rd],2] ||| [1] [2] ||| $noposcats[$i]_$noposcats[$rd]=1\n"; + } + } + print "[S] ||| [$cats[$rootcat],1] ||| [1] ||| TOP=1\n"; +} + diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc new file mode 100644 index 00000000..461ff8e4 --- /dev/null +++ b/training/grammar_convert.cc @@ -0,0 +1,319 @@ +#include <iostream> +#include <algorithm> +#include <sstream> + +#include <boost/lexical_cast.hpp> +#include <boost/program_options.hpp> + +#include "tdict.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +WordID kSTART; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value<string>()->default_value("-"), "Input file") + ("format,f", po::value<string>()->default_value("cfg"), "Input format. Values: cfg, json, split") + ("output,o", po::value<string>()->default_value("json"), "Output command. Values: json, 1best") + ("reorder,r", "Add Yamada & Knight (2002) reorderings") + ("weights,w", po::value<string>(), "Feature weights for k-best derivations [optional]") + ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") + ("k_derivations,k", po::value<int>(), "Show k derivations and their features") + ("max_reorder,m", po::value<int>()->default_value(999), "Move a constituent at most this far") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +int GetOrCreateNode(const WordID& lhs, map<WordID, int>* lhs2node, Hypergraph* hg) { + int& node_id = (*lhs2node)[lhs]; + if (!node_id) + node_id = hg->AddNode(lhs)->id_ + 1; + return node_id - 1; +} + +void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { + if (goal < 0) { + cerr << "Error! [S] not found in grammar!\n"; + exit(1); + } + if (hg->nodes_[goal].in_edges_.size() != 1) { + cerr << "Error! [S] has more than one rewrite!\n"; + exit(1); + } + int old_size = hg->nodes_.size(); + hg->TopologicallySortNodesAndEdges(goal); + if (hg->nodes_.size() != old_size) { + cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; + } +} + +void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { + Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); + hg->ConnectEdgeToHeadNode(new_edge, head_node); + new_edge->feature_values_ = r->scores_; +} + +// from a category label like "NP_2", return "NP" +string PureCategory(WordID cat) { + assert(cat < 0); + string c = TD::Convert(cat*-1); + size_t p = c.find("_"); + if (p == string::npos) return c; + return c.substr(0, p); +}; + +string ConstituentOrderFeature(const TRule& rule, const vector<int>& pi) { + const static string kTERM_VAR = "x"; + const vector<WordID>& f = rule.f(); + map<string, int> used; + vector<string> terms(f.size()); + for (int i = 0; i < f.size(); ++i) { + const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); + int& count = used[term]; + if (!count) { + terms[i] = term; + } else { + ostringstream os; + os << term << count; + terms[i] = os.str(); + } + ++count; + } + ostringstream os; + os << PureCategory(rule.GetLHS()) << ':'; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) os << '_'; + os << terms[pi[i]]; + } + return os.str(); +} + +bool CheckPermutationMask(const vector<int>& mask, const vector<int>& pi) { + assert(mask.size() == pi.size()); + + int req_min = -1; + int cur_max = 0; + int cur_mask = -1; + for (int i = 0; i < mask.size(); ++i) { + if (mask[i] != cur_mask) { + cur_mask = mask[i]; + req_min = cur_max - 1; + } + if (pi[i] > req_min) { + if (pi[i] > cur_max) cur_max = pi[i]; + } else { + return false; + } + } + + return true; +} + +void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { + // Hypergraph tmp = *hg; + Hypergraph::Node* node = &hg->nodes_[nodeid]; + if (node->in_edges_.size() != 1) { + cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; + cerr << " not recursing!\n"; + return; + } +// for (int eii = 0; eii < node->in_edges_.size(); ++eii) { + const int oe_index = node->in_edges_.front(); + const TRule& rule = *hg->edges_[oe_index].rule_; + const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; + const int tail_size = orig_tail.size(); + for (int i = 0; i < tail_size; ++i) { + PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); + } + const vector<WordID>& of = rule.f_; + if (of.size() == 1) return; + // cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; + // cerr << "ORIG: " << rule.AsString() << endl; + vector<WordID> pi(of.size(), 0); + for (int i = 0; i < pi.size(); ++i) pi[i] = i; + + vector<int> permutation_mask(of.size(), 0); + const bool dont_reorder_across_PU = true; // TODO add configuration + if (dont_reorder_across_PU) { + int cur = 0; + for (int i = 0; i < pi.size(); ++i) { + if (of[i] >= 0) continue; + const string cat = PureCategory(of[i]); + if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { + ++cur; + permutation_mask[i] = cur; + ++cur; + } else { + permutation_mask[i] = cur; + } + } + } + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); + while (next_permutation(pi.begin(), pi.end())) { + if (!CheckPermutationMask(permutation_mask, pi)) + continue; + vector<WordID> nf(pi.size(), 0); + Hypergraph::TailNodeVector tail(pi.size(), 0); + bool skip = false; + for (int i = 0; i < pi.size(); ++i) { + int dist = pi[i] - i; if (dist < 0) dist *= -1; + if (dist > max_reorder) { skip = true; break; } + nf[i] = of[pi[i]]; + tail[i] = orig_tail[pi[i]]; + } + if (skip) continue; + TRulePtr nr(new TRule(rule)); + nr->f_ = nf; + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + nr->scores_.set_value(fid, 1.0); + // cerr << "PERM: " << nr->AsString() << endl; + CreateEdge(nr, tail, node, hg); + } + // } +} + +void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { + assert(hg->nodes_.back().cat_ == kSTART); + assert(hg->nodes_.back().in_edges_.size() == 1); + PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); +} + +void CollapseWeights(Hypergraph* hg) { + int fid = FD::Convert("Reordering"); + for (int i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + edge.feature_values_.clear(); + if (edge.edge_prob_ != prob_t::Zero()) { + edge.feature_values_.set_value(fid, log(edge.edge_prob_)); + } + } +} + +void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { + if (conf.count("reorder")) + PermuteYamadaAndKnight(hg, conf["max_reorder"].as<int>()); + if (w.size() > 0) { hg->Reweight(w); } + if (conf.count("collapse_weights")) CollapseWeights(hg); + if (conf["output"].as<string>() == "json") { + HypergraphIO::WriteToJSON(*hg, false, &cout); + if (!ref.empty()) { cerr << "REF: " << ref << endl; } + } else { + vector<WordID> onebest; + ViterbiESentence(*hg, &onebest); + if (ref.empty()) { + cout << TD::GetString(onebest) << endl; + } else { + cout << TD::GetString(onebest) << " ||| " << ref << endl; + } + } + if (conf.count("k_derivations")) { + const int k = conf["k_derivations"].as<int>(); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(*hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +int main(int argc, char **argv) { + kSTART = TD::Convert("S") * -1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as<string>(); + const bool is_split_input = (conf["format"].as<string>() == "split"); + const bool is_json_input = is_split_input || (conf["format"].as<string>() == "json"); + const bool collapse_weights = conf.count("collapse_weights"); + Weights wts; + vector<double> w; + if (conf.count("weights")) { + wts.InitFromFile(conf["weights"].as<string>()); + wts.InitVector(&w); + } + if (collapse_weights && !w.size()) { + cerr << "--collapse_weights requires a weights file to be specified!\n"; + exit(1); + } + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + int lc = 0; + Hypergraph hg; + map<WordID, int> lhs2node; + while(*in) { + string line; + ++lc; + getline(*in, line); + if (is_json_input) { + if (line.empty() || line[0] == '#') continue; + string ref; + if (is_split_input) { + size_t pos = line.rfind("}}"); + assert(pos != string::npos); + size_t rstart = line.find("||| ", pos); + assert(rstart != string::npos); + ref = line.substr(rstart + 4); + line = line.substr(0, pos + 2); + } + istringstream is(line); + if (HypergraphIO::ReadFromJSON(&is, &hg)) { + ProcessHypergraph(w, conf, ref, &hg); + hg.clear(); + } else { + cerr << "Error reading grammar from JSON: line " << lc << endl; + exit(1); + } + } else { + if (line.empty()) { + int goal = lhs2node[kSTART] - 1; + FilterAndCheckCorrectness(goal, &hg); + ProcessHypergraph(w, conf, "", &hg); + hg.clear(); + lhs2node.clear(); + continue; + } + if (line[0] == '#') continue; + if (line[0] != '[') { + cerr << "Line " << lc << ": bad format\n"; + exit(1); + } + TRulePtr tr(TRule::CreateRuleMonolingual(line)); + Hypergraph::TailNodeVector tail; + for (int i = 0; i < tr->f_.size(); ++i) { + WordID var_cat = tr->f_[i]; + if (var_cat < 0) + tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); + } + const WordID lhs = tr->GetLHS(); + int head = GetOrCreateNode(lhs, &lhs2node, &hg); + Hypergraph::Edge* edge = hg.AddEdge(tr, tail); + edge->feature_values_ = tr->scores_; + Hypergraph::Node* node = &hg.nodes_[head]; + hg.ConnectEdgeToHeadNode(edge, node); + } + } +} + diff --git a/training/lbfgs.h b/training/lbfgs.h new file mode 100644 index 00000000..e8baecab --- /dev/null +++ b/training/lbfgs.h @@ -0,0 +1,1459 @@ +#ifndef SCITBX_LBFGS_H +#define SCITBX_LBFGS_H + +#include <cstdio> +#include <cstddef> +#include <cmath> +#include <stdexcept> +#include <algorithm> +#include <vector> +#include <string> +#include <iostream> +#include <sstream> + +namespace scitbx { + +//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. +/*! Implementation of the + Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) + algorithm for large-scale multidimensional minimization + problems. + + This code was manually derived from Java code which was + in turn derived from the Fortran program + <code>lbfgs.f</code>. The Java translation was + effected mostly mechanically, with some manual + clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran + code have been pasted in. + + Information on the original LBFGS Fortran source code is + available at + http://www.netlib.org/opt/lbfgs_um.shar . The following + information is taken verbatim from the Netlib documentation + for the Fortran source. + + <pre> + file opt/lbfgs_um.shar + for unconstrained optimization problems + alg limited memory BFGS method + by J. Nocedal + contact nocedal@eecs.nwu.edu + ref D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for + , large scale optimization methods'' Mathematical Programming 45 + , (1989), pp. 503-528. + , (Postscript file of this paper is available via anonymous ftp + , to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.) + </pre> + + @author Jorge Nocedal: original Fortran version, including comments + (July 1990).<br> + Robert Dodier: Java translation, August 1997.<br> + Ralf W. Grosse-Kunstleve: C++ port, March 2002.<br> + Chris Dyer: serialize/deserialize functionality + */ +namespace lbfgs { + + //! Generic exception class for %lbfgs %error messages. + /*! All exceptions thrown by the minimizer are derived from this class. + */ + class error : public std::exception { + public: + //! Constructor. + error(std::string const& msg) throw() + : msg_("lbfgs error: " + msg) + {} + //! Access to error message. + virtual const char* what() const throw() { return msg_.c_str(); } + protected: + virtual ~error() throw() {} + std::string msg_; + public: + static std::string itoa(unsigned long i) { + std::ostringstream os; + os << i; + return os.str(); + } + }; + + //! Specific exception class. + class error_internal_error : public error { + public: + //! Constructor. + error_internal_error(const char* file, unsigned long line) throw() + : error( + "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") + {} + }; + + //! Specific exception class. + class error_improper_input_parameter : public error { + public: + //! Constructor. + error_improper_input_parameter(std::string const& msg) throw() + : error("Improper input parameter: " + msg) + {} + }; + + //! Specific exception class. + class error_improper_input_data : public error { + public: + //! Constructor. + error_improper_input_data(std::string const& msg) throw() + : error("Improper input data: " + msg) + {} + }; + + //! Specific exception class. + class error_search_direction_not_descent : public error { + public: + //! Constructor. + error_search_direction_not_descent() throw() + : error("The search direction is not a descent direction.") + {} + }; + + //! Specific exception class. + class error_line_search_failed : public error { + public: + //! Constructor. + error_line_search_failed(std::string const& msg) throw() + : error("Line search failed: " + msg) + {} + }; + + //! Specific exception class. + class error_line_search_failed_rounding_errors + : public error_line_search_failed { + public: + //! Constructor. + error_line_search_failed_rounding_errors(std::string const& msg) throw() + : error_line_search_failed(msg) + {} + }; + + namespace detail { + + template <typename NumType> + inline + NumType + pow2(NumType const& x) { return x * x; } + + template <typename NumType> + inline + NumType + abs(NumType const& x) { + if (x < NumType(0)) return -x; + return x; + } + + // This class implements an algorithm for multi-dimensional line search. + template <typename FloatType, typename SizeType = std::size_t> + class mcsrch + { + protected: + int infoc; + FloatType dginit; + bool brackt; + bool stage1; + FloatType finit; + FloatType dgtest; + FloatType width; + FloatType width1; + FloatType stx; + FloatType fx; + FloatType dgx; + FloatType sty; + FloatType fy; + FloatType dgy; + FloatType stmin; + FloatType stmax; + + static FloatType const& max3( + FloatType const& x, + FloatType const& y, + FloatType const& z) + { + return x < y ? (y < z ? z : y ) : (x < z ? z : x ); + } + + public: + /* Minimize a function along a search direction. This code is + a Java translation of the function <code>MCSRCH</code> from + <code>lbfgs.f</code>, which in turn is a slight modification + of the subroutine <code>CSRCH</code> of More' and Thuente. + The changes are to allow reverse communication, and do not + affect the performance of the routine. This function, in turn, + calls <code>mcstep</code>.<p> + + The Java translation was effected mostly mechanically, with + some manual clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran code have + been pasted in here as well.<p> + + The purpose of <code>mcsrch</code> is to find a step which + satisfies a sufficient decrease condition and a curvature + condition.<p> + + At each stage this function updates an interval of uncertainty + with endpoints <code>stx</code> and <code>sty</code>. The + interval of uncertainty is initially chosen so that it + contains a minimizer of the modified function + <pre> + f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s). + </pre> + If a step is obtained for which the modified function has a + nonpositive function value and nonnegative derivative, then + the interval of uncertainty is chosen so that it contains a + minimizer of <code>f(x+stp*s)</code>.<p> + + The algorithm is designed to find a step which satisfies + the sufficient decrease condition + <pre> + f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s), + </pre> + and the curvature condition + <pre> + abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s). + </pre> + If <code>ftol</code> is less than <code>gtol</code> and if, + for example, the function is bounded below, then there is + always a step which satisfies both conditions. If no step can + be found which satisfies both conditions, then the algorithm + usually stops when rounding errors prevent further progress. + In this case <code>stp</code> only satisfies the sufficient + decrease condition.<p> + + @author Original Fortran version by Jorge J. More' and + David J. Thuente as part of the Minpack project, June 1983, + Argonne National Laboratory. Java translation by Robert + Dodier, August 1997. + + @param n The number of variables. + + @param x On entry this contains the base point for the line + search. On exit it contains <code>x + stp*s</code>. + + @param f On entry this contains the value of the objective + function at <code>x</code>. On exit it contains the value + of the objective function at <code>x + stp*s</code>. + + @param g On entry this contains the gradient of the objective + function at <code>x</code>. On exit it contains the gradient + at <code>x + stp*s</code>. + + @param s The search direction. + + @param stp On entry this contains an initial estimate of a + satifactory step length. On exit <code>stp</code> contains + the final estimate. + + @param ftol Tolerance for the sufficient decrease condition. + + @param xtol Termination occurs when the relative width of the + interval of uncertainty is at most <code>xtol</code>. + + @param maxfev Termination occurs when the number of evaluations + of the objective function is at least <code>maxfev</code> by + the end of an iteration. + + @param info This is an output variable, which can have these + values: + <ul> + <li><code>info = -1</code> A return is made to compute + the function and gradient. + <li><code>info = 1</code> The sufficient decrease condition + and the directional derivative condition hold. + </ul> + + @param nfev On exit, this is set to the number of function + evaluations. + + @param wa Temporary storage array, of length <code>n</code>. + */ + void run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa); + + /* The purpose of this function is to compute a safeguarded step + for a linesearch and to update an interval of uncertainty for + a minimizer of the function.<p> + + The parameter <code>stx</code> contains the step with the + least function value. The parameter <code>stp</code> contains + the current step. It is assumed that the derivative at + <code>stx</code> is negative in the direction of the step. If + <code>brackt</code> is <code>true</code> when + <code>mcstep</code> returns then a minimizer has been + bracketed in an interval of uncertainty with endpoints + <code>stx</code> and <code>sty</code>.<p> + + Variables that must be modified by <code>mcstep</code> are + implemented as 1-element arrays. + + @param stx Step at the best step obtained so far. + This variable is modified by <code>mcstep</code>. + @param fx Function value at the best step obtained so far. + This variable is modified by <code>mcstep</code>. + @param dx Derivative at the best step obtained so far. + The derivative must be negative in the direction of the + step, that is, <code>dx</code> and <code>stp-stx</code> must + have opposite signs. This variable is modified by + <code>mcstep</code>. + + @param sty Step at the other endpoint of the interval of + uncertainty. This variable is modified by <code>mcstep</code>. + @param fy Function value at the other endpoint of the interval + of uncertainty. This variable is modified by + <code>mcstep</code>. + + @param dy Derivative at the other endpoint of the interval of + uncertainty. This variable is modified by <code>mcstep</code>. + + @param stp Step at the current step. If <code>brackt</code> is set + then on input <code>stp</code> must be between <code>stx</code> + and <code>sty</code>. On output <code>stp</code> is set to the + new step. + @param fp Function value at the current step. + @param dp Derivative at the current step. + + @param brackt Tells whether a minimizer has been bracketed. + If the minimizer has not been bracketed, then on input this + variable must be set <code>false</code>. If the minimizer has + been bracketed, then on output this variable is + <code>true</code>. + + @param stpmin Lower bound for the step. + @param stpmax Upper bound for the step. + + If the return value is 1, 2, 3, or 4, then the step has + been computed successfully. A return value of 0 indicates + improper input parameters. + + @author Jorge J. More, David J. Thuente: original Fortran version, + as part of Minpack project. Argonne Nat'l Laboratory, June 1983. + Robert Dodier: Java translation, August 1997. + */ + static int mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax); + + void serialize(std::ostream* out) const { + out->write((const char*)&infoc,sizeof(infoc)); + out->write((const char*)&dginit,sizeof(dginit)); + out->write((const char*)&brackt,sizeof(brackt)); + out->write((const char*)&stage1,sizeof(stage1)); + out->write((const char*)&finit,sizeof(finit)); + out->write((const char*)&dgtest,sizeof(dgtest)); + out->write((const char*)&width,sizeof(width)); + out->write((const char*)&width1,sizeof(width1)); + out->write((const char*)&stx,sizeof(stx)); + out->write((const char*)&fx,sizeof(fx)); + out->write((const char*)&dgx,sizeof(dgx)); + out->write((const char*)&sty,sizeof(sty)); + out->write((const char*)&fy,sizeof(fy)); + out->write((const char*)&dgy,sizeof(dgy)); + out->write((const char*)&stmin,sizeof(stmin)); + out->write((const char*)&stmax,sizeof(stmax)); + } + + void deserialize(std::istream* in) const { + in->read((char*)&infoc, sizeof(infoc)); + in->read((char*)&dginit, sizeof(dginit)); + in->read((char*)&brackt, sizeof(brackt)); + in->read((char*)&stage1, sizeof(stage1)); + in->read((char*)&finit, sizeof(finit)); + in->read((char*)&dgtest, sizeof(dgtest)); + in->read((char*)&width, sizeof(width)); + in->read((char*)&width1, sizeof(width1)); + in->read((char*)&stx, sizeof(stx)); + in->read((char*)&fx, sizeof(fx)); + in->read((char*)&dgx, sizeof(dgx)); + in->read((char*)&sty, sizeof(sty)); + in->read((char*)&fy, sizeof(fy)); + in->read((char*)&dgy, sizeof(dgy)); + in->read((char*)&stmin, sizeof(stmin)); + in->read((char*)&stmax, sizeof(stmax)); + } + }; + + template <typename FloatType, typename SizeType> + void mcsrch<FloatType, SizeType>::run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa) + { + if (info != -1) { + infoc = 1; + if ( n == 0 + || maxfev == 0 + || gtol < FloatType(0) + || xtol < FloatType(0) + || stpmin < FloatType(0) + || stpmax < stpmin) { + throw error_internal_error(__FILE__, __LINE__); + } + if (stp <= FloatType(0) || ftol < FloatType(0)) { + throw error_internal_error(__FILE__, __LINE__); + } + // Compute the initial gradient in the search direction + // and check that s is a descent direction. + dginit = FloatType(0); + for (SizeType j = 0; j < n; j++) { + dginit += g[j] * s[is0+j]; + } + if (dginit >= FloatType(0)) { + throw error_search_direction_not_descent(); + } + brackt = false; + stage1 = true; + nfev = 0; + finit = f; + dgtest = ftol*dginit; + width = stpmax - stpmin; + width1 = FloatType(2) * width; + std::copy(x, x+n, wa); + // The variables stx, fx, dgx contain the values of the step, + // function, and directional derivative at the best step. + // The variables sty, fy, dgy contain the value of the step, + // function, and derivative at the other endpoint of + // the interval of uncertainty. + // The variables stp, f, dg contain the values of the step, + // function, and derivative at the current step. + stx = FloatType(0); + fx = finit; + dgx = dginit; + sty = FloatType(0); + fy = finit; + dgy = dginit; + } + for (;;) { + if (info != -1) { + // Set the minimum and maximum steps to correspond + // to the present interval of uncertainty. + if (brackt) { + stmin = std::min(stx, sty); + stmax = std::max(stx, sty); + } + else { + stmin = stx; + stmax = stp + FloatType(4) * (stp - stx); + } + // Force the step to be within the bounds stpmax and stpmin. + stp = std::max(stp, stpmin); + stp = std::min(stp, stpmax); + // If an unusual termination is to occur then let + // stp be the lowest point obtained so far. + if ( (brackt && (stp <= stmin || stp >= stmax)) + || nfev >= maxfev - 1 || infoc == 0 + || (brackt && stmax - stmin <= xtol * stmax)) { + stp = stx; + } + // Evaluate the function and gradient at stp + // and compute the directional derivative. + // We return to main program to obtain F and G. + for (SizeType j = 0; j < n; j++) { + x[j] = wa[j] + stp * s[is0+j]; + } + info=-1; + break; + } + info = 0; + nfev++; + FloatType dg(0); + for (SizeType j = 0; j < n; j++) { + dg += g[j] * s[is0+j]; + } + FloatType ftest1 = finit + stp*dgtest; + // Test for convergence. + if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { + throw error_line_search_failed_rounding_errors( + "Rounding errors prevent further progress." + " There may not be a step which satisfies the" + " sufficient decrease and curvature conditions." + " Tolerances may be too small."); + } + if (stp == stpmax && f <= ftest1 && dg <= dgtest) { + throw error_line_search_failed( + "The step is at the upper bound stpmax()."); + } + if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { + throw error_line_search_failed( + "The step is at the lower bound stpmin()."); + } + if (nfev >= maxfev) { + throw error_line_search_failed( + "Number of function evaluations has reached maxfev()."); + } + if (brackt && stmax - stmin <= xtol * stmax) { + throw error_line_search_failed( + "Relative width of the interval of uncertainty" + " is at most xtol()."); + } + // Check for termination. + if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { + info = 1; + break; + } + // In the first stage we seek a step for which the modified + // function has a nonpositive value and nonnegative derivative. + if ( stage1 && f <= ftest1 + && dg >= std::min(ftol, gtol) * dginit) { + stage1 = false; + } + // A modified function is used to predict the step only if + // we have not obtained a step for which the modified + // function has a nonpositive function value and nonnegative + // derivative, and if a lower function value has been + // obtained but the decrease is not sufficient. + if (stage1 && f <= fx && f > ftest1) { + // Define the modified function and derivative values. + FloatType fm = f - stp*dgtest; + FloatType fxm = fx - stx*dgtest; + FloatType fym = fy - sty*dgtest; + FloatType dgm = dg - dgtest; + FloatType dgxm = dgx - dgtest; + FloatType dgym = dgy - dgtest; + // Call cstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, + brackt, stmin, stmax); + // Reset the function and gradient values for f. + fx = fxm + stx*dgtest; + fy = fym + sty*dgtest; + dgx = dgxm + dgtest; + dgy = dgym + dgtest; + } + else { + // Call mcstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, + brackt, stmin, stmax); + } + // Force a sufficient decrease in the size of the + // interval of uncertainty. + if (brackt) { + if (abs(sty - stx) >= FloatType(0.66) * width1) { + stp = stx + FloatType(0.5) * (sty - stx); + } + width1 = width; + width = abs(sty - stx); + } + } + } + + template <typename FloatType, typename SizeType> + int mcsrch<FloatType, SizeType>::mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax) + { + bool bound; + FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; + int info = 0; + if ( ( brackt && (stp <= std::min(stx, sty) + || stp >= std::max(stx, sty))) + || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { + return 0; + } + // Determine if the derivatives have opposite sign. + sgnd = dp * (dx / abs(dx)); + if (fp > fx) { + // First case. A higher function value. + // The minimum is bracketed. If the cubic step is closer + // to stx than the quadratic step, the cubic step is taken, + // else the average of the cubic and quadratic steps is taken. + info = 1; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp < stx) gamma = - gamma; + p = (gamma - dx) + theta; + q = ((gamma - dx) + gamma) + dp; + r = p/q; + stpc = stx + r * (stp - stx); + stpq = stx + + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) + * (stp - stx); + if (abs(stpc - stx) < abs(stpq - stx)) { + stpf = stpc; + } + else { + stpf = stpc + (stpq - stpc) / FloatType(2); + } + brackt = true; + } + else if (sgnd < FloatType(0)) { + // Second case. A lower function value and derivatives of + // opposite sign. The minimum is bracketed. If the cubic + // step is closer to stx than the quadratic (secant) step, + // the cubic step is taken, else the quadratic step is taken. + info = 2; + bound = false; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp > stx) gamma = - gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dx; + r = p/q; + stpc = stp + r * (stx - stp); + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (abs(stpc - stp) > abs(stpq - stp)) { + stpf = stpc; + } + else { + stpf = stpq; + } + brackt = true; + } + else if (abs(dp) < abs(dx)) { + // Third case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative decreases. + // The cubic step is only used if the cubic tends to infinity + // in the direction of the step or if the minimum of the cubic + // is beyond stp. Otherwise the cubic step is defined to be + // either stpmin or stpmax. The quadratic (secant) step is also + // computed and if the minimum is bracketed then the the step + // closest to stx is taken, else the step farthest away is taken. + info = 3; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt( + std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); + if (stp > stx) gamma = -gamma; + p = (gamma - dp) + theta; + q = (gamma + (dx - dp)) + gamma; + r = p/q; + if (r < FloatType(0) && gamma != FloatType(0)) { + stpc = stp + r * (stx - stp); + } + else if (stp > stx) { + stpc = stpmax; + } + else { + stpc = stpmin; + } + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (brackt) { + if (abs(stp - stpc) < abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + else { + if (abs(stp - stpc) > abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + } + else { + // Fourth case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative does + // not decrease. If the minimum is not bracketed, the step + // is either stpmin or stpmax, else the cubic step is taken. + info = 4; + bound = false; + if (brackt) { + theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; + s = max3(abs(theta), abs(dy), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); + if (stp > sty) gamma = -gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dy; + r = p/q; + stpc = stp + r * (sty - stp); + stpf = stpc; + } + else if (stp > stx) { + stpf = stpmax; + } + else { + stpf = stpmin; + } + } + // Update the interval of uncertainty. This update does not + // depend on the new step or the case analysis above. + if (fp > fx) { + sty = stp; + fy = fp; + dy = dp; + } + else { + if (sgnd < FloatType(0)) { + sty = stx; + fy = fx; + dy = dx; + } + stx = stp; + fx = fp; + dx = dp; + } + // Compute the new step and safeguard it. + stpf = std::min(stpmax, stpf); + stpf = std::max(stpmin, stpf); + stp = stpf; + if (brackt && bound) { + if (sty > stx) { + stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); + } + else { + stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); + } + } + return info; + } + + /* Compute the sum of a vector times a scalar plus another vector. + Adapted from the subroutine <code>daxpy</code> in + <code>lbfgs.f</code>. + */ + template <typename FloatType, typename SizeType> + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + SizeType incx, + FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + if (n == 0) return; + if (da == FloatType(0)) return; + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dy[iy0+iy] += da * dx[ix0+ix]; + ix += incx; + iy += incy; + } + return; + } + m = n % 4; + for (i = 0; i < m; i++) { + dy[iy0+i] += da * dx[ix0+i]; + } + for (; i < n;) { + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + } + } + + template <typename FloatType, typename SizeType> + inline + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + FloatType* dy) + { + daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); + } + + /* Compute the dot product of two vectors. + Adapted from the subroutine <code>ddot</code> + in <code>lbfgs.f</code>. + */ + template <typename FloatType, typename SizeType> + FloatType ddot( + SizeType n, + const FloatType* dx, + SizeType ix0, + SizeType incx, + const FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + FloatType dtemp(0); + if (n == 0) return FloatType(0); + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dtemp += dx[ix0+ix] * dy[iy0+iy]; + ix += incx; + iy += incy; + } + return dtemp; + } + m = n % 5; + for (i = 0; i < m; i++) { + dtemp += dx[ix0+i] * dy[iy0+i]; + } + for (; i < n;) { + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + } + return dtemp; + } + + template <typename FloatType, typename SizeType> + inline + FloatType ddot( + SizeType n, + const FloatType* dx, + const FloatType* dy) + { + return ddot( + n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); + } + + } // namespace detail + + //! Interface to the LBFGS %minimizer. + /*! This class solves the unconstrained minimization problem + <pre> + min f(x), x = (x1,x2,...,x_n), + </pre> + using the limited-memory BFGS method. The routine is + especially effective on problems involving a large number of + variables. In a typical iteration of this method an + approximation Hk to the inverse of the Hessian + is obtained by applying <code>m</code> BFGS updates to a + diagonal matrix Hk0, using information from the + previous <code>m</code> steps. The user specifies the number + <code>m</code>, which determines the amount of storage + required by the routine. The user may also provide the + diagonal matrices Hk0 (parameter <code>diag</code> in the run() + function) if not satisfied with the default choice. The + algorithm is described in "On the limited memory BFGS method for + large scale optimization", by D. Liu and J. Nocedal, Mathematical + Programming B 45 (1989) 503-528. + + The user is required to calculate the function value + <code>f</code> and its gradient <code>g</code>. In order to + allow the user complete control over these computations, + reverse communication is used. The routine must be called + repeatedly under the control of the member functions + <code>requests_f_and_g()</code>, + <code>requests_diag()</code>. + If neither requests_f_and_g() nor requests_diag() is + <code>true</code> the user should check for convergence + (using class traditional_convergence_test or any + other custom test). If the convergence test is negative, + the minimizer may be called again for the next iteration. + + The steplength (stp()) is determined at each iteration + by means of the line search routine <code>mcsrch</code>, which is + a slight modification of the routine <code>CSRCH</code> written + by More' and Thuente. + + The only variables that are machine-dependent are + <code>xtol</code>, + <code>stpmin</code> and + <code>stpmax</code>. + + Fatal errors cause <code>error</code> exceptions to be thrown. + The generic class <code>error</code> is sub-classed (e.g. + class <code>error_line_search_failed</code>) to facilitate + granular %error handling. + + A note on performance: Using Compaq Fortran V5.4 and + Compaq C++ V6.5, the C++ implementation is about 15% slower + than the Fortran implementation. + */ + template <typename FloatType, typename SizeType = std::size_t> + class minimizer + { + public: + //! Default constructor. Some members are not initialized! + minimizer() + : n_(0), m_(0), maxfev_(0), + gtol_(0), xtol_(0), + stpmin_(0), stpmax_(0), + ispt(0), iypt(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: <code>n > 0</code>. + + @param m The number of corrections used in the BFGS update. + Values of <code>m</code> less than 3 are not recommended; + large values of <code>m</code> will result in excessive + computing time. <code>3 <= m <= 7</code> is + recommended. + Restriction: <code>m > 0</code>. + + @param maxfev Maximum number of function evaluations + <b>per line search</b>. + Termination occurs when the number of evaluations + of the objective function is at least <code>maxfev</code> by + the end of an iteration. + + @param gtol Controls the accuracy of the line search. + If the function and gradient evaluations are inexpensive with + respect to the cost of the iteration (which is sometimes the + case when solving very large problems) it may be advantageous + to set <code>gtol</code> to a small value. A typical small + value is 0.1. + Restriction: <code>gtol</code> should be greater than 1e-4. + + @param xtol An estimate of the machine precision (e.g. 10e-16 + on a SUN station 3/60). The line search routine will + terminate if the relative width of the interval of + uncertainty is less than <code>xtol</code>. + + @param stpmin Specifies the lower bound for the step + in the line search. + The default value is 1e-20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + + @param stpmax specifies the upper bound for the step + in the line search. + The default value is 1e20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + */ + explicit + minimizer( + SizeType n, + SizeType m = 5, + SizeType maxfev = 20, + FloatType gtol = FloatType(0.9), + FloatType xtol = FloatType(1.e-16), + FloatType stpmin = FloatType(1.e-20), + FloatType stpmax = FloatType(1.e20)) + : n_(n), m_(m), maxfev_(maxfev), + gtol_(gtol), xtol_(xtol), + stpmin_(stpmin), stpmax_(stpmax), + iflag_(0), requests_f_and_g_(false), requests_diag_(false), + iter_(0), nfun_(0), stp_(0), + stp1(0), ftol(0.0001), ys(0), point(0), npt(0), + ispt(n+2*m), iypt((n+2*m)+n*m), + info(0), bound(0), nfev(0) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (m_ == 0) { + throw error_improper_input_parameter("m = 0."); + } + if (maxfev_ == 0) { + throw error_improper_input_parameter("maxfev = 0."); + } + if (gtol_ <= FloatType(1.e-4)) { + throw error_improper_input_parameter("gtol <= 1.e-4."); + } + if (xtol_ < FloatType(0)) { + throw error_improper_input_parameter("xtol < 0."); + } + if (stpmin_ < FloatType(0)) { + throw error_improper_input_parameter("stpmin < 0."); + } + if (stpmax_ < stpmin) { + throw error_improper_input_parameter("stpmax < stpmin"); + } + w_.resize(n_*(2*m_+1)+2*m_); + scratch_array_.resize(n_); + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + //! Number of corrections kept (as passed to the constructor). + SizeType m() const { return m_; } + + /*! \brief Maximum number of evaluations of the objective function + per line search (as passed to the constructor). + */ + SizeType maxfev() const { return maxfev_; } + + /*! \brief Control of the accuracy of the line search. + (as passed to the constructor). + */ + FloatType gtol() const { return gtol_; } + + //! Estimate of the machine precision (as passed to the constructor). + FloatType xtol() const { return xtol_; } + + /*! \brief Lower bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmin() const { return stpmin_; } + + /*! \brief Upper bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmax() const { return stpmax_; } + + //! Status indicator for reverse communication. + /*! <code>true</code> if the run() function returns to request + evaluation of the objective function (<code>f</code>) and + gradients (<code>g</code>) for the current point + (<code>x</code>). To continue the minimization the + run() function is called again with the updated values for + <code>f</code> and <code>g</code>. + <p> + See also: requests_diag() + */ + bool requests_f_and_g() const { return requests_f_and_g_; } + + //! Status indicator for reverse communication. + /*! <code>true</code> if the run() function returns to request + evaluation of the diagonal matrix (<code>diag</code>) + for the current point (<code>x</code>). + To continue the minimization the run() function is called + again with the updated values for <code>diag</code>. + <p> + See also: requests_f_and_g() + */ + bool requests_diag() const { return requests_diag_; } + + //! Number of iterations so far. + /*! Note that one iteration may involve multiple evaluations + of the objective function. + <p> + See also: nfun() + */ + SizeType iter() const { return iter_; } + + //! Total number of evaluations of the objective function so far. + /*! The total number of function evaluations increases by the + number of evaluations required for the line search. The total + is only increased after a successful line search. + <p> + See also: iter() + */ + SizeType nfun() const { return nfun_; } + + //! Norm of gradient given gradient array of length n(). + FloatType euclidean_norm(const FloatType* a) const { + return std::sqrt(detail::ddot(n_, a, a)); + } + + //! Current stepsize. + FloatType stp() const { return stp_; } + + //! Execution of one step of the minimization. + /*! @param x On initial entry this must be set by the user to + the values of the initial estimate of the solution vector. + + @param f Before initial entry or on re-entry under the + control of requests_f_and_g(), <code>f</code> must be set + by the user to contain the value of the objective function + at the current point <code>x</code>. + + @param g Before initial entry or on re-entry under the + control of requests_f_and_g(), <code>g</code> must be set + by the user to contain the components of the gradient at + the current point <code>x</code>. + + The return value is <code>true</code> if either + requests_f_and_g() or requests_diag() is <code>true</code>. + Otherwise the user should check for convergence + (e.g. using class traditional_convergence_test) and + call the run() function again to continue the minimization. + If the return value is <code>false</code> the user + should <b>not</b> update <code>f</code>, <code>g</code> or + <code>diag</code> (other overload) before calling + the run() function again. + + Note that <code>x</code> is always modified by the run() + function. Depending on the situation it can therefore be + necessary to evaluate the objective function one more time + after the minimization is terminated. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g) + { + return generic_run(x, f, g, false, 0); + } + + //! Execution of one step of the minimization. + /*! @param x See other overload. + + @param f See other overload. + + @param g See other overload. + + @param diag On initial entry or on re-entry under the + control of requests_diag(), <code>diag</code> must be set by + the user to contain the values of the diagonal matrix Hk0. + The routine will return at each iteration of the algorithm + with requests_diag() set to <code>true</code>. + <p> + Restriction: all elements of <code>diag</code> must be + positive. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g, + const FloatType* diag) + { + return generic_run(x, f, g, true, diag); + } + + void serialize(std::ostream* out) const { + out->write((const char*)&n_, sizeof(n_)); // sanity check + out->write((const char*)&m_, sizeof(m_)); // sanity check + SizeType fs = sizeof(FloatType); + out->write((const char*)&fs, sizeof(fs)); // sanity check + + mcsrch_instance.serialize(out); + out->write((const char*)&iflag_, sizeof(iflag_)); + out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + out->write((const char*)&requests_diag_, sizeof(requests_diag_)); + out->write((const char*)&iter_, sizeof(iter_)); + out->write((const char*)&nfun_, sizeof(nfun_)); + out->write((const char*)&stp_, sizeof(stp_)); + out->write((const char*)&stp1, sizeof(stp1)); + out->write((const char*)&ftol, sizeof(ftol)); + out->write((const char*)&ys, sizeof(ys)); + out->write((const char*)&point, sizeof(point)); + out->write((const char*)&npt, sizeof(npt)); + out->write((const char*)&info, sizeof(info)); + out->write((const char*)&bound, sizeof(bound)); + out->write((const char*)&nfev, sizeof(nfev)); + out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); + out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + void deserialize(std::istream* in) { + SizeType n, m, fs; + in->read((char*)&n, sizeof(n)); + in->read((char*)&m, sizeof(m)); + in->read((char*)&fs, sizeof(fs)); + assert(n == n_); + assert(m == m_); + assert(fs == sizeof(FloatType)); + + mcsrch_instance.deserialize(in); + in->read((char*)&iflag_, sizeof(iflag_)); + in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + in->read((char*)&requests_diag_, sizeof(requests_diag_)); + in->read((char*)&iter_, sizeof(iter_)); + in->read((char*)&nfun_, sizeof(nfun_)); + in->read((char*)&stp_, sizeof(stp_)); + in->read((char*)&stp1, sizeof(stp1)); + in->read((char*)&ftol, sizeof(ftol)); + in->read((char*)&ys, sizeof(ys)); + in->read((char*)&point, sizeof(point)); + in->read((char*)&npt, sizeof(npt)); + in->read((char*)&info, sizeof(info)); + in->read((char*)&bound, sizeof(bound)); + in->read((char*)&nfev, sizeof(nfev)); + in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); + in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + protected: + static void throw_diagonal_element_not_positive(SizeType i) { + throw error_improper_input_data( + "The " + error::itoa(i) + ". diagonal element of the" + " inverse Hessian approximation is not positive."); + } + + bool generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag); + + detail::mcsrch<FloatType, SizeType> mcsrch_instance; + const SizeType n_; + const SizeType m_; + const SizeType maxfev_; + const FloatType gtol_; + const FloatType xtol_; + const FloatType stpmin_; + const FloatType stpmax_; + int iflag_; + bool requests_f_and_g_; + bool requests_diag_; + SizeType iter_; + SizeType nfun_; + FloatType stp_; + FloatType stp1; + FloatType ftol; + FloatType ys; + SizeType point; + SizeType npt; + const SizeType ispt; + const SizeType iypt; + int info; + SizeType bound; + SizeType nfev; + std::vector<FloatType> w_; + std::vector<FloatType> scratch_array_; + }; + + template <typename FloatType, typename SizeType> + bool minimizer<FloatType, SizeType>::generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag) + { + bool execute_entire_while_loop = false; + if (!(requests_f_and_g_ || requests_diag_)) { + execute_entire_while_loop = true; + } + requests_f_and_g_ = false; + requests_diag_ = false; + FloatType* w = &(*(w_.begin())); + if (iflag_ == 0) { // Initialize. + nfun_ = 1; + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + else { + std::fill_n(scratch_array_.begin(), n_, FloatType(1)); + diag = &(*(scratch_array_.begin())); + } + for (SizeType i = 0; i < n_; i++) { + w[ispt + i] = -g[i] * diag[i]; + } + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm == FloatType(0)) return false; + stp1 = FloatType(1) / gnorm; + execute_entire_while_loop = true; + } + if (execute_entire_while_loop) { + bound = iter_; + iter_++; + info = 0; + if (iter_ != 1) { + if (iter_ > m_) bound = m_; + ys = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); + if (!diagco) { + FloatType yy = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); + std::fill_n(scratch_array_.begin(), n_, ys / yy); + diag = &(*(scratch_array_.begin())); + } + else { + iflag_ = 2; + requests_diag_ = true; + return true; + } + } + } + if (execute_entire_while_loop || iflag_ == 2) { + if (iter_ != 1) { + if (diag == 0) { + throw error_internal_error(__FILE__, __LINE__); + } + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + SizeType cp = point; + if (point == 0) cp = m_; + w[n_ + cp -1] = 1 / ys; + SizeType i; + for (i = 0; i < n_; i++) { + w[i] = -g[i]; + } + cp = point; + for (i = 0; i < bound; i++) { + if (cp == 0) cp = m_; + cp--; + FloatType sq = detail::ddot( + n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + SizeType inmc=n_+m_+cp; + SizeType iycn=iypt+cp*n_; + w[inmc] = w[n_ + cp] * sq; + detail::daxpy(n_, -w[inmc], w, iycn, w); + } + for (i = 0; i < n_; i++) { + w[i] *= diag[i]; + } + for (i = 0; i < bound; i++) { + FloatType yr = detail::ddot( + n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + FloatType beta = w[n_ + cp] * yr; + SizeType inmc=n_+m_+cp; + beta = w[inmc] - beta; + SizeType iscn=ispt+cp*n_; + detail::daxpy(n_, beta, w, iscn, w); + cp++; + if (cp == m_) cp = 0; + } + std::copy(w, w+n_, w+(ispt + point * n_)); + } + stp_ = FloatType(1); + if (iter_ == 1) stp_ = stp1; + std::copy(g, g+n_, w); + } + mcsrch_instance.run( + gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, + stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); + if (info == -1) { + iflag_ = 1; + requests_f_and_g_ = true; + return true; + } + if (info != 1) { + throw error_internal_error(__FILE__, __LINE__); + } + nfun_ += nfev; + npt = point*n_; + for (SizeType i = 0; i < n_; i++) { + w[ispt + npt + i] = stp_ * w[ispt + npt + i]; + w[iypt + npt + i] = g[i] - w[i]; + } + point++; + if (point == m_) point = 0; + return false; + } + + //! Traditional LBFGS convergence test. + /*! This convergence test is equivalent to the test embedded + in the <code>lbfgs.f</code> Fortran code. The test assumes that + there is a meaningful relation between the Euclidean norm of the + parameter vector <code>x</code> and the norm of the gradient + vector <code>g</code>. Therefore this test should not be used if + this assumption is not correct for a given problem. + */ + template <typename FloatType, typename SizeType = std::size_t> + class traditional_convergence_test + { + public: + //! Default constructor. + traditional_convergence_test() + : n_(0), eps_(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: <code>n > 0</code>. + + @param eps Determines the accuracy with which the solution + is to be found. + */ + explicit + traditional_convergence_test( + SizeType n, + FloatType eps = FloatType(1.e-5)) + : n_(n), eps_(eps) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (eps_ < FloatType(0)) { + throw error_improper_input_parameter("eps < 0."); + } + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + /*! \brief Accuracy with which the solution is to be found + (as passed to the constructor). + */ + FloatType eps() const { return eps_; } + + //! Execution of the convergence test for the given parameters. + /*! Returns <code>true</code> if + <pre> + ||g|| < eps * max(1,||x||), + </pre> + where <code>||.||</code> denotes the Euclidean norm. + + @param x Current solution vector. + + @param g Components of the gradient at the current + point <code>x</code>. + */ + bool + operator()(const FloatType* x, const FloatType* g) const + { + FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; + return false; + } + protected: + const SizeType n_; + const FloatType eps_; + }; + +}} // namespace scitbx::lbfgs + +template <typename T> +std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer<T>& min) { + return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); +} + + +#endif // SCITBX_LBFGS_H diff --git a/training/lbfgs_test.cc b/training/lbfgs_test.cc new file mode 100644 index 00000000..4171c118 --- /dev/null +++ b/training/lbfgs_test.cc @@ -0,0 +1,112 @@ +#include <cassert> +#include <iostream> +#include <sstream> +#include "lbfgs.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer() { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::minimizer<double> opt(3); + scitbx::lbfgs::traditional_convergence_test<double> converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt.run(x, obj, g); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + cerr << opt << endl; + } while (!converged(x, g)); + return obj; +} + +double TestPersistentOptimizer() { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::traditional_convergence_test<double> converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + scitbx::lbfgs::minimizer<double> opt(3); + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt.deserialize(&is); + } + opt.run(x, obj, g); + ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!converged(x, g)); + return obj; +} + +void TestSparseVector() { + cerr << "Testing SparseVector<double> serialization.\n"; + int f1 = FD::Convert("Feature_1"); + int f2 = FD::Convert("Feature_2"); + FD::Convert("LanguageModel"); + int f4 = FD::Convert("SomeFeature"); + int f5 = FD::Convert("SomeOtherFeature"); + SparseVector<double> g; + g.set_value(f2, log(0.5)); + g.set_value(f4, log(0.125)); + g.set_value(f1, 0); + g.set_value(f5, 23.777); + ostringstream os; + double iobj = 1.5; + B64::Encode(iobj, g, &os); + cerr << iobj << "\t" << g << endl; + string data = os.str(); + cout << data << endl; + SparseVector<double> v; + double obj; + assert(B64::Decode(&obj, &v, &data[0], data.size())); + cerr << obj << "\t" << v << endl; + assert(obj == iobj); + assert(g.num_active() == v.num_active()); +} + +int main() { + double o1 = TestOptimizer(); + double o2 = TestPersistentOptimizer(); + if (o1 != o2) { + cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + return 1; + } + TestSparseVector(); + cerr << "SUCCESS\n"; + return 0; +} + diff --git a/training/make-lexcrf-grammar.pl b/training/make-lexcrf-grammar.pl new file mode 100755 index 00000000..8cdf7718 --- /dev/null +++ b/training/make-lexcrf-grammar.pl @@ -0,0 +1,285 @@ +#!/usr/bin/perl -w +use utf8; +use strict; +my ($effile, $model1) = @ARGV; +die "Usage: $0 corpus.fr-en corpus.model1\n" unless $effile && -f $effile && $model1 && -f $model1; + +open EF, "<$effile" or die; +open M1, "<$model1" or die; +binmode(EF,":utf8"); +binmode(M1,":utf8"); +binmode(STDOUT,":utf8"); +my %model1; +while(<M1>) { + chomp; + my ($f, $e, $lp) = split /\s+/; + $model1{$f}->{$e} = $lp; +} + +my $ADD_MODEL1 = 0; # found that model1 hurts performance +my $IS_FRENCH_F = 1; # indicates that the f language is french +my $IS_ARABIC_F = 0; # indicates that the f language is arabic +my $IS_URDU_F = 0; # indicates that the f language is arabic +my $ADD_PREFIX_ID = 0; +my $ADD_LEN = 1; +my $ADD_SIM = 1; +my $ADD_DICE = 1; +my $ADD_111 = 1; +my $ADD_ID = 1; +my $ADD_PUNC = 1; +my $ADD_NUM_MM = 1; +my $ADD_NULL = 1; +my $ADD_STEM_ID = 1; +my $BEAM_RATIO = 50; + +my %fdict; +my %fcounts; +my %ecounts; + +my %sdict; + +while(<EF>) { + chomp; + my ($f, $e) = split /\s*\|\|\|\s*/; + my @es = split /\s+/, $e; + my @fs = split /\s+/, $f; + for my $ew (@es){ $ecounts{$ew}++; } + push @fs, '<eps>' if $ADD_NULL; + for my $fw (@fs){ $fcounts{$fw}++; } + for my $fw (@fs){ + for my $ew (@es){ + $fdict{$fw}->{$ew}++; + } + } +} + +print STDERR "Dice 0\n" if $ADD_DICE; +print STDERR "OneOneOne 0\nId_OneOneOne 0\n" if $ADD_111; +print STDERR "Identical 0\n" if $ADD_ID; +print STDERR "PuncMiss 0\n" if $ADD_PUNC; +print STDERR "IsNull 0\n" if $ADD_NULL; +print STDERR "Model1 0\n" if $ADD_MODEL1; +print STDERR "DLen 0\n" if $ADD_LEN; +print STDERR "NumMM 0\nNumMatch 0\n" if $ADD_NUM_MM; +print STDERR "OrthoSim 0\n" if $ADD_SIM; +print STDERR "PfxIdentical 0\n" if ($ADD_PREFIX_ID); +my $fc = 1000000; +my $sids = 1000000; +for my $f (sort keys %fdict) { + my $re = $fdict{$f}; + my $max; + for my $e (sort {$re->{$b} <=> $re->{$a}} keys %$re) { + my $efcount = $re->{$e}; + unless (defined $max) { $max = $efcount; } + my $m1 = $model1{$f}->{$e}; + unless (defined $m1) { next; } + $fc++; + my $dice = 2 * $efcount / ($ecounts{$e} + $fcounts{$f}); + my $feats = "F$fc=1"; + my $oe = $e; + my $of = $f; # normalized form + if ($IS_FRENCH_F) { + # see http://en.wikipedia.org/wiki/Use_of_the_circumflex_in_French + $of =~ s/â/as/g; + $of =~ s/ê/es/g; + $of =~ s/î/is/g; + $of =~ s/ô/os/g; + $of =~ s/û/us/g; + } elsif ($IS_ARABIC_F) { + if (length($of) > 1 && !($of =~ /\d/)) { + $of =~ s/\$/sh/g; + } + } elsif ($IS_URDU_F) { + if (length($of) > 1 && !($of =~ /\d/)) { + $of =~ s/\$/sh/g; + } + $oe =~ s/^-e-//; + $oe =~ s/^al-/al/; + $of =~ s/([a-z])\~/$1$1/g; + $of =~ s/E/'/g; + $of =~ s/^Aw/o/g; + $of =~ s/\|/a/g; + $of =~ s/@/h/g; + $of =~ s/c/ch/g; + $of =~ s/x/kh/g; + $of =~ s/\*/dh/g; + $of =~ s/w/o/g; + $of =~ s/Z/dh/g; + $of =~ s/y/i/g; + $of =~ s/Y/a/g; + $of = lc $of; + } + my $len_e = length($oe); + my $len_f = length($of); + $feats .= " Model1=$m1" if ($ADD_MODEL1); + $feats .= " Dice=$dice" if $ADD_DICE; + my $is_null = undef; + if ($ADD_NULL && $f eq '<eps>') { + $feats .= " IsNull=1"; + $is_null = 1; + } + if ($ADD_LEN) { + if (!$is_null) { + my $dlen = abs($len_e - $len_f); + $feats .= " DLen=$dlen"; + } + } + my $f_num = ($of =~ /^-?\d[0-9\.\,]+%?$/ && (length($of) > 3)); + my $e_num = ($oe =~ /^-?\d[0-9\.\,]+%?$/ && (length($oe) > 3)); + my $both_non_numeric = (!$e_num && !$f_num); + if ($ADD_NUM_MM && (($f_num && !$e_num) || ($e_num && !$f_num))) { + $feats .= " NumMM=1"; + } + if ($ADD_NUM_MM && ($f_num && $e_num) && ($oe eq $of)) { + $feats .= " NumMatch=1"; + } + if ($ADD_STEM_ID) { + my $el = 4; + my $fl = 4; + if ($oe =~ /^al|re|co/) { $el++; } + if ($of =~ /^al|re|co/) { $fl++; } + if ($oe =~ /^trans|inter/) { $el+=2; } + if ($of =~ /^trans|inter/) { $fl+=2; } + if ($fl > length($of)) { $fl = length($of); } + if ($el > length($oe)) { $el = length($oe); } + my $sf = substr $of, 0, $fl; + my $se = substr $oe, 0, $el; + my $id = $sdict{$sf}->{$se}; + if (!$id) { + $sids++; + $sdict{$sf}->{$se} = $sids; + $id = $sids; + print STDERR "S$sids 0\n" + } + $feats .= " S$id=1"; + } + if ($ADD_PREFIX_ID) { + if ($len_e > 3 && $len_f > 3 && $both_non_numeric) { + my $pe = substr $oe, 0, 3; + my $pf = substr $of, 0, 3; + if ($pe eq $pf) { $feats .= " PfxIdentical=1"; } + } + } + if ($ADD_SIM) { + my $ld = 0; + my $eff = $len_e; + if ($eff < $len_f) { $eff = $len_f; } + if (!$is_null) { + $ld = ($eff - levenshtein($oe, $of)) / sqrt($eff); + } + $feats .= " OrthoSim=$ld"; + } + my $ident = ($e eq $f); + if ($ident && $ADD_ID) { $feats .= " Identical=1"; } + if ($ADD_111 && ($efcount == 1 && $ecounts{$e} == 1 && $fcounts{$f} == 1)) { + if ($ident && $ADD_ID) { + $feats .= " Id_OneOneOne=1"; + } + $feats .= " OneOneOne=1"; + } + if ($ADD_PUNC) { + if (($f =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $e =~ /[a-z]+/) || + ($e =~ /^[0-9!\$%,\-\/"':;=+?.()«»]+$/ && $f =~ /[a-z]+/)) { + $feats .= " PuncMiss=1"; + } + } + my $r = (0.5 - rand)/5; + print STDERR "F$fc $r\n"; + print "$f ||| $e ||| $feats\n"; + } +} + +sub levenshtein +{ + # $s1 and $s2 are the two strings + # $len1 and $len2 are their respective lengths + # + my ($s1, $s2) = @_; + my ($len1, $len2) = (length $s1, length $s2); + + # If one of the strings is empty, the distance is the length + # of the other string + # + return $len2 if ($len1 == 0); + return $len1 if ($len2 == 0); + + my %mat; + + # Init the distance matrix + # + # The first row to 0..$len1 + # The first column to 0..$len2 + # The rest to 0 + # + # The first row and column are initialized so to denote distance + # from the empty string + # + for (my $i = 0; $i <= $len1; ++$i) + { + for (my $j = 0; $j <= $len2; ++$j) + { + $mat{$i}{$j} = 0; + $mat{0}{$j} = $j; + } + + $mat{$i}{0} = $i; + } + + # Some char-by-char processing is ahead, so prepare + # array of chars from the strings + # + my @ar1 = split(//, $s1); + my @ar2 = split(//, $s2); + + for (my $i = 1; $i <= $len1; ++$i) + { + for (my $j = 1; $j <= $len2; ++$j) + { + # Set the cost to 1 iff the ith char of $s1 + # equals the jth of $s2 + # + # Denotes a substitution cost. When the char are equal + # there is no need to substitute, so the cost is 0 + # + my $cost = ($ar1[$i-1] eq $ar2[$j-1]) ? 0 : 1; + + # Cell $mat{$i}{$j} equals the minimum of: + # + # - The cell immediately above plus 1 + # - The cell immediately to the left plus 1 + # - The cell diagonally above and to the left plus the cost + # + # We can either insert a new char, delete a char or + # substitute an existing char (with an associated cost) + # + $mat{$i}{$j} = min([$mat{$i-1}{$j} + 1, + $mat{$i}{$j-1} + 1, + $mat{$i-1}{$j-1} + $cost]); + } + } + + # Finally, the Levenshtein distance equals the rightmost bottom cell + # of the matrix + # + # Note that $mat{$x}{$y} denotes the distance between the substrings + # 1..$x and 1..$y + # + return $mat{$len1}{$len2}; +} + + +# minimal element of a list +# +sub min +{ + my @list = @{$_[0]}; + my $min = $list[0]; + + foreach my $i (@list) + { + $min = $i if ($i < $min); + } + + return $min; +} + diff --git a/training/model1.cc b/training/model1.cc new file mode 100644 index 00000000..f571700f --- /dev/null +++ b/training/model1.cc @@ -0,0 +1,103 @@ +#include <iostream> + +#include "lattice.h" +#include "stringlib.h" +#include "filelib.h" +#include "ttables.h" +#include "tdict.h" + +using namespace std; + +int main(int argc, char** argv) { + if (argc != 2) { + cerr << "Usage: " << argv[0] << " corpus.fr-en\n"; + return 1; + } + const int ITERATIONS = 5; + const prob_t BEAM_THRESHOLD(0.0001); + TTable tt; + const WordID kNULL = TD::Convert("<eps>"); + bool use_null = true; + TTable::Word2Word2Double was_viterbi; + for (int iter = 0; iter < ITERATIONS; ++iter) { + const bool final_iteration = (iter == (ITERATIONS - 1)); + cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; + ReadFile rf(argv[1]); + istream& in = *rf.stream(); + prob_t likelihood = prob_t::One(); + double denom = 0.0; + int lc = 0; + bool flag = false; + while(true) { + string line; + getline(in, line); + if (!in) break; + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } + string ssrc, strg; + ParseTranslatorInput(line, &ssrc, &strg); + Lattice src, trg; + LatticeTools::ConvertTextToLattice(ssrc, &src); + LatticeTools::ConvertTextToLattice(strg, &trg); + assert(src.size() > 0); + assert(trg.size() > 0); + denom += 1.0; + vector<prob_t> probs(src.size() + 1); + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j][0].label; + prob_t sum = prob_t::Zero(); + if (use_null) { + probs[0] = tt.prob(kNULL, f_j); + sum += probs[0]; + } + for (int i = 1; i <= src.size(); ++i) { + probs[i] = tt.prob(src[i-1][0].label, f_j); + sum += probs[i]; + } + if (final_iteration) { + WordID max_i = 0; + prob_t max_p = prob_t::Zero(); + if (use_null) { + max_i = kNULL; + max_p = probs[0]; + } + for (int i = 1; i <= src.size(); ++i) { + if (probs[i] > max_p) { + max_p = probs[i]; + max_i = src[i-1][0].label; + } + } + was_viterbi[max_i][f_j] = 1.0; + } else { + if (use_null) + tt.Increment(kNULL, f_j, probs[0] / sum); + for (int i = 1; i <= src.size(); ++i) + tt.Increment(src[i-1][0].label, f_j, probs[i] / sum); + } + likelihood *= sum; + } + } + if (flag) { cerr << endl; } + cerr << " log likelihood: " << log(likelihood) << endl; + cerr << " cross entopy: " << (-log(likelihood) / denom) << endl; + cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl; + if (!final_iteration) tt.Normalize(); + } + for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { + const TTable::Word2Double& cpd = ei->second; + const TTable::Word2Double& vit = was_viterbi[ei->first]; + const string& esym = TD::Convert(ei->first); + prob_t max_p = prob_t::Zero(); + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) + if (fi->second > max_p) max_p = prob_t(fi->second); + const prob_t threshold = max_p * BEAM_THRESHOLD; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { + if (fi->second > threshold || (vit.count(fi->first) > 0)) { + cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; + } + } + } + return 0; +} + diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc new file mode 100644 index 00000000..52387e7f --- /dev/null +++ b/training/mr_em_adapted_reduce.cc @@ -0,0 +1,194 @@ +#include <iostream> +#include <vector> +#include <cassert> +#include <cmath> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "config.h" +#ifdef HAVE_BOOST_DIGAMMA +#include <boost/math/special_functions/digamma.hpp> +using boost::math::digamma; +#endif + +#include "filelib.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +#ifndef HAVE_BOOST_DIGAMMA +#warning Using Mark Johnsons digamma() +double digamma(double x) { + double result = 0, xx, xx2, xx4; + assert(x > 0); + for ( ; x < 7; ++x) + result -= 1/x; + x -= 1.0/2.0; + xx = 1.0/x; + xx2 = xx*xx; + xx4 = xx2*xx2; + result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; + return result; +} +#endif + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("optimization_method,m", po::value<string>()->default_value("em"), "Optimization method (em, vb)") + ("input_format,f",po::value<string>()->default_value("b64"),"Encoding of the input (b64 or text)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +double NoZero(const double& x) { + if (x) return x; + return 1e-35; +} + +void Maximize(const bool use_vb, + const double& alpha, + const int total_event_types, + SparseVector<double>* pc) { + const SparseVector<double>& counts = *pc; + + if (use_vb) + assert(total_event_types >= counts.num_active()); + + double tot = 0; + for (SparseVector<double>::const_iterator it = counts.begin(); + it != counts.end(); ++it) + tot += it->second; +// cerr << " = " << tot << endl; + assert(tot > 0.0); + double ltot = log(tot); + if (use_vb) + ltot = digamma(tot + total_event_types * alpha); + for (SparseVector<double>::const_iterator it = counts.begin(); + it != counts.end(); ++it) { + if (use_vb) { + pc->set_value(it->first, NoZero(digamma(it->second + alpha) - ltot)); + } else { + pc->set_value(it->first, NoZero(log(it->second) - ltot)); + } + } +#if 0 + if (counts.num_active() < 50) { + for (SparseVector<double>::const_iterator it = counts.begin(); + it != counts.end(); ++it) { + cerr << " p(" << FD::Convert(it->first) << ")=" << exp(it->second); + } + cerr << endl; + } +#endif +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["input_format"].as<string>() == "b64"; + const bool use_vb = conf["optimization_method"].as<string>() == "vb"; + const double alpha = 1e-09; + if (use_vb) + cerr << "Using variational Bayes, make sure alphas are set\n"; + + const string s_obj = "**OBJ**"; + // E-step + string cur_key = ""; + SparseVector<double> acc; + double logprob = 0; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + int feat; + double val; + size_t i = line.find("\t"); + const string key = line.substr(0, i); + assert(i != string::npos); + ++i; + if (key != cur_key) { + if (cur_key.size() > 0) { + // TODO shouldn't be num_active, should be total number + // of events + Maximize(use_vb, alpha, acc.num_active(), &acc); + cout << cur_key << '\t'; + if (use_b64) + B64::Encode(0.0, acc, &cout); + else + cout << acc; + cout << endl; + acc.clear(); + } + cur_key = key; + } + if (use_b64) { + SparseVector<double> g; + double obj; + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping!\n"; + continue; + } + logprob += obj; + acc += g; + } else { // text encoding - your counts will not be accurate! + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat == -1) { + logprob += val; + } else { + acc.add_value(feat, val); + } + } + } + } + // TODO shouldn't be num_active, should be total number + // of events + Maximize(use_vb, alpha, acc.num_active(), &acc); + cout << cur_key << '\t'; + if (use_b64) + B64::Encode(0.0, acc, &cout); + else + cout << acc; + cout << endl << flush; + + cerr << "LOGPROB: " << logprob << endl; + + return 0; +} diff --git a/training/mr_em_map_adapter.cc b/training/mr_em_map_adapter.cc new file mode 100644 index 00000000..a98e1b77 --- /dev/null +++ b/training/mr_em_map_adapter.cc @@ -0,0 +1,160 @@ +#include <iostream> +#include <fstream> +#include <cassert> +#include <cmath> + +#include <boost/utility.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include "boost/tuple/tuple.hpp" + +#include "fdict.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +// useful for EM models parameterized by a bunch of multinomials +// this converts event counts (returned from cdec as feature expectations) +// into different keys and values (which are lists of all the events, +// conditioned on the key) for summing and normalization by a reducer + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("buffer_size,b", po::value<int>()->default_value(1), "Buffer size (in # of counts) before emitting counts") + ("format,f",po::value<string>()->default_value("b64"), "Encoding of the input (b64 or text)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct EventMapper { + int Map(int fid) { + int& cv = map_[fid]; + if (!cv) { + cv = GetConditioningVariable(fid); + } + return cv; + } + void Clear() { map_.clear(); } + protected: + virtual int GetConditioningVariable(int fid) const = 0; + private: + map<int, int> map_; +}; + +struct LexAlignEventMapper : public EventMapper { + protected: + virtual int GetConditioningVariable(int fid) const { + const string& str = FD::Convert(fid); + size_t pos = str.rfind("_"); + if (pos == string::npos || pos == 0 || pos >= str.size() - 1) { + cerr << "Bad feature for EM adapter: " << str << endl; + abort(); + } + return FD::Convert(str.substr(0, pos)); + } +}; + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["format"].as<string>() == "b64"; + const int buffer_size = conf["buffer_size"].as<int>(); + + const string s_obj = "**OBJ**"; + // 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; + // 0<TAB>**OBJ**=1.1;Feat1=1.0; + + EventMapper* event_mapper = new LexAlignEventMapper; + map<int, SparseVector<double> > counts; + size_t total = 0; + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + int feat; + double val; + size_t i = line.find("\t"); + assert(i != string::npos); + ++i; + SparseVector<double> g; + double obj = 0; + if (use_b64) { + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping!\n"; + continue; + } + } else { // text encoding - your counts will not be accurate! + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat == -1) { + obj = val; + } else { + g.set_value(feat, val); + } + } + } + //cerr << "OBJ: " << obj << endl; + const SparseVector<double>& cg = g; + for (SparseVector<double>::const_iterator it = cg.begin(); it != cg.end(); ++it) { + const int cond_var = event_mapper->Map(it->first); + SparseVector<double>& cond_counts = counts[cond_var]; + int delta = cond_counts.num_active(); + cond_counts.add_value(it->first, it->second); + delta = cond_counts.num_active() - delta; + total += delta; + } + if (total > buffer_size) { + for (map<int, SparseVector<double> >::iterator it = counts.begin(); + it != counts.end(); ++it) { + const SparseVector<double>& cc = it->second; + cout << FD::Convert(it->first) << '\t'; + if (use_b64) { + B64::Encode(0.0, cc, &cout); + } else { + abort(); + } + cout << endl; + } + cout << flush; + total = 0; + counts.clear(); + } + } + + return 0; +} + diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc new file mode 100644 index 00000000..42727ecb --- /dev/null +++ b/training/mr_optimize_reduce.cc @@ -0,0 +1,243 @@ +#include <sstream> +#include <iostream> +#include <fstream> +#include <vector> +#include <cassert> +#include <cmath> + +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +using boost::shared_ptr; +namespace po = boost::program_options; + +void SanityCheck(const vector<double>& w) { + for (int i = 0; i < w.size(); ++i) { + assert(!isnan(w[i])); + assert(!isinf(w[i])); + } +} + +struct FComp { + const vector<double>& w_; + FComp(const vector<double>& w) : w_(w) {} + bool operator()(int a, int b) const { + return fabs(w_[a]) > fabs(w_[b]); + } +}; + +void ShowLargestFeatures(const vector<double>& w) { + vector<int> fnums(w.size()); + for (int i = 0; i < w.size(); ++i) + fnums[i] = i; + vector<int>::iterator mid = fnums.begin(); + mid += (w.size() > 10 ? 10 : w.size()); + partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); + cerr << "TOP FEATURES:"; + for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) { + cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; + } + cerr << endl; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,i",po::value<string>(),"Input feature weights file") + ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file") + ("optimization_method,m", po::value<string>()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") + ("state,s",po::value<string>(),"Read (and write if output_state is not set) optimizer state from this state file. In the first iteration, the file should not exist.") + ("input_format,f",po::value<string>()->default_value("b64"),"Encoding of the input (b64 or text)") + ("output_state,S", po::value<string>(), "Output state file (optional override)") + ("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory") + ("eta,e", po::value<double>()->default_value(0.1), "Learning rate for SGD (eta)") + ("gaussian_prior,p","Use a Gaussian prior on the weights") + ("means,u", po::value<string>(), "File containing the means for Gaussian prior") + ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !conf->count("state")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["input_format"].as<string>() == "b64"; + + Weights weights; + weights.InitFromFile(conf["input_weights"].as<string>()); + const string s_obj = "**OBJ**"; + int num_feats = FD::NumFeats(); + cerr << "Number of features: " << num_feats << endl; + const bool gaussian_prior = conf.count("gaussian_prior"); + vector<double> means(num_feats, 0); + if (conf.count("means")) { + if (!gaussian_prior) { + cerr << "Don't use --means without --gaussian_prior!\n"; + exit(1); + } + Weights wm; + wm.InitFromFile(conf["means"].as<string>()); + if (num_feats != FD::NumFeats()) { + cerr << "[ERROR] Means file had unexpected features!\n"; + exit(1); + } + wm.InitVector(&means); + } + shared_ptr<Optimizer> o; + const string omethod = conf["optimization_method"].as<string>(); + if (omethod == "sgd") + o.reset(new SGDOptimizer(conf["eta"].as<double>())); + else if (omethod == "rprop") + o.reset(new RPropOptimizer(num_feats)); // TODO add configuration + else + o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as<int>())); + cerr << "Optimizer: " << o->Name() << endl; + string state_file = conf["state"].as<string>(); + { + ifstream in(state_file.c_str(), ios::binary); + if (in) + o->Load(&in); + else + cerr << "No state file found, assuming ITERATION 1\n"; + } + + vector<double> lambdas(num_feats, 0); + weights.InitVector(&lambdas); + double objective = 0; + vector<double> gradient(num_feats, 0); + // 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; + // 0<TAB>**OBJ**=1.1;Feat1=1.0; + int total_lines = 0; // TODO - this should be a count of the + // training instances!! + while(cin) { + string line; + getline(cin, line); + if (line.empty()) continue; + ++total_lines; + int feat; + double val; + size_t i = line.find("\t"); + assert(i != string::npos); + ++i; + if (use_b64) { + SparseVector<double> g; + double obj; + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping gradient!\n"; + cerr << " START: " << line.substr(0,line.size() > 200 ? 200 : line.size()) << endl; + if (line.size() > 200) + cerr << " END: " << line.substr(line.size() - 200, 200) << endl; + cout << "-1\tRESTART\n"; + exit(99); + } + objective += obj; + const SparseVector<double>& cg = g; + for (SparseVector<double>::const_iterator it = cg.begin(); it != cg.end(); ++it) { + if (it->first >= num_feats) { + cerr << "Unexpected feature in gradient: " << FD::Convert(it->first) << endl; + abort(); + } + gradient[it->first] -= it->second; + } + } else { // text encoding - your gradients will not be accurate! + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + if (feat >= num_feats) { + cerr << "Unexpected feature in gradient: " << line.substr(start, i - start) << endl; + abort(); + } + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat == -1) { + objective += val; + } else { + gradient[feat] -= val; + } + } + } + } + + if (gaussian_prior) { + const double sigsq = conf["sigma_squared"].as<double>(); + double norm = 0; + for (int k = 1; k < lambdas.size(); ++k) { + const double& lambda_k = lambdas[k]; + if (lambda_k) { + const double param = (lambda_k - means[k]); + norm += param * param; + gradient[k] += param / sigsq; + } + } + const double reg = norm / (2.0 * sigsq); + cerr << "REGULARIZATION TERM: " << reg << endl; + objective += reg; + } + cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; + double gnorm = 0; + for (int i = 0; i < gradient.size(); ++i) + gnorm += gradient[i] * gradient[i]; + cerr << " GNORM=" << sqrt(gnorm) << endl; + vector<double> old = lambdas; + int c = 0; + while (old == lambdas) { + ++c; + if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } + o->Optimize(objective, gradient, &lambdas); + assert(c < 5); + } + old.clear(); + SanityCheck(lambdas); + ShowLargestFeatures(lambdas); + weights.InitFromVector(lambdas); + weights.WriteToFile(conf["output_weights"].as<string>(), false); + + const bool conv = o->HasConverged(); + if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + + if (conf.count("output_state")) + state_file = conf["output_state"].as<string>(); + ofstream out(state_file.c_str(), ios::binary); + cerr << "Writing state to: " << state_file << endl; + o->Save(&out); + out.close(); + + cout << o->EvaluationCount() << "\t" << conv << endl; + return 0; +} diff --git a/training/mr_reduce_to_weights.cc b/training/mr_reduce_to_weights.cc new file mode 100644 index 00000000..16b47720 --- /dev/null +++ b/training/mr_reduce_to_weights.cc @@ -0,0 +1,109 @@ +#include <iostream> +#include <fstream> +#include <vector> +#include <cassert> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "filelib.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_format,f",po::value<string>()->default_value("b64"),"Encoding of the input (b64 or text)") + ("input,i",po::value<string>()->default_value("-"),"Read file from") + ("output,o",po::value<string>()->default_value("-"),"Write weights to"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void WriteWeights(const SparseVector<double>& weights, ostream* out) { + for (SparseVector<double>::const_iterator it = weights.begin(); + it != weights.end(); ++it) { + (*out) << FD::Convert(it->first) << " " << it->second << endl; + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + const bool use_b64 = conf["input_format"].as<string>() == "b64"; + + const string s_obj = "**OBJ**"; + // E-step + ReadFile rf(conf["input"].as<string>()); + istream* in = rf.stream(); + assert(*in); + WriteFile wf(conf["output"].as<string>()); + ostream* out = wf.stream(); + out->precision(17); + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + int feat; + double val; + size_t i = line.find("\t"); + assert(i != string::npos); + ++i; + if (use_b64) { + SparseVector<double> g; + double obj; + if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { + cerr << "B64 decoder returned error, skipping!\n"; + continue; + } + WriteWeights(g, out); + } else { // text encoding - your counts will not be accurate! + SparseVector<double> weights; + while (i < line.size()) { + size_t start = i; + while (line[i] != '=' && i < line.size()) ++i; + if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } + string fname = line.substr(start, i - start); + if (fname == s_obj) { + feat = -1; + } else { + feat = FD::Convert(line.substr(start, i - start)); + } + ++i; + start = i; + while (line[i] != ';' && i < line.size()) ++i; + if (i - start == 0) continue; + val = atof(line.substr(start, i - start).c_str()); + ++i; + if (feat != -1) { + weights.set_value(feat, val); + } + } + WriteWeights(weights, out); + } + } + + return 0; +} diff --git a/training/optimize.cc b/training/optimize.cc new file mode 100644 index 00000000..5194752e --- /dev/null +++ b/training/optimize.cc @@ -0,0 +1,114 @@ +#include "optimize.h" + +#include <iostream> +#include <cassert> + +#include "lbfgs.h" + +using namespace std; + +Optimizer::~Optimizer() {} + +void Optimizer::Save(ostream* out) const { + out->write((const char*)&eval_, sizeof(eval_)); + out->write((const char*)&has_converged_, sizeof(has_converged_)); + SaveImpl(out); + unsigned int magic = 0xABCDDCBA; // should be uint32_t + out->write((const char*)&magic, sizeof(magic)); +} + +void Optimizer::Load(istream* in) { + in->read((char*)&eval_, sizeof(eval_)); + ++eval_; + in->read((char*)&has_converged_, sizeof(has_converged_)); + LoadImpl(in); + unsigned int magic = 0; // should be uint32_t + in->read((char*)&magic, sizeof(magic)); + assert(magic == 0xABCDDCBA); + cerr << Name() << " EVALUATION #" << eval_ << endl; +} + +void Optimizer::SaveImpl(ostream* out) const { + (void)out; +} + +void Optimizer::LoadImpl(istream* in) { + (void)in; +} + +string RPropOptimizer::Name() const { + return "RPropOptimizer"; +} + +void RPropOptimizer::OptimizeImpl(const double& obj, + const vector<double>& g, + vector<double>* x) { + for (int i = 0; i < g.size(); ++i) { + const double g_i = g[i]; + const double sign_i = (signbit(g_i) ? -1.0 : 1.0); + const double prod = g_i * prev_g_[i]; + if (prod > 0.0) { + const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); + (*x)[i] -= dij * sign_i; + delta_ij_[i] = dij; + prev_g_[i] = g_i; + } else if (prod < 0.0) { + delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); + prev_g_[i] = 0.0; + } else { + (*x)[i] -= delta_ij_[i] * sign_i; + prev_g_[i] = g_i; + } + } +} + +void RPropOptimizer::SaveImpl(ostream* out) const { + const size_t n = prev_g_.size(); + out->write((const char*)&n, sizeof(n)); + out->write((const char*)&prev_g_[0], sizeof(double) * n); + out->write((const char*)&delta_ij_[0], sizeof(double) * n); +} + +void RPropOptimizer::LoadImpl(istream* in) { + size_t n; + in->read((char*)&n, sizeof(n)); + assert(n == prev_g_.size()); + assert(n == delta_ij_.size()); + in->read((char*)&prev_g_[0], sizeof(double) * n); + in->read((char*)&delta_ij_[0], sizeof(double) * n); +} + +string SGDOptimizer::Name() const { + return "SGDOptimizer"; +} + +void SGDOptimizer::OptimizeImpl(const double& obj, + const vector<double>& g, + vector<double>* x) { + (void)obj; + for (int i = 0; i < g.size(); ++i) + (*x)[i] -= g[i] * eta_; +} + +string LBFGSOptimizer::Name() const { + return "LBFGSOptimizer"; +} + +LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : + opt_(num_feats, memory_buffers) {} + +void LBFGSOptimizer::SaveImpl(ostream* out) const { + opt_.serialize(out); +} + +void LBFGSOptimizer::LoadImpl(istream* in) { + opt_.deserialize(in); +} + +void LBFGSOptimizer::OptimizeImpl(const double& obj, + const vector<double>& g, + vector<double>* x) { + opt_.run(&(*x)[0], obj, &g[0]); + cerr << opt_ << endl; +} + diff --git a/training/optimize.h b/training/optimize.h new file mode 100644 index 00000000..eddceaad --- /dev/null +++ b/training/optimize.h @@ -0,0 +1,104 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include <iostream> +#include <vector> +#include <string> +#include <cassert> + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class Optimizer { + public: + Optimizer() : eval_(1), has_converged_(false) {} + virtual ~Optimizer(); + virtual std::string Name() const = 0; + int EvaluationCount() const { return eval_; } + bool HasConverged() const { return has_converged_; } + + void Optimize(const double& obj, + const std::vector<double>& g, + std::vector<double>* x) { + assert(g.size() == x->size()); + OptimizeImpl(obj, g, x); + scitbx::lbfgs::traditional_convergence_test<double> converged(g.size()); + has_converged_ = converged(&(*x)[0], &g[0]); + } + + void Save(std::ostream* out) const; + void Load(std::istream* in); + protected: + virtual void SaveImpl(std::ostream* out) const; + virtual void LoadImpl(std::istream* in); + virtual void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x) = 0; + + int eval_; + private: + bool has_converged_; +}; + +class RPropOptimizer : public Optimizer { + public: + explicit RPropOptimizer(int num_vars, + double eta_plus = 1.2, + double eta_minus = 0.5, + double delta_0 = 0.1, + double delta_max = 50.0, + double delta_min = 1e-6) : + prev_g_(num_vars, 0.0), + delta_ij_(num_vars, delta_0), + eta_plus_(eta_plus), + eta_minus_(eta_minus), + delta_max_(delta_max), + delta_min_(delta_min) { + assert(eta_plus > 1.0); + assert(eta_minus > 0.0 && eta_minus < 1.0); + assert(delta_max > 0.0); + assert(delta_min > 0.0); + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + private: + std::vector<double> prev_g_; + std::vector<double> delta_ij_; + const double eta_plus_; + const double eta_minus_; + const double delta_max_; + const double delta_min_; +}; + +class SGDOptimizer : public Optimizer { + public: + explicit SGDOptimizer(int num_vars, double eta = 0.1) : eta_(eta) { + (void) num_vars; + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + private: + const double eta_; +}; + +class LBFGSOptimizer : public Optimizer { + public: + explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); + std::string Name() const; + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + void OptimizeImpl(const double& obj, + const std::vector<double>& g, + std::vector<double>* x); + private: + scitbx::lbfgs::minimizer<double> opt_; +}; + +#endif diff --git a/training/optimize_test.cc b/training/optimize_test.cc new file mode 100644 index 00000000..0ada7cbb --- /dev/null +++ b/training/optimize_test.cc @@ -0,0 +1,105 @@ +#include <cassert> +#include <iostream> +#include <sstream> +#include <boost/program_options/variables_map.hpp> +#include "optimize.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer(Optimizer* opt) { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector<double> x(3); + vector<double> g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt->Optimize(obj, g, &x); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!opt->HasConverged()); + return obj; +} + +double TestPersistentOptimizer(Optimizer* opt) { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector<double> x(3); + vector<double> g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + bool converged = false; + while (!converged) { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt->Load(&is); + } + opt->Optimize(obj, g, &x); + ostringstream os(ios::binary); opt->Save(&os); state = os.str(); + + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + converged = opt->HasConverged(); + if (!converged) { + // now screw up the state (should be undone by Load) + obj += 2.0; + g[1] = -g[2]; + vector<double> x2 = x; + try { + opt->Optimize(obj, g, &x2); + } catch (...) { } + } + } + return obj; +} + +template <class O> +void TestOptimizerVariants(int num_vars) { + O oa(num_vars); + cerr << "-------------------------------------------------------------------------\n"; + cerr << "TESTING: " << oa.Name() << endl; + double o1 = TestOptimizer(&oa); + O ob(num_vars); + double o2 = TestPersistentOptimizer(&ob); + if (o1 != o2) { + cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + exit(1); + } + cerr << oa.Name() << " SUCCESS\n"; +} + +int main() { + int n = 3; + TestOptimizerVariants<SGDOptimizer>(n); + TestOptimizerVariants<LBFGSOptimizer>(n); + TestOptimizerVariants<RPropOptimizer>(n); + return 0; +} + diff --git a/training/plftools.cc b/training/plftools.cc new file mode 100644 index 00000000..903ec54f --- /dev/null +++ b/training/plftools.cc @@ -0,0 +1,93 @@ +#include <iostream> +#include <fstream> +#include <vector> + +#include <boost/lexical_cast.hpp> +#include <boost/program_options.hpp> + +#include "filelib.h" +#include "tdict.h" +#include "prob.h" +#include "hg.h" +#include "hg_io.h" +#include "viterbi.h" +#include "kbest.h" + +namespace po = boost::program_options; +using namespace std; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value<string>(), "REQ. Lattice input file (PLF), - for STDIN") + ("prior_scale,p", po::value<double>()->default_value(1.0), "Scale path probabilities by this amount < 1 flattens, > 1 sharpens") + ("weight,w", po::value<vector<double> >(), "Weight(s) for arc features") + ("output,o", po::value<string>()->default_value("plf"), "Output format (text, plf)") + ("command,c", po::value<string>()->default_value("push"), "Operation to perform: push, graphviz, 1best, 2best ...") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char **argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as<string>(); + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + SparseVector<double> wts; + vector<double> wv; + if (conf.count("weight") > 0) wv = conf["weight"].as<vector<double> >(); + if (wv.empty()) wv.push_back(1.0); + for (int i = 0; i < wv.size(); ++i) { + const string fname = "Feature_" + boost::lexical_cast<string>(i); + cerr << "[INFO] Arc weight " << (i+1) << " = " << wv[i] << endl; + wts.set_value(FD::Convert(fname), wv[i]); + } + const string cmd = conf["command"].as<string>(); + const bool push_weights = cmd == "push"; + const bool output_plf = cmd == "plf"; + const bool graphviz = cmd == "graphviz"; + const bool kbest = cmd.rfind("best") == (cmd.size() - 4) && cmd.size() > 4; + int k = 1; + if (kbest) { + k = boost::lexical_cast<int>(cmd.substr(0, cmd.size() - 4)); + cerr << "KBEST = " << k << endl; + } + const double scale = conf["prior_scale"].as<double>(); + int lc = 0; + while(*in) { + ++lc; + string plf; + getline(*in, plf); + if (plf.empty()) continue; + Hypergraph hg; + HypergraphIO::ReadFromPLF(plf, &hg); + hg.Reweight(wts); + if (graphviz) hg.PrintGraphviz(); + if (push_weights) hg.PushWeightsToSource(scale); + if (output_plf) { + cout << HypergraphIO::AsPLF(hg) << endl; + } else { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cout << lc << " ||| " << TD::GetString(d->yield) << " ||| " << d->score << endl; + } + } + } + return 0; +} + |