diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 48 |
1 files changed, 23 insertions, 25 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 95fc81af..373458e8 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -1,4 +1,5 @@ #include "dcommon.h" +#include "learner.h" @@ -45,41 +46,35 @@ main(int argc, char** argv) ReadFile ini_rf(conf["decoder-config"].as<string>()); Decoder decoder(ini_rf.stream()); KBestGetter observer(k); - size_t N = 4; // TODO as parameter/in config + size_t N = 3; // TODO as parameter/in config // TODO scoring metric as parameter/in config // for approx. bleu - //NgramCounts global_counts; - //size_t global_hyp_len; - //size_t global_ref_len; + NgramCounts global_counts(N); + size_t global_hyp_len = 0; + size_t global_ref_len = 0; Weights weights; SparseVector<double> lambdas; weights.InitSparseVector(&lambdas); vector<double> dense_weights; - lambdas.set_value(FD::Convert("logp"), 0); - - vector<string> strs, ref_strs; vector<WordID> ref_ids; string in, psg; size_t sid = 0; cerr << "(1 dot equals 100 lines of input)" << endl; while( getline(cin, in) ) { - //if ( !SILENT ) - // cerr << endl << endl << "Getting kbest for sentence #" << sid << endl; if ( (sid+1) % 100 == 0 ) { cerr << "."; if ( (sid+1)%1000 == 0 ) cerr << endl; } - if ( sid > 5000 ) break; + //if ( sid > 5000 ) break; // weights dense_weights.clear(); weights.InitFromVector( lambdas ); weights.InitVector( &dense_weights ); decoder.SetWeights( dense_weights ); - //if ( sid > 100 ) break; // handling input.. strs.clear(); boost::split( strs, in, boost::is_any_of("\t") ); @@ -94,33 +89,36 @@ main(int argc, char** argv) register_and_convert( ref_strs, ref_ids ); // scoring kbest double score = 0; + size_t cand_len = 0; Scores scores; - for ( size_t i = 0; i < k; i++ ) { - NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], 4 ); - score = smooth_bleu( counts, - ref_ids.size(), - kb->sents[i].size(), N ); + for ( size_t i = 0; i < kb->sents.size(); i++ ) { + NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N ); + if ( i == 0) { + global_counts += counts; + global_hyp_len += kb->sents[i].size(); + global_ref_len += ref_ids.size(); + cand_len = 0; + } else { + cand_len = kb->sents[i].size(); + } + //score = bleu( global_counts, + // global_ref_len, + // global_hyp_len + cand_len, N ); + score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N ); ScorePair sp( kb->scores[i], score ); scores.push_back( sp ); //cout << "'" << TD::GetString( ref_ids ) << "' vs '" << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl; //cout << kb->feats[i] << endl; } - //cout << "###" << endl; + // learner SofiaLearner learner; learner.Init( sid, kb->feats, scores ); learner.Update(lambdas); - // initializing learner - // TODO - // updating weights - //lambdas.set_value( FD::Convert("use_shell"), 1 ); - //lambdas.set_value( FD::Convert("use_a"), 1 ); //print_FD(); sid += 1; // TODO does cdec count this already? } - - weights.WriteToFile( "weights-final", true ); - cerr << endl; + weights.WriteToFile( "data/weights-final-normalx", true ); return 0; } |