summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-02-26 13:26:37 +0100
committerPatrick Simianer <p@simianer.de>2015-02-26 13:26:37 +0100
commit4223261682388944fe1b1cf31b9d51d88f9ad53b (patch)
treedaf072c310d60b0386587bde5e554312f193b3b2 /training/dtrain/dtrain.h
parent2a37a7ad1b21ab54701de3b5b44dc4ea55a75307 (diff)
refactoring
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r--training/dtrain/dtrain.h57
1 files changed, 42 insertions, 15 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 2b466930..728b0698 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -1,9 +1,6 @@
#ifndef _DTRAIN_H_
#define _DTRAIN_H_
-#define DTRAIN_DOTS 10 // after how many inputs to display a '.'
-#define DTRAIN_SCALE 100000
-
#include <iomanip>
#include <climits>
#include <string.h>
@@ -25,20 +22,17 @@ namespace po = boost::program_options;
namespace dtrain
{
-typedef double score_t;
-
struct ScoredHyp
{
- vector<WordID> w;
+ vector<WordID> w;
SparseVector<weight_t> f;
- score_t model, score;
- unsigned rank;
+ weight_t model, gold;
+ size_t rank;
};
inline void
RegisterAndConvert(const vector<string>& strs, vector<WordID>& ids)
{
- vector<string>::const_iterator it;
for (auto s: strs)
ids.push_back(TD::Convert(s));
}
@@ -46,7 +40,7 @@ RegisterAndConvert(const vector<string>& strs, vector<WordID>& ids)
inline void
PrintWordIDVec(vector<WordID>& v, ostream& os=cerr)
{
- for (unsigned i = 0; i < v.size(); i++) {
+ for (size_t i = 0; i < v.size(); i++) {
os << TD::Convert(v[i]);
if (i < v.size()-1) os << " ";
}
@@ -57,12 +51,45 @@ inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
inline ostream& _p2(ostream& out) { return out << setprecision(2); }
inline ostream& _p5(ostream& out) { return out << setprecision(5); }
-template<typename T>
-inline T
-sign(T z)
+bool
+dtrain_init(int argc, char** argv, po::variables_map* conf)
{
- if (z == 0) return 0;
- return z < 0 ? -1 : +1;
+ po::options_description ini("Configuration File Options");
+ ini.add_options()
+ ("bitext,b", po::value<string>(), "bitext")
+ ("decoder_config,C", po::value<string>(), "configuration file for decoder")
+ ("iterations,T", po::value<size_t>()->default_value(10), "number of iterations T (per shard)")
+ ("k", po::value<size_t>()->default_value(100), "size of kbest list")
+ ("learning_rate,l", po::value<weight_t>()->default_value(1.0), "learning rate")
+ ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength")
+ ("error_margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
+ ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
+ ("input_weights,w", po::value<string>(), "input weights file")
+ ("average,a", po::value<bool>()->default_value(false), "output average weights")
+ ("keep,K", po::value<bool>()->default_value(false), "output a weight file per iteration")
+ ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
+ ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
+ "list of weights to print after each iteration");
+ po::options_description cl("Command Line Options");
+ cl.add_options()
+ ("config,c", po::value<string>(), "dtrain config file");
+ cl.add(ini);
+ po::store(parse_command_line(argc, argv, cl), *conf);
+ if (conf->count("config")) {
+ ifstream f((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(f, ini), *conf);
+ }
+ po::notify(*conf);
+ if (!conf->count("decoder_config")) {
+ cerr << "Missing decoder configuration." << endl;
+ return false;
+ }
+ if (!conf->count("bitext")) {
+ cerr << "No training data given." << endl;
+ return false;
+ }
+
+ return true;
}
} // namespace