From 19840a3f47d64c38753c1fac46cb4f39212fc99f Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 15 Nov 2010 17:07:04 +0000 Subject: model 1 options git-svn-id: https://ws10smt.googlecode.com/svn/trunk@723 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/em_utils.h | 24 +++++++++++++++++++ training/model1.cc | 52 +++++++++++++++++++++++++++++++++------- training/mr_em_adapted_reduce.cc | 23 +----------------- training/ttables.h | 2 +- 4 files changed, 69 insertions(+), 32 deletions(-) create mode 100644 training/em_utils.h (limited to 'training') diff --git a/training/em_utils.h b/training/em_utils.h new file mode 100644 index 00000000..37762978 --- /dev/null +++ b/training/em_utils.h @@ -0,0 +1,24 @@ +#ifndef _EM_UTILS_H_ +#define _EM_UTILS_H_ + +#include "config.h" +#ifdef HAVE_BOOST_DIGAMMA +#include +using boost::math::digamma; +#else +#warning Using Mark Johnsons digamma() +#include +inline double digamma(double x) { + double result = 0, xx, xx2, xx4; + assert(x > 0); + for ( ; x < 7; ++x) + result -= 1/x; + x -= 1.0/2.0; + xx = 1.0/x; + xx2 = xx*xx; + xx4 = xx2*xx2; + result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; + return result; +} +#endif +#endif diff --git a/training/model1.cc b/training/model1.cc index 487ddb5f..83dacd63 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -1,29 +1,63 @@ #include #include +#include +#include + #include "lattice.h" #include "stringlib.h" #include "filelib.h" #include "ttables.h" #include "tdict.h" +#include "em_utils.h" +namespace po = boost::program_options; using namespace std; -int main(int argc, char** argv) { - if (argc != 2) { - cerr << "Usage: " << argv[0] << " corpus.fr-en\n"; - return 1; +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("iterations,i",po::value()->default_value(5),"Number of iterations of EM training") + ("beam_threshold,t",po::value()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)") + ("no_null_word,N","Do not generate from the null token"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); } - const int ITERATIONS = 5; - const double BEAM_THRESHOLD = 0.0001; - TTable tt; + po::notify(*conf); + + if (argc < 2 || conf->count("help")) { + cerr << "Usage " << argv[0] << " [OPTIONS] corpus.fr-en\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + const string fname = argv[argc - 1]; + const int ITERATIONS = conf["iterations"].as(); + const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as()); + const bool use_null = (conf.count("no_null_word") == 0); const WordID kNULL = TD::Convert(""); - bool use_null = true; + + TTable tt; TTable::Word2Word2Double was_viterbi; for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; - ReadFile rf(argv[1]); + ReadFile rf(fname); istream& in = *rf.stream(); double likelihood = 0; double denom = 0.0; diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc index 52387e7f..29416348 100644 --- a/training/mr_em_adapted_reduce.cc +++ b/training/mr_em_adapted_reduce.cc @@ -6,36 +6,15 @@ #include #include -#include "config.h" -#ifdef HAVE_BOOST_DIGAMMA -#include -using boost::math::digamma; -#endif - #include "filelib.h" #include "fdict.h" #include "weights.h" #include "sparse_vector.h" +#include "em_utils.h" using namespace std; namespace po = boost::program_options; -#ifndef HAVE_BOOST_DIGAMMA -#warning Using Mark Johnsons digamma() -double digamma(double x) { - double result = 0, xx, xx2, xx4; - assert(x > 0); - for ( ; x < 7; ++x) - result -= 1/x; - x -= 1.0/2.0; - xx = 1.0/x; - xx2 = xx*xx; - xx4 = xx2*xx2; - result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; - return result; -} -#endif - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/training/ttables.h b/training/ttables.h index 04e54f9d..53f5f2ab 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -12,7 +12,7 @@ class TTable { TTable() {} typedef std::tr1::unordered_map Word2Double; typedef std::tr1::unordered_map Word2Word2Double; - inline const double prob(const int& e, const int& f) const { + inline double prob(const int& e, const int& f) const { const Word2Word2Double::const_iterator cit = ttable.find(e); if (cit != ttable.end()) { const Word2Double& cpd = cit->second; -- cgit v1.2.3