diff options
-rw-r--r-- | training/em_utils.h | 24 | ||||
-rw-r--r-- | training/model1.cc | 52 | ||||
-rw-r--r-- | training/mr_em_adapted_reduce.cc | 23 | ||||
-rw-r--r-- | training/ttables.h | 2 |
4 files changed, 69 insertions, 32 deletions
diff --git a/training/em_utils.h b/training/em_utils.h new file mode 100644 index 00000000..37762978 --- /dev/null +++ b/training/em_utils.h @@ -0,0 +1,24 @@ +#ifndef _EM_UTILS_H_ +#define _EM_UTILS_H_ + +#include "config.h" +#ifdef HAVE_BOOST_DIGAMMA +#include <boost/math/special_functions/digamma.hpp> +using boost::math::digamma; +#else +#warning Using Mark Johnsons digamma() +#include <cmath> +inline double digamma(double x) { + double result = 0, xx, xx2, xx4; + assert(x > 0); + for ( ; x < 7; ++x) + result -= 1/x; + x -= 1.0/2.0; + xx = 1.0/x; + xx2 = xx*xx; + xx4 = xx2*xx2; + result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; + return result; +} +#endif +#endif diff --git a/training/model1.cc b/training/model1.cc index 487ddb5f..83dacd63 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -1,29 +1,63 @@ #include <iostream> #include <cmath> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + #include "lattice.h" #include "stringlib.h" #include "filelib.h" #include "ttables.h" #include "tdict.h" +#include "em_utils.h" +namespace po = boost::program_options; using namespace std; -int main(int argc, char** argv) { - if (argc != 2) { - cerr << "Usage: " << argv[0] << " corpus.fr-en\n"; - return 1; +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training") + ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)") + ("no_null_word,N","Do not generate from the null token"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); } - const int ITERATIONS = 5; - const double BEAM_THRESHOLD = 0.0001; - TTable tt; + po::notify(*conf); + + if (argc < 2 || conf->count("help")) { + cerr << "Usage " << argv[0] << " [OPTIONS] corpus.fr-en\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + const string fname = argv[argc - 1]; + const int ITERATIONS = conf["iterations"].as<unsigned>(); + const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>()); + const bool use_null = (conf.count("no_null_word") == 0); const WordID kNULL = TD::Convert("<eps>"); - bool use_null = true; + + TTable tt; TTable::Word2Word2Double was_viterbi; for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; - ReadFile rf(argv[1]); + ReadFile rf(fname); istream& in = *rf.stream(); double likelihood = 0; double denom = 0.0; diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc index 52387e7f..29416348 100644 --- a/training/mr_em_adapted_reduce.cc +++ b/training/mr_em_adapted_reduce.cc @@ -6,36 +6,15 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> -#include "config.h" -#ifdef HAVE_BOOST_DIGAMMA -#include <boost/math/special_functions/digamma.hpp> -using boost::math::digamma; -#endif - #include "filelib.h" #include "fdict.h" #include "weights.h" #include "sparse_vector.h" +#include "em_utils.h" using namespace std; namespace po = boost::program_options; -#ifndef HAVE_BOOST_DIGAMMA -#warning Using Mark Johnsons digamma() -double digamma(double x) { - double result = 0, xx, xx2, xx4; - assert(x > 0); - for ( ; x < 7; ++x) - result -= 1/x; - x -= 1.0/2.0; - xx = 1.0/x; - xx2 = xx*xx; - xx4 = xx2*xx2; - result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4; - return result; -} -#endif - void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() diff --git a/training/ttables.h b/training/ttables.h index 04e54f9d..53f5f2ab 100644 --- a/training/ttables.h +++ b/training/ttables.h @@ -12,7 +12,7 @@ class TTable { TTable() {} typedef std::tr1::unordered_map<WordID, double> Word2Double; typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double; - inline const double prob(const int& e, const int& f) const { + inline double prob(const int& e, const int& f) const { const Word2Word2Double::const_iterator cit = ttable.find(e); if (cit != ttable.end()) { const Word2Double& cpd = cit->second; |