#include #include #include #include #include #include #include "filelib.h" #include "stringlib.h" #include "weights.h" #include "hg_io.h" #include "kbest.h" #include "viterbi.h" #include "ns.h" #include "ns_docscorer.h" using namespace std; namespace po = boost::program_options; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") ("weights,w",po::value(), "[REQD] Weights files from current iterations") ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") ("kbest_size,k",po::value()->default_value(500u), "Top k-hypotheses to extract") ("cccp_iterations,I", po::value()->default_value(10u), "CCCP iterations (T')") ("ssd_iterations,J", po::value()->default_value(5u), "Stochastic subgradient iterations (T'')") ("eta", po::value()->default_value(1e-4), "Step size") ("regularization_strength,C", po::value()->default_value(1.0), "L2 regularization strength") ("alpha,a", po::value()->default_value(10.0), "Cost scale (alpha); alpha * [1-metric(y,y')]") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); bool flag = false; if (!conf->count("reference")) { cerr << "Please specify one or more references using -r \n"; flag = true; } if (!conf->count("weights")) { cerr << "Please specify weights using -w \n"; flag = true; } if (flag || conf->count("help")) { cerr << dcmdline_options << endl; exit(1); } } struct HypInfo { HypInfo() : g(-100.0f) {} HypInfo(const vector& h, const SparseVector& feats, const SegmentEvaluator& scorer, const EvaluationMetric* metric) : hyp(h), x(feats) { SufficientStats ss; scorer.Evaluate(hyp, &ss); g = metric->ComputeScore(ss); if (metric->IsErrorMetric()) g = 1 - g; } vector hyp; float g; SparseVector x; }; void CostAugmentedSearch(const vector& kbest, const SparseVector& w, double alpha, SparseVector* fmap) { unsigned best_i = 0; double best = -numeric_limits::infinity(); for (unsigned i = 0; i < kbest.size(); ++i) { double s = kbest[i].x.dot(w) + alpha * kbest[i].g; if (s > best) { best = s; best_i = i; } } *fmap = kbest[best_i].x; } // runs lines 4--15 of rampion algorithm int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); const string evaluation_metric = conf["evaluation_metric"].as(); EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); DocumentScorer ds(metric, conf["reference"].as >()); cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; double goodsign = 1; double badsign = -goodsign; Hypergraph hg; string last_file; ReadFile in_read(conf["input"].as()); istream &in=*in_read.stream(); const unsigned kbest_size = conf["kbest_size"].as(); const unsigned tp = conf["cccp_iterations"].as(); const unsigned tpp = conf["ssd_iterations"].as(); const double eta = conf["eta"].as(); const double reg = conf["regularization_strength"].as(); const double alpha = conf["alpha"].as(); SparseVector weights; { vector vweights; const string weightsf = conf["weights"].as(); Weights::InitFromFile(weightsf, &vweights); Weights::InitSparseVector(vweights, &weights); } string line, file; vector > kis; cerr << "Loading hypergraphs...\n"; while(getline(in, line)) { istringstream is(line); int sent_id; kis.resize(kis.size() + 1); vector& curkbest = kis.back(); is >> file >> sent_id; ReadFile rf(file); if (kis.size() % 5 == 0) { cerr << '.'; } if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } HypergraphIO::ReadFromJSON(rf.stream(), &hg); hg.Reweight(weights); KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); for (int i = 0; i < kbest_size; ++i) { const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = kbest.LazyKthBest(hg.nodes_.size() - 1, i); if (!d) break; curkbest.push_back(HypInfo(d->yield, d->feature_values, *ds[sent_id], metric)); } } cerr << "\nHypergraphs loaded.\n"; vector > goals(kis.size()); // f(x_i,y+,h+) SparseVector fear; // f(x,y-,h-) for (unsigned iterp = 1; iterp <= tp; ++iterp) { cerr << "CCCP Iteration " << iterp << endl; for (int i = 0; i < goals.size(); ++i) CostAugmentedSearch(kis[i], weights, goodsign * alpha, &goals[i]); for (unsigned iterpp = 1; iterpp <= tpp; ++iterpp) { cerr << " SSD Iteration " << iterpp << endl; for (int i = 0; i < goals.size(); ++i) { CostAugmentedSearch(kis[i], weights, badsign * alpha, &fear); weights -= weights * (eta * reg / goals.size()); weights += (goals[i] - fear) * eta; } } } vector w; weights.init_vector(&w); Weights::WriteToFile("-", w); return 0; }