diff options
Diffstat (limited to 'rampion/rampion_cccp.cc')
-rw-r--r-- | rampion/rampion_cccp.cc | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/rampion/rampion_cccp.cc b/rampion/rampion_cccp.cc new file mode 100644 index 00000000..6eb3ccf3 --- /dev/null +++ b/rampion/rampion_cccp.cc @@ -0,0 +1,157 @@ +#include <sstream> +#include <iostream> +#include <vector> +#include <limits> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "ns.h" +#include "ns_docscorer.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value<vector<string> >(), "[REQD] Reference translation (tokenized text)") + ("weights,w",po::value<string>(), "[REQD] Weights files from current iterations") + ("input,i",po::value<string>()->default_value("-"), "Input file to map (- is STDIN)") + ("evaluation_metric,m",po::value<string>()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") + ("kbest_size,k",po::value<unsigned>()->default_value(500u), "Top k-hypotheses to extract") + ("cccp_iterations,I", po::value<unsigned>()->default_value(10u), "CCCP iterations (T')") + ("ssd_iterations,J", po::value<unsigned>()->default_value(5u), "Stochastic subgradient iterations (T'')") + ("eta", po::value<double>()->default_value(1e-4), "Step size") + ("regularization_strength,C", po::value<double>()->default_value(1.0), "L2 regularization strength") + ("alpha,a", po::value<double>()->default_value(10.0), "Cost scale (alpha); alpha * [1-metric(y,y')]") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r <REF.TXT>\n"; + flag = true; + } + if (!conf->count("weights")) { + cerr << "Please specify weights using -w <WEIGHTS.TXT>\n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct HypInfo { + HypInfo() : g(-100.0f) {} + HypInfo(const vector<WordID>& h, + const SparseVector<weight_t>& feats, + const SegmentEvaluator& scorer, const EvaluationMetric* metric) : hyp(h), x(feats) { + SufficientStats ss; + scorer.Evaluate(hyp, &ss); + g = metric->ComputeScore(ss); + } + + vector<WordID> hyp; + float g; + SparseVector<weight_t> x; +}; + +void CostAugmentedSearch(const vector<HypInfo>& kbest, + const SparseVector<double>& w, + double alpha, + SparseVector<double>* fmap) { + unsigned best_i = 0; + double best = -numeric_limits<double>::infinity(); + for (unsigned i = 0; i < kbest.size(); ++i) { + double s = kbest[i].x.dot(w) + alpha * kbest[i].g; + if (s > best) { + best = s; + best_i = i; + } + } + *fmap = kbest[best_i].x; +} + +// runs lines 4--15 of rampion algorithm +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string evaluation_metric = conf["evaluation_metric"].as<string>(); + + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as<vector<string> >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; + double goodsign = 1; + if (metric->IsErrorMetric()) goodsign = -goodsign; + double badsign = -goodsign; + + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as<string>()); + istream &in=*in_read.stream(); + const unsigned kbest_size = conf["kbest_size"].as<unsigned>(); + const unsigned tp = conf["cccp_iterations"].as<unsigned>(); + const unsigned tpp = conf["ssd_iterations"].as<unsigned>(); + const double eta = conf["eta"].as<double>(); + const double reg = conf["regularization_strength"].as<double>(); + const double alpha = conf["alpha"].as<double>(); + SparseVector<weight_t> weights; + { + vector<weight_t> vweights; + const string weightsf = conf["weights"].as<string>(); + Weights::InitFromFile(weightsf, &vweights); + Weights::InitSparseVector(vweights, &weights); + } + string line, file; + vector<vector<HypInfo> > kis; + cerr << "Loading hypergraphs...\n"; + while(getline(in, line)) { + istringstream is(line); + int sent_id; + kis.resize(kis.size() + 1); + vector<HypInfo>& curkbest = kis.back(); + is >> file >> sent_id; + ReadFile rf(file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(weights); + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); + + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + curkbest.push_back(HypInfo(d->yield, d->feature_values, *ds[sent_id], metric)); + } + } + + cerr << "Hypergraphs loaded.\n"; + vector<SparseVector<weight_t> > goals(kis.size()); // f(x_i,y+,h+) + SparseVector<weight_t> fear; // f(x,y-,h-) + for (unsigned iterp = 1; iterp <= tp; ++iterp) { + cerr << "CCCP Iteration " << iterp << endl; + for (int i = 0; i < goals.size(); ++i) + CostAugmentedSearch(kis[i], weights, goodsign * alpha, &goals[i]); + for (unsigned iterpp = 1; iterpp <= tpp; ++iterpp) { + cerr << " SSD Iteration " << iterpp << endl; + for (int i = 0; i < goals.size(); ++i) { + CostAugmentedSearch(kis[i], weights, badsign * alpha, &fear); + weights -= weights * (eta * reg / goals.size()); + weights += (goals[i] - fear) * eta; + } + } + } + vector<weight_t> w; + weights.init_vector(&w); + Weights::WriteToFile("-", w); + return 0; +} + |