summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc48
1 files changed, 23 insertions, 25 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 95fc81af..373458e8 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -1,4 +1,5 @@
#include "dcommon.h"
+#include "learner.h"
@@ -45,41 +46,35 @@ main(int argc, char** argv)
ReadFile ini_rf(conf["decoder-config"].as<string>());
Decoder decoder(ini_rf.stream());
KBestGetter observer(k);
- size_t N = 4; // TODO as parameter/in config
+ size_t N = 3; // TODO as parameter/in config
// TODO scoring metric as parameter/in config
// for approx. bleu
- //NgramCounts global_counts;
- //size_t global_hyp_len;
- //size_t global_ref_len;
+ NgramCounts global_counts(N);
+ size_t global_hyp_len = 0;
+ size_t global_ref_len = 0;
Weights weights;
SparseVector<double> lambdas;
weights.InitSparseVector(&lambdas);
vector<double> dense_weights;
- lambdas.set_value(FD::Convert("logp"), 0);
-
-
vector<string> strs, ref_strs;
vector<WordID> ref_ids;
string in, psg;
size_t sid = 0;
cerr << "(1 dot equals 100 lines of input)" << endl;
while( getline(cin, in) ) {
- //if ( !SILENT )
- // cerr << endl << endl << "Getting kbest for sentence #" << sid << endl;
if ( (sid+1) % 100 == 0 ) {
cerr << ".";
if ( (sid+1)%1000 == 0 ) cerr << endl;
}
- if ( sid > 5000 ) break;
+ //if ( sid > 5000 ) break;
// weights
dense_weights.clear();
weights.InitFromVector( lambdas );
weights.InitVector( &dense_weights );
decoder.SetWeights( dense_weights );
- //if ( sid > 100 ) break;
// handling input..
strs.clear();
boost::split( strs, in, boost::is_any_of("\t") );
@@ -94,33 +89,36 @@ main(int argc, char** argv)
register_and_convert( ref_strs, ref_ids );
// scoring kbest
double score = 0;
+ size_t cand_len = 0;
Scores scores;
- for ( size_t i = 0; i < k; i++ ) {
- NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], 4 );
- score = smooth_bleu( counts,
- ref_ids.size(),
- kb->sents[i].size(), N );
+ for ( size_t i = 0; i < kb->sents.size(); i++ ) {
+ NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N );
+ if ( i == 0) {
+ global_counts += counts;
+ global_hyp_len += kb->sents[i].size();
+ global_ref_len += ref_ids.size();
+ cand_len = 0;
+ } else {
+ cand_len = kb->sents[i].size();
+ }
+ //score = bleu( global_counts,
+ // global_ref_len,
+ // global_hyp_len + cand_len, N );
+ score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N );
ScorePair sp( kb->scores[i], score );
scores.push_back( sp );
//cout << "'" << TD::GetString( ref_ids ) << "' vs '" << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;
//cout << kb->feats[i] << endl;
}
- //cout << "###" << endl;
+ // learner
SofiaLearner learner;
learner.Init( sid, kb->feats, scores );
learner.Update(lambdas);
- // initializing learner
- // TODO
- // updating weights
- //lambdas.set_value( FD::Convert("use_shell"), 1 );
- //lambdas.set_value( FD::Convert("use_a"), 1 );
//print_FD();
sid += 1; // TODO does cdec count this already?
}
-
- weights.WriteToFile( "weights-final", true );
-
cerr << endl;
+ weights.WriteToFile( "data/weights-final-normalx", true );
return 0;
}