From 671c21451542e2dd20e45b4033d44d8e8735f87b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 3 Dec 2009 16:33:55 -0500 Subject: initial check in --- training/atools.cc | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 training/atools.cc (limited to 'training/atools.cc') diff --git a/training/atools.cc b/training/atools.cc new file mode 100644 index 00000000..bac73859 --- /dev/null +++ b/training/atools.cc @@ -0,0 +1,207 @@ +#include +#include +#include + +#include +#include +#include + +#include "filelib.h" +#include "aligner.h" + +namespace po = boost::program_options; +using namespace std; +using boost::shared_ptr; + +struct Command { + virtual ~Command() {} + virtual string Name() const = 0; + + // returns 1 for alignment grid output [default] + // returns 2 if Summary() should be called [for AER, etc] + virtual int Result() const { return 1; } + + virtual bool RequiresTwoOperands() const { return true; } + virtual void Apply(const Array2D& a, const Array2D& b, Array2D* x) = 0; + void EnsureSize(const Array2D& a, const Array2D& b, Array2D* x) { + x->resize(max(a.width(), b.width()), max(a.height(), b.width())); + } + bool Safe(const Array2D& a, int i, int j) const { + if (i < a.width() && j < a.height()) + return a(i,j); + else + return false; + } + virtual void Summary() { assert(!"Summary should have been overridden"); } +}; + +// compute fmeasure, second alignment is reference, first is hyp +struct FMeasureCommand : public Command { + FMeasureCommand() : matches(), num_predicted(), num_in_ref() {} + int Result() const { return 2; } + string Name() const { return "f"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D& hyp, const Array2D& ref, Array2D* x) { + int i_len = ref.width(); + int j_len = ref.height(); + for (int i = 0; i < i_len; ++i) { + for (int j = 0; j < j_len; ++j) { + if (ref(i,j)) { + ++num_in_ref; + if (Safe(hyp, i, j)) ++matches; + } + } + } + for (int i = 0; i < hyp.width(); ++i) + for (int j = 0; j < hyp.height(); ++j) + if (hyp(i,j)) ++num_predicted; + } + void Summary() { + if (num_predicted == 0 || num_in_ref == 0) { + cerr << "Insufficient statistics to compute f-measure!\n"; + abort(); + } + const double prec = static_cast(matches) / num_predicted; + const double rec = static_cast(matches) / num_in_ref; + cout << "P: " << prec << endl; + cout << "R: " << rec << endl; + const double f = (2.0 * prec * rec) / (rec + prec); + cout << "F: " << f << endl; + } + int matches; + int num_predicted; + int num_in_ref; +}; + +struct ConvertCommand : public Command { + string Name() const { return "convert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D& in, const Array2D¬_used, Array2D* x) { + *x = in; + } +}; + +struct InvertCommand : public Command { + string Name() const { return "invert"; } + bool RequiresTwoOperands() const { return false; } + void Apply(const Array2D& in, const Array2D¬_used, Array2D* x) { + Array2D& res = *x; + res.resize(in.height(), in.width()); + for (int i = 0; i < in.height(); ++i) + for (int j = 0; j < in.width(); ++j) + res(i, j) = in(j, i); + } +}; + +struct IntersectCommand : public Command { + string Name() const { return "intersect"; } + bool RequiresTwoOperands() const { return true; } + void Apply(const Array2D& a, const Array2D& b, Array2D* x) { + EnsureSize(a, b, x); + Array2D& res = *x; + for (int i = 0; i < a.width(); ++i) + for (int j = 0; j < a.height(); ++j) + res(i, j) = Safe(a, i, j) && Safe(b, i, j); + } +}; + +map > commands; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + ostringstream os; + os << "[REQ] Operation to perform:"; + for (map >::iterator it = commands.begin(); + it != commands.end(); ++it) { + os << ' ' << it->first; + } + string cstr = os.str(); + opts.add_options() + ("input_1,i", po::value(), "[REQ] Alignment 1 file, - for STDIN") + ("input_2,j", po::value(), "[OPT] Alignment 2 file, - for STDIN") + ("command,c", po::value()->default_value("convert"), cstr.c_str()) + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input_1") == 0 || conf->count("command") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } + const string cmd = (*conf)["command"].as(); + if (commands.count(cmd) == 0) { + cerr << "Don't understand command: " << cmd << endl; + exit(1); + } + if (commands[cmd]->RequiresTwoOperands()) { + if (conf->count("input_2") == 0) { + cerr << "Command '" << cmd << "' requires two alignment files\n"; + exit(1); + } + if ((*conf)["input_1"].as() == "-" && (*conf)["input_2"].as() == "-") { + cerr << "Both inputs cannot be STDIN\n"; + exit(1); + } + } else { + if (conf->count("input_2") != 0) { + cerr << "Command '" << cmd << "' requires only one alignment file\n"; + exit(1); + } + } +} + +template static void AddCommand() { + C* c = new C; + commands[c->Name()].reset(c); +} + +int main(int argc, char **argv) { + AddCommand(); + AddCommand(); + AddCommand(); + AddCommand(); + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + Command& cmd = *commands[conf["command"].as()]; + boost::shared_ptr rf1(new ReadFile(conf["input_1"].as())); + boost::shared_ptr rf2; + if (cmd.RequiresTwoOperands()) + rf2.reset(new ReadFile(conf["input_2"].as())); + istream* in1 = rf1->stream(); + istream* in2 = NULL; + if (rf2) in2 = rf2->stream(); + while(*in1) { + string line1; + string line2; + getline(*in1, line1); + if (in2) { + getline(*in2, line2); + if ((*in1 && !*in2) || (*in2 && !*in1)) { + cerr << "Mismatched number of lines!\n"; + exit(1); + } + } + if (line1.empty() && !*in1) break; + shared_ptr > out(new Array2D); + shared_ptr > a1 = AlignerTools::ReadPharaohAlignmentGrid(line1); + if (in2) { + shared_ptr > a2 = AlignerTools::ReadPharaohAlignmentGrid(line2); + cmd.Apply(*a1, *a2, out.get()); + } else { + Array2D dummy; + cmd.Apply(*a1, dummy, out.get()); + } + + if (cmd.Result() == 1) { + AlignerTools::SerializePharaohFormat(*out, &cout); + } + } + if (cmd.Result() == 2) + cmd.Summary(); + return 0; +} + -- cgit v1.2.3