diff options
author | Chris Dyer <redpony@gmail.com> | 2014-02-09 20:50:41 -0500 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2014-02-09 20:50:41 -0500 |
commit | 9b83a2e82aba73b5ff0e848182a8726481a10485 (patch) | |
tree | 81e7ade548bdffaa8534705c2d34beb8c752dc24 /training/mira/ada_opt_sm.cc | |
parent | 702591b3296af472cc5c7c4720f1c21b2a6e34b1 (diff) |
adaptive hope-fear learner
Diffstat (limited to 'training/mira/ada_opt_sm.cc')
-rw-r--r-- | training/mira/ada_opt_sm.cc | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/training/mira/ada_opt_sm.cc b/training/mira/ada_opt_sm.cc new file mode 100644 index 00000000..18ddbf8f --- /dev/null +++ b/training/mira/ada_opt_sm.cc @@ -0,0 +1,198 @@ +#include "config.h" + +#include <boost/container/flat_map.hpp> +#include <boost/shared_ptr.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "sparse_vector.h" +#include "candidate_set.h" +#include "sentence_metadata.h" +#include "ns.h" +#include "ns_docscorer.h" +#include "verbose.h" +#include "hg.h" +#include "ff_register.h" +#include "decoder.h" +#include "fdict.h" +#include "sampler.h" + +using namespace std; +namespace po = boost::program_options; + +boost::shared_ptr<MT19937> rng; +vector<training::CandidateSet> kbests; +SparseVector<weight_t> G, u, lambdas; +double pseudo_doc_decay = 0.9; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file") + ("devset,d",po::value<string>(),"[REQ] Source/reference development set") + ("weights,w",po::value<string>(),"Initial feature weights file") + ("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") + ("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)") + ("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)") + ("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)") + ("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)") + ("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)") + ("eta0", po::value<double>()->default_value(0.1), "Initial step size") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") + ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount") + ("pseudo_doc,e", "Use pseudo-documents for approximate scoring") + ("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,H", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") + || !conf->count("decoder_config") + || !conf->count("devset")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +struct TrainingObserver : public DecoderObserver { + explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {} + + const EvaluationMetric& metric; + const int kbest_size; + const SegmentEvaluator* cur_eval; + SufficientStats pdoc; + unsigned hi, vi, fi; // hope, viterbi, fear + + void SetSegmentEvaluator(const SegmentEvaluator* eval) { + cur_eval = eval; + } + + virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) { + cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl; + abort(); + } + + unsigned CostAugmentedDecode(const training::CandidateSet& cs, + const SparseVector<double>& w, + double alpha = 0) { + unsigned best_i = 0; + double best = -numeric_limits<double>::infinity(); + for (unsigned i = 0; i < cs.size(); ++i) { + double s = cs[i].fmap.dot(w); + if (alpha) + s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc); + if (s > best) { + best = s; + best_i = i; + } + } + return best_i; + } + + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + pdoc *= pseudo_doc_decay; + const unsigned sent_id = smeta.GetSentenceID(); + kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval); + vi = CostAugmentedDecode(kbests[sent_id], lambdas); + hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0); + fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0); + cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl; + pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats + } +}; + +int main(int argc, char** argv) { + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); + else + rng.reset(new MT19937); + + string metric_name = UppercaseString(conf["mt_metric"].as<string>()); + if (metric_name == "COMBI") { + cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n"; + metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5"; + } else if (metric_name == "BLEU") { + cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n"; + metric_name = "IBM_BLEU"; + } + EvaluationMetric* metric = EvaluationMetric::Instance(metric_name); + DocumentScorer ds(metric, conf["devset"].as<string>()); + cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; + kbests.resize(ds.size()); + double eta = 0.001; + + ReadFile ini_rf(conf["decoder_config"].as<string>()); + Decoder decoder(ini_rf.stream()); + + vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) { + Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + } + + TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>()); + + unsigned num = 200; + for (unsigned iter = 1; iter < num; ++iter) { + lambdas.init_vector(&dense_weights); + unsigned sent_id = rng->next() * ds.size(); + cerr << "Learning from sentence id: " << sent_id << endl; + observer.SetSegmentEvaluator(ds[sent_id]); + decoder.SetId(sent_id); + decoder.Decode(ds[sent_id]->src, &observer); + if (observer.vi != observer.hi) { // viterbi != hope + SparseVector<double> grad = kbests[sent_id][observer.fi].fmap; + grad -= kbests[sent_id][observer.hi].fmap; + cerr << "GRAD: " << grad << endl; + const SparseVector<double>& g = grad; +#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4) + for (auto& gi : g) { +#else + for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) { + const pair<unsigned,double>& gi = *it; +#endif + if (gi.second) { + u[gi.first] += gi.second; + G[gi.first] += gi.second * gi.second; + lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation + } + } + for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + const pair<unsigned,double>& xi = *it; + double z = fabs(u[xi.first] / iter) - 0.0; + double s = 1; + if (u[xi.first] > 0) s = -1; + if (z > 0 && G[xi.first]) { + lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first])); + } else { + lambdas.set_value(xi.first, 0.0); + } + } + } + } + cerr << "Optimization complete.\n"; + Weights::WriteToFile("-", dense_weights, true); + return 0; +} + |