From 0172721855098ca02b207231a654dffa5e4eb1c9 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 22 Jun 2010 05:12:27 +0000 Subject: initial checkin git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/plftools.cc | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 training/plftools.cc (limited to 'training/plftools.cc') diff --git a/training/plftools.cc b/training/plftools.cc new file mode 100644 index 00000000..903ec54f --- /dev/null +++ b/training/plftools.cc @@ -0,0 +1,93 @@ +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "tdict.h" +#include "prob.h" +#include "hg.h" +#include "hg_io.h" +#include "viterbi.h" +#include "kbest.h" + +namespace po = boost::program_options; +using namespace std; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value(), "REQ. Lattice input file (PLF), - for STDIN") + ("prior_scale,p", po::value()->default_value(1.0), "Scale path probabilities by this amount < 1 flattens, > 1 sharpens") + ("weight,w", po::value >(), "Weight(s) for arc features") + ("output,o", po::value()->default_value("plf"), "Output format (text, plf)") + ("command,c", po::value()->default_value("push"), "Operation to perform: push, graphviz, 1best, 2best ...") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char **argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as(); + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + SparseVector wts; + vector wv; + if (conf.count("weight") > 0) wv = conf["weight"].as >(); + if (wv.empty()) wv.push_back(1.0); + for (int i = 0; i < wv.size(); ++i) { + const string fname = "Feature_" + boost::lexical_cast(i); + cerr << "[INFO] Arc weight " << (i+1) << " = " << wv[i] << endl; + wts.set_value(FD::Convert(fname), wv[i]); + } + const string cmd = conf["command"].as(); + const bool push_weights = cmd == "push"; + const bool output_plf = cmd == "plf"; + const bool graphviz = cmd == "graphviz"; + const bool kbest = cmd.rfind("best") == (cmd.size() - 4) && cmd.size() > 4; + int k = 1; + if (kbest) { + k = boost::lexical_cast(cmd.substr(0, cmd.size() - 4)); + cerr << "KBEST = " << k << endl; + } + const double scale = conf["prior_scale"].as(); + int lc = 0; + while(*in) { + ++lc; + string plf; + getline(*in, plf); + if (plf.empty()) continue; + Hypergraph hg; + HypergraphIO::ReadFromPLF(plf, &hg); + hg.Reweight(wts); + if (graphviz) hg.PrintGraphviz(); + if (push_weights) hg.PushWeightsToSource(scale); + if (output_plf) { + cout << HypergraphIO::AsPLF(hg) << endl; + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cout << lc << " ||| " << TD::GetString(d->yield) << " ||| " << d->score << endl; + } + } + } + return 0; +} + -- cgit v1.2.3