#include "config.h" #include #include #include #include "filelib.h" #include "stringlib.h" #include "weights.h" #include "sparse_vector.h" #include "candidate_set.h" #include "sentence_metadata.h" #include "ns.h" #include "ns_docscorer.h" #include "verbose.h" #include "hg.h" #include "ff_register.h" #include "decoder.h" #include "fdict.h" #include "sampler.h" using namespace std; namespace po = boost::program_options; boost::shared_ptr rng; vector kbests; SparseVector G, u, lambdas; double pseudo_doc_decay = 0.9; bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("decoder_config,c",po::value(),"[REQ] Decoder configuration file") ("devset,d",po::value(),"[REQ] Source/reference development set") ("weights,w",po::value(),"Initial feature weights file") ("mt_metric,m",po::value()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") ("size",po::value()->default_value(0), "Process rank (for multiprocess mode)") ("rank",po::value()->default_value(1), "Number of processes (for multiprocess mode)") ("optimizer,o",po::value()->default_value(1), "Optimizer (Adaptive MIRA=1)") ("fear,f",po::value()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)") ("hope,h",po::value()->default_value(1), "Hope selection (model+cost=1, mincost=2)") ("eta0", po::value()->default_value(0.1), "Initial step size") ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") ("mt_metric_scale,s", po::value()->default_value(1.0), "Scale MT loss function by this amount") ("pseudo_doc,e", "Use pseudo-documents for approximate scoring") ("k_best_size,k", po::value()->default_value(500), "Size of hypothesis list to search for oracles"); po::options_description clo("Command line options"); clo.add_options() ("config", po::value(), "Configuration file") ("help,H", "Print this help message and exit"); po::options_description dconfig_options, dcmdline_options; dconfig_options.add(opts); dcmdline_options.add(opts).add(clo); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); if (conf->count("config")) { ifstream config((*conf)["config"].as().c_str()); po::store(po::parse_config_file(config, dconfig_options), *conf); } po::notify(*conf); if (conf->count("help") || !conf->count("decoder_config") || !conf->count("devset")) { cerr << dcmdline_options << endl; return false; } return true; } struct TrainingObserver : public DecoderObserver { explicit TrainingObserver(const EvaluationMetric& m, const int k) : metric(m), kbest_size(k), cur_eval() {} const EvaluationMetric& metric; const int kbest_size; const SegmentEvaluator* cur_eval; SufficientStats pdoc; unsigned hi, vi, fi; // hope, viterbi, fear void SetSegmentEvaluator(const SegmentEvaluator* eval) { cur_eval = eval; } virtual void NotifySourceParseFailure(const SentenceMetadata& smeta) { cerr << "Failed to translate sentence with ID = " << smeta.GetSentenceID() << endl; abort(); } unsigned CostAugmentedDecode(const training::CandidateSet& cs, const SparseVector& w, double alpha = 0) { unsigned best_i = 0; double best = -numeric_limits::infinity(); for (unsigned i = 0; i < cs.size(); ++i) { double s = cs[i].fmap.dot(w); if (alpha) s += alpha * metric.ComputeScore(cs[i].eval_feats + pdoc); if (s > best) { best = s; best_i = i; } } return best_i; } virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { pdoc *= pseudo_doc_decay; const unsigned sent_id = smeta.GetSentenceID(); kbests[sent_id].AddUniqueKBestCandidates(*hg, kbest_size, cur_eval); vi = CostAugmentedDecode(kbests[sent_id], lambdas); hi = CostAugmentedDecode(kbests[sent_id], lambdas, 1.0); fi = CostAugmentedDecode(kbests[sent_id], lambdas, -1.0); cerr << sent_id << " ||| " << TD::GetString(kbests[sent_id][vi].ewords) << " ||| " << metric.ComputeScore(kbests[sent_id][vi].eval_feats + pdoc) << endl; pdoc += kbests[sent_id][vi].eval_feats; // update pseudodoc stats } }; int main(int argc, char** argv) { SetSilent(true); // turn off verbose decoder output register_feature_functions(); po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; if (conf.count("random_seed")) rng.reset(new MT19937(conf["random_seed"].as())); else rng.reset(new MT19937); string metric_name = UppercaseString(conf["mt_metric"].as()); if (metric_name == "COMBI") { cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n"; metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5"; } else if (metric_name == "BLEU") { cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n"; metric_name = "IBM_BLEU"; } EvaluationMetric* metric = EvaluationMetric::Instance(metric_name); DocumentScorer ds(metric, conf["devset"].as()); cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; kbests.resize(ds.size()); double eta = 0.001; ReadFile ini_rf(conf["decoder_config"].as()); Decoder decoder(ini_rf.stream()); vector& dense_weights = decoder.CurrentWeightVector(); if (conf.count("weights")) { Weights::InitFromFile(conf["weights"].as(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); } TrainingObserver observer(*metric, conf["k_best_size"].as()); unsigned num = 200; for (unsigned iter = 1; iter < num; ++iter) { lambdas.init_vector(&dense_weights); unsigned sent_id = rng->next() * ds.size(); cerr << "Learning from sentence id: " << sent_id << endl; observer.SetSegmentEvaluator(ds[sent_id]); decoder.SetId(sent_id); decoder.Decode(ds[sent_id]->src, &observer); if (observer.vi != observer.hi) { // viterbi != hope SparseVector grad = kbests[sent_id][observer.fi].fmap; grad -= kbests[sent_id][observer.hi].fmap; cerr << "GRAD: " << grad << endl; const SparseVector& g = grad; #if HAVE_CXX11 && (__GNUC_MINOR__ > 4 || __GNUC__ > 4) for (auto& gi : g) { #else for (SparseVector::const_iterator it = g.begin(); it != g.end(); ++it) { const pair& gi = *it; #endif if (gi.second) { u[gi.first] += gi.second; G[gi.first] += gi.second * gi.second; lambdas.set_value(gi.first, 1.0); // this is a dummy value to trigger recomputation } } for (SparseVector::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { const pair& xi = *it; double z = fabs(u[xi.first] / iter) - 0.0; double s = 1; if (u[xi.first] > 0) s = -1; if (z > 0 && G[xi.first]) { lambdas.set_value(xi.first, eta * s * z * iter / sqrt(G[xi.first])); } else { lambdas.set_value(xi.first, 0.0); } } } } cerr << "Optimization complete.\n"; Weights::WriteToFile("-", dense_weights, true); return 0; }