summaryrefslogtreecommitdiff
path: root/training/mira/ada_opt_sm.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-02-09 20:50:41 -0500
committerChris Dyer <redpony@gmail.com>2014-02-09 20:50:41 -0500
commit9b83a2e82aba73b5ff0e848182a8726481a10485 (patch)
tree81e7ade548bdffaa8534705c2d34beb8c752dc24 /training/mira/ada_opt_sm.cc
parent702591b3296af472cc5c7c4720f1c21b2a6e34b1 (diff)
adaptive hope-fear learner
Diffstat (limited to 'training/mira/ada_opt_sm.cc')
-rw-r--r--training/mira/ada_opt_sm.cc198
1 files changed, 198 insertions, 0 deletions
diff --git a/training/mira/ada_opt_sm.cc b/training/mira/ada_opt_sm.cc
new file mode 100644
index 00000000..18ddbf8f
--- /dev/null
+++ b/training/mira/ada_opt_sm.cc
@@ -0,0 +1,198 @@
+#include "config.h"
+
+#include <boost/container/flat_map.hpp>
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "weights.h"
+#include "sparse_vector.h"
+#include "candidate_set.h"
+#include "sentence_metadata.h"
+#include "ns.h"
+#include "ns_docscorer.h"
+#include "verbose.h"
+#include "hg.h"
+#include "ff_register.h"
+#include "decoder.h"
+#include "fdict.h"
+#include "sampler.h"
+
+using namespace std;
+namespace po = boost::program_options;
+
+boost::shared_ptr<MT19937> rng;
+vector<training::CandidateSet> kbests;
+SparseVector<weight_t> G, u, lambdas;
+double pseudo_doc_decay = 0.9;
+
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("decoder_config,c",po::value<string>(),"[REQ] Decoder configuration file")
+ ("devset,d",po::value<string>(),"[REQ] Source/reference development set")
+ ("weights,w",po::value<string>(),"Initial feature weights file")
+ ("mt_metric,m",po::value<string>()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)")
+ ("size",po::value<unsigned>()->default_value(0), "Process rank (for multiprocess mode)")
+ ("rank",po::value<unsigned>()->default_value(1), "Number of processes (for multiprocess mode)")
+ ("optimizer,o",po::value<unsigned>()->default_value(1), "Optimizer (Adaptive MIRA=1)")
+ ("fear,f",po::value<unsigned>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)")
+ ("hope,h",po::value<unsigned>()->default_value(1), "Hope selection (model+cost=1, mincost=2)")
+ ("eta0", po::value<double>()->default_value(0.1), "Initial step size")
+ ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
+ ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Scale MT loss function by this amount")
+ ("pseudo_doc,e", "Use pseudo-documents for approximate scoring")
+ ("k_best_size,k", po::value<unsigned>()->default_value(500), "Size of hypothesis list to search for oracles");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,H", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help")
+ || !conf->count("decoder_config")
+ || !conf->count("devset")) {
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+ return true;
+}
+
+struct TrainingObserver : public DecoderObserver {
+ explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {}
+
+ const EvaluationMetric& metric;
+ const int kbest_size;
+ const SegmentEvaluator* cur_eval;
+ SufficientStats pdoc;
+ unsigned hi, vi, fi; // hope, viterbi, fear
+
+ void SetSegmentEvaluator(const SegmentEvaluator* eval) {
+ cur_eval = eval;
+ }
+
+ virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) {
+ cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl;
+ abort();
+ }
+
+ unsigned CostAugmentedDecode(const training::CandidateSet& cs,
+ const SparseVector<double>& w,
+ double alpha = 0) {
+ unsigned best_i = 0;
+ double best = -numeric_limits<double>::infinity();
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ double s = cs[i].fmap.dot(w);
+ if (alpha)
+ s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc);
+ if (s > best) {
+ best = s;
+ best_i = i;
+ }
+ }
+ return best_i;
+ }
+
+ virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ pdoc *= pseudo_doc_decay;
+ const unsigned sent_id = smeta.GetSentenceID();
+ kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval);
+ vi = CostAugmentedDecode(kbests[sent_id], lambdas);
+ hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0);
+ fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0);
+ cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl;
+ pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats
+ }
+};
+
+int main(int argc, char** argv) {
+ SetSilent(true); // turn off verbose decoder output
+ register_feature_functions();
+
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
+
+ if (conf.count("random_seed"))
+ rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ rng.reset(new MT19937);
+
+ string metric_name = UppercaseString(conf["mt_metric"].as<string>());
+ if (metric_name == "COMBI") {
+ cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n";
+ metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5";
+ } else if (metric_name == "BLEU") {
+ cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n";
+ metric_name = "IBM_BLEU";
+ }
+ EvaluationMetric* metric = EvaluationMetric::Instance(metric_name);
+ DocumentScorer ds(metric, conf["devset"].as<string>());
+ cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
+ kbests.resize(ds.size());
+ double eta = 0.001;
+
+ ReadFile ini_rf(conf["decoder_config"].as<string>());
+ Decoder decoder(ini_rf.stream());
+
+ vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ if (conf.count("weights")) {
+ Weights::InitFromFile(conf["weights"].as<string>(), &dense_weights);
+ Weights::InitSparseVector(dense_weights, &lambdas);
+ }
+
+ TrainingObserver observer(*metric, conf["k_best_size"].as<unsigned>());
+
+ unsigned num = 200;
+ for (unsigned iter = 1; iter < num; ++iter) {
+ lambdas.init_vector(&dense_weights);
+ unsigned sent_id = rng->next() * ds.size();
+ cerr << "Learning from sentence id: " << sent_id << endl;
+ observer.SetSegmentEvaluator(ds[sent_id]);
+ decoder.SetId(sent_id);
+ decoder.Decode(ds[sent_id]->src, &observer);
+ if (observer.vi != observer.hi) { // viterbi != hope
+ SparseVector<double> grad = kbests[sent_id][observer.fi].fmap;
+ grad -= kbests[sent_id][observer.hi].fmap;
+ cerr << "GRAD: " << grad << endl;
+ const SparseVector<double>& g = grad;
+#if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4)
+ for (auto& gi : g) {
+#else
+ for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it) {
+ const pair<unsigned,double>& gi = *it;
+#endif
+ if (gi.second) {
+ u[gi.first] += gi.second;
+ G[gi.first] += gi.second * gi.second;
+ lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation
+ }
+ }
+ for (SparseVector<double>::iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ const pair<unsigned,double>& xi = *it;
+ double z = fabs(u[xi.first] / iter) - 0.0;
+ double s = 1;
+ if (u[xi.first] > 0) s = -1;
+ if (z > 0 && G[xi.first]) {
+ lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first]));
+ } else {
+ lambdas.set_value(xi.first, 0.0);
+ }
+ }
+ }
+ }
+ cerr << "Optimization complete.\n";
+ Weights::WriteToFile("-", dense_weights, true);
+ return 0;
+}
+