summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-15 17:07:04 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-15 17:07:04 +0000
commit19840a3f47d64c38753c1fac46cb4f39212fc99f (patch)
treeabc060aa8008b43cc15d331f80fa9ecfa48d778c /training
parent49fb41843a2ad81e3ef2b65e5b9264809aac1847 (diff)
model 1 options
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@723 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/em_utils.h24
-rw-r--r--training/model1.cc52
-rw-r--r--training/mr_em_adapted_reduce.cc23
-rw-r--r--training/ttables.h2
4 files changed, 69 insertions, 32 deletions
diff --git a/training/em_utils.h b/training/em_utils.h
new file mode 100644
index 00000000..37762978
--- /dev/null
+++ b/training/em_utils.h
@@ -0,0 +1,24 @@
+#ifndef _EM_UTILS_H_
+#define _EM_UTILS_H_
+
+#include "config.h"
+#ifdef HAVE_BOOST_DIGAMMA
+#include <boost/math/special_functions/digamma.hpp>
+using boost::math::digamma;
+#else
+#warning Using Mark Johnsons digamma()
+#include <cmath>
+inline double digamma(double x) {
+ double result = 0, xx, xx2, xx4;
+ assert(x > 0);
+ for ( ; x < 7; ++x)
+ result -= 1/x;
+ x -= 1.0/2.0;
+ xx = 1.0/x;
+ xx2 = xx*xx;
+ xx4 = xx2*xx2;
+ result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4;
+ return result;
+}
+#endif
+#endif
diff --git a/training/model1.cc b/training/model1.cc
index 487ddb5f..83dacd63 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -1,29 +1,63 @@
#include <iostream>
#include <cmath>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
#include "lattice.h"
#include "stringlib.h"
#include "filelib.h"
#include "ttables.h"
#include "tdict.h"
+#include "em_utils.h"
+namespace po = boost::program_options;
using namespace std;
-int main(int argc, char** argv) {
- if (argc != 2) {
- cerr << "Usage: " << argv[0] << " corpus.fr-en\n";
- return 1;
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training")
+ ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)")
+ ("no_null_word,N","Do not generate from the null token");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
}
- const int ITERATIONS = 5;
- const double BEAM_THRESHOLD = 0.0001;
- TTable tt;
+ po::notify(*conf);
+
+ if (argc < 2 || conf->count("help")) {
+ cerr << "Usage " << argv[0] << " [OPTIONS] corpus.fr-en\n";
+ cerr << dcmdline_options << endl;
+ return false;
+ }
+ return true;
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
+ const string fname = argv[argc - 1];
+ const int ITERATIONS = conf["iterations"].as<unsigned>();
+ const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>());
+ const bool use_null = (conf.count("no_null_word") == 0);
const WordID kNULL = TD::Convert("<eps>");
- bool use_null = true;
+
+ TTable tt;
TTable::Word2Word2Double was_viterbi;
for (int iter = 0; iter < ITERATIONS; ++iter) {
const bool final_iteration = (iter == (ITERATIONS - 1));
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
- ReadFile rf(argv[1]);
+ ReadFile rf(fname);
istream& in = *rf.stream();
double likelihood = 0;
double denom = 0.0;
diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc
index 52387e7f..29416348 100644
--- a/training/mr_em_adapted_reduce.cc
+++ b/training/mr_em_adapted_reduce.cc
@@ -6,36 +6,15 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
-#include "config.h"
-#ifdef HAVE_BOOST_DIGAMMA
-#include <boost/math/special_functions/digamma.hpp>
-using boost::math::digamma;
-#endif
-
#include "filelib.h"
#include "fdict.h"
#include "weights.h"
#include "sparse_vector.h"
+#include "em_utils.h"
using namespace std;
namespace po = boost::program_options;
-#ifndef HAVE_BOOST_DIGAMMA
-#warning Using Mark Johnsons digamma()
-double digamma(double x) {
- double result = 0, xx, xx2, xx4;
- assert(x > 0);
- for ( ; x < 7; ++x)
- result -= 1/x;
- x -= 1.0/2.0;
- xx = 1.0/x;
- xx2 = xx*xx;
- xx4 = xx2*xx2;
- result += log(x)+(1./24.)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4;
- return result;
-}
-#endif
-
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
diff --git a/training/ttables.h b/training/ttables.h
index 04e54f9d..53f5f2ab 100644
--- a/training/ttables.h
+++ b/training/ttables.h
@@ -12,7 +12,7 @@ class TTable {
TTable() {}
typedef std::tr1::unordered_map<WordID, double> Word2Double;
typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double;
- inline const double prob(const int& e, const int& f) const {
+ inline double prob(const int& e, const int& f) const {
const Word2Word2Double::const_iterator cit = ttable.find(e);
if (cit != ttable.end()) {
const Word2Double& cpd = cit->second;