diff options
| -rw-r--r-- | training/em_utils.h | 24 | ||||
| -rw-r--r-- | training/model1.cc | 52 | ||||
| -rw-r--r-- | training/mr_em_adapted_reduce.cc | 23 | ||||
| -rw-r--r-- | training/ttables.h | 2 | 
4 files changed, 69 insertions, 32 deletions
diff --git a/training/em_utils.h b/training/em_utils.h new file mode 100644 index 00000000..37762978 --- /dev/null +++ b/training/em_utils.h @@ -0,0 +1,24 @@ +#ifndef _EM_UTILS_H_ +#define _EM_UTILS_H_ + +#include "config.h" +#ifdef HAVE_BOOST_DIGAMMA +#include <boost/math/special_functions/digamma.hpp> +using boost::math::digamma; +#else +#warning Using Mark Johnsons digamma() +#include <cmath> +inline double digamma(double x) { +  double result = 0, xx, xx2, xx4; +  assert(x > 0); +  for ( ; x < 7; ++x) +    result -= 1/x; +  x -= 1.0/2.0; +  xx = 1.0/x; +  xx2 = xx*xx; +  xx4 = xx2*xx2; +  result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; +  return result; +} +#endif +#endif diff --git a/training/model1.cc b/training/model1.cc index 487ddb5f..83dacd63 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -1,29 +1,63 @@  #include <iostream>  #include <cmath> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +  #include "lattice.h"  #include "stringlib.h"  #include "filelib.h"  #include "ttables.h"  #include "tdict.h" +#include "em_utils.h" +namespace po = boost::program_options;  using namespace std; -int main(int argc, char** argv) { -  if (argc != 2) { -    cerr << "Usage: " << argv[0] << " corpus.fr-en\n"; -    return 1; +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  opts.add_options() +        ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training") +        ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)") +        ("no_null_word,N","Do not generate from the null token"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config", po::value<string>(), "Configuration file") +        ("help,h", "Print this help message and exit"); +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(opts).add(clo); +   +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (conf->count("config")) { +    ifstream config((*conf)["config"].as<string>().c_str()); +    po::store(po::parse_config_file(config, dconfig_options), *conf);    } -  const int ITERATIONS = 5; -  const double BEAM_THRESHOLD = 0.0001; -  TTable tt; +  po::notify(*conf); + +  if (argc < 2 || conf->count("help")) { +    cerr << "Usage " << argv[0] << " [OPTIONS] corpus.fr-en\n"; +    cerr << dcmdline_options << endl; +    return false; +  } +  return true; +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  if (!InitCommandLine(argc, argv, &conf)) return 1; +  const string fname = argv[argc - 1]; +  const int ITERATIONS = conf["iterations"].as<unsigned>(); +  const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>()); +  const bool use_null = (conf.count("no_null_word") == 0);    const WordID kNULL = TD::Convert("<eps>"); -  bool use_null = true; + +  TTable tt;    TTable::Word2Word2Double was_viterbi;    for (int iter = 0; iter < ITERATIONS; ++iter) {      const bool final_iteration = (iter == (ITERATIONS - 1));      cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; -    ReadFile rf(argv[1]); +    ReadFile rf(fname);      istream& in = *rf.stream();      double likelihood = 0;      double denom = 0.0; diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc index 52387e7f..29416348 100644 --- a/training/mr_em_adapted_reduce.cc +++ b/training/mr_em_adapted_reduce.cc @@ -6,36 +6,15 @@  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> -#include "config.h" -#ifdef HAVE_BOOST_DIGAMMA -#include <boost/math/special_functions/digamma.hpp> -using boost::math::digamma; -#endif -  #include "filelib.h"  #include "fdict.h"  #include "weights.h"  #include "sparse_vector.h" +#include "em_utils.h"  using namespace std;  namespace po = boost::program_options; -#ifndef HAVE_BOOST_DIGAMMA -#warning Using Mark Johnsons digamma() -double digamma(double x) { -  double result = 0, xx, xx2, xx4; -  assert(x > 0); -  for ( ; x < 7; ++x) -    result -= 1/x; -  x -= 1.0/2.0; -  xx = 1.0/x; -  xx2 = xx*xx; -  xx4 = xx2*xx2; -  result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; -  return result; -} -#endif -  void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options() diff --git a/training/ttables.h b/training/ttables.h index 04e54f9d..53f5f2ab 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -12,7 +12,7 @@ class TTable {    TTable() {}    typedef std::tr1::unordered_map<WordID, double> Word2Double;    typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double; -  inline const double prob(const int& e, const int& f) const { +  inline double prob(const int& e, const int& f) const {      const Word2Word2Double::const_iterator cit = ttable.find(e);      if (cit != ttable.end()) {        const Word2Double& cpd = cit->second;  | 
