summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc86
1 files changed, 57 insertions, 29 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 373458e8..16b83a70 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -1,6 +1,11 @@
-#include "dcommon.h"
+#include "common.h"
+#include "kbestget.h"
#include "learner.h"
+#include "util.h"
+#ifdef DTRAIN_DEBUG
+#include "tests.h"
+#endif
@@ -12,20 +17,33 @@ bool
init(int argc, char** argv, po::variables_map* conf)
{
po::options_description opts( "Options" );
+ size_t k, N, T;
+ // TODO scoring metric as parameter/in config
opts.add_options()
- ( "decoder-config,c", po::value<string>(), "configuration file for cdec" )
- ( "kbest,k", po::value<size_t>(), "k for kbest" )
- ( "ngrams,n", po::value<int>(), "n for Ngrams" )
- ( "filter,f", po::value<string>(), "filter kbest list" )
- ( "test", "run tests and exit");
+ ( "decoder-config,c", po::value<string>(), "configuration file for cdec" )
+ ( "kbest,k", po::value<size_t>(&k)->default_value(DTRAIN_DEFAULT_K), "k for kbest" )
+ ( "ngrams,n", po::value<size_t>(&N)->default_value(DTRAIN_DEFAULT_N), "n for Ngrams" )
+ ( "filter,f", po::value<string>(), "filter kbest list" ) // FIXME
+ ( "epochs,t", po::value<size_t>(&T)->default_value(DTRAIN_DEFAULT_T), "# of iterations T" )
+#ifndef DTRAIN_DEBUG
+ ;
+#else
+ ( "test", "run tests and exit");
+#endif
po::options_description cmdline_options;
cmdline_options.add(opts);
po::store( parse_command_line(argc, argv, cmdline_options), *conf );
po::notify( *conf );
- if ( ! (conf->count("decoder-config") || conf->count("test")) ) {
+ if ( ! conf->count("decoder-config") ) {
cerr << cmdline_options << endl;
return false;
}
+ #ifdef DTRAIN_DEBUG
+ if ( ! conf->count("test") ) {
+ cerr << cmdline_options << endl;
+ return false;
+ }
+ #endif
return true;
}
@@ -40,19 +58,21 @@ main(int argc, char** argv)
SetSilent(true);
po::variables_map conf;
if (!init(argc, argv, &conf)) return 1;
+#ifdef DTRAIN_DEBUG
if ( conf.count("test") ) run_tests();
+#endif
register_feature_functions();
size_t k = conf["kbest"].as<size_t>();
- ReadFile ini_rf(conf["decoder-config"].as<string>());
+ ReadFile ini_rf( conf["decoder-config"].as<string>() );
Decoder decoder(ini_rf.stream());
- KBestGetter observer(k);
- size_t N = 3; // TODO as parameter/in config
+ KBestGetter observer( k );
+ size_t N = conf["ngrams"].as<size_t>();
+ size_t T = conf["epochs"].as<size_t>();
- // TODO scoring metric as parameter/in config
// for approx. bleu
- NgramCounts global_counts(N);
- size_t global_hyp_len = 0;
- size_t global_ref_len = 0;
+ //NgramCounts global_counts( N );
+ //size_t global_hyp_len = 0;
+ //size_t global_ref_len = 0;
Weights weights;
SparseVector<double> lambdas;
@@ -62,20 +82,24 @@ main(int argc, char** argv)
vector<string> strs, ref_strs;
vector<WordID> ref_ids;
string in, psg;
- size_t sid = 0;
- cerr << "(1 dot equals 100 lines of input)" << endl;
+ size_t sn = 0;
+ cerr << "(A dot equals " << DTRAIN_DOTOUT << " lines of input.)" << endl;
+
+ for ( size_t t = 0; t < T; t++ )
+ {
+
while( getline(cin, in) ) {
- if ( (sid+1) % 100 == 0 ) {
+ if ( (sn+1) % DTRAIN_DOTOUT == 0 ) {
cerr << ".";
- if ( (sid+1)%1000 == 0 ) cerr << endl;
+ if ( (sn+1) % (20*DTRAIN_DOTOUT) == 0 ) cerr << endl;
}
- //if ( sid > 5000 ) break;
+ //if ( sn > 5000 ) break;
// weights
dense_weights.clear();
weights.InitFromVector( lambdas );
weights.InitVector( &dense_weights );
decoder.SetWeights( dense_weights );
- // handling input..
+ // handling input
strs.clear();
boost::split( strs, in, boost::is_any_of("\t") );
// grammar
@@ -89,11 +113,11 @@ main(int argc, char** argv)
register_and_convert( ref_strs, ref_ids );
// scoring kbest
double score = 0;
- size_t cand_len = 0;
+ //size_t cand_len = 0;
Scores scores;
for ( size_t i = 0; i < kb->sents.size(); i++ ) {
NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N );
- if ( i == 0) {
+ /*if ( i == 0 ) {
global_counts += counts;
global_hyp_len += kb->sents[i].size();
global_ref_len += ref_ids.size();
@@ -101,24 +125,28 @@ main(int argc, char** argv)
} else {
cand_len = kb->sents[i].size();
}
- //score = bleu( global_counts,
- // global_ref_len,
- // global_hyp_len + cand_len, N );
+ score = bleu( global_counts,
+ global_ref_len,
+ global_hyp_len + cand_len, N );*/
score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N );
ScorePair sp( kb->scores[i], score );
scores.push_back( sp );
- //cout << "'" << TD::GetString( ref_ids ) << "' vs '" << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;
+ //cout << "'" << TD::GetString( ref_ids ) << "' vs '";
+ //cout << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;
//cout << kb->feats[i] << endl;
}
// learner
SofiaLearner learner;
- learner.Init( sid, kb->feats, scores );
+ learner.Init( sn, kb->feats, scores );
learner.Update(lambdas);
//print_FD();
- sid += 1; // TODO does cdec count this already?
+ sn += 1;
}
+
+ } // outer loop
+
cerr << endl;
- weights.WriteToFile( "data/weights-final-normalx", true );
+ weights.WriteToFile( "data/weights-vanilla", false );
return 0;
}