diff options
| author | Patrick Simianer <p@simianer.de> | 2011-08-03 01:29:52 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2011-09-23 19:13:57 +0200 | 
| commit | 2e605eb2745e56619b16fdbcb8095e0a6543ab27 (patch) | |
| tree | 03c122c3add26365eb8f3f84aec2a533d7222cab /dtrain/dtrain.cc | |
| parent | b7568a8dad2720d5ea0418171e9b85229adbbcc5 (diff) | |
refactoring, cleaning up
Diffstat (limited to 'dtrain/dtrain.cc')
| -rw-r--r-- | dtrain/dtrain.cc | 86 | 
1 files changed, 57 insertions, 29 deletions
| diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 373458e8..16b83a70 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -1,6 +1,11 @@ -#include "dcommon.h" +#include "common.h" +#include "kbestget.h"  #include "learner.h" +#include "util.h" +#ifdef DTRAIN_DEBUG +#include "tests.h" +#endif @@ -12,20 +17,33 @@ bool  init(int argc, char** argv, po::variables_map* conf)  {    po::options_description opts( "Options" ); +  size_t k, N, T; +  // TODO scoring metric as parameter/in config     opts.add_options() -    ( "decoder-config,c", po::value<string>(), "configuration file for cdec" ) -    ( "kbest,k",          po::value<size_t>(), "k for kbest" ) -    ( "ngrams,n",         po::value<int>(),    "n for Ngrams" ) -    ( "filter,f",         po::value<string>(), "filter kbest list" ) -    ( "test",                                       "run tests and exit"); +    ( "decoder-config,c", po::value<string>(),                          "configuration file for cdec" ) +    ( "kbest,k",          po::value<size_t>(&k)->default_value(DTRAIN_DEFAULT_K),       "k for kbest" ) +    ( "ngrams,n",         po::value<size_t>(&N)->default_value(DTRAIN_DEFAULT_N),      "n for Ngrams" ) +    ( "filter,f",         po::value<string>(),                                    "filter kbest list" ) // FIXME +    ( "epochs,t",         po::value<size_t>(&T)->default_value(DTRAIN_DEFAULT_T), "# of iterations T" )  +#ifndef DTRAIN_DEBUG +    ; +#else +    ( "test",                                  "run tests and exit"); +#endif    po::options_description cmdline_options;    cmdline_options.add(opts);    po::store( parse_command_line(argc, argv, cmdline_options), *conf );    po::notify( *conf ); -  if ( ! (conf->count("decoder-config") || conf->count("test")) ) { +  if ( ! conf->count("decoder-config") ) {       cerr << cmdline_options << endl;      return false;    } +  #ifdef DTRAIN_DEBUG        +  if ( ! conf->count("test") ) { +    cerr << cmdline_options << endl; +    return false; +  } +  #endif    return true;  } @@ -40,19 +58,21 @@ main(int argc, char** argv)    SetSilent(true);    po::variables_map conf;    if (!init(argc, argv, &conf)) return 1; +#ifdef DTRAIN_DEBUG    if ( conf.count("test") ) run_tests();  +#endif    register_feature_functions();    size_t k = conf["kbest"].as<size_t>(); -  ReadFile ini_rf(conf["decoder-config"].as<string>()); +  ReadFile ini_rf( conf["decoder-config"].as<string>() );    Decoder decoder(ini_rf.stream()); -  KBestGetter observer(k); -  size_t N = 3; // TODO as parameter/in config  +  KBestGetter observer( k ); +  size_t N = conf["ngrams"].as<size_t>();  +  size_t T = conf["epochs"].as<size_t>(); -  // TODO scoring metric as parameter/in config     // for approx. bleu -  NgramCounts global_counts(N); -  size_t global_hyp_len = 0; -  size_t global_ref_len = 0; +  //NgramCounts global_counts( N ); +  //size_t global_hyp_len = 0; +  //size_t global_ref_len = 0;    Weights weights;    SparseVector<double> lambdas; @@ -62,20 +82,24 @@ main(int argc, char** argv)    vector<string> strs, ref_strs;    vector<WordID> ref_ids;    string in, psg; -  size_t sid = 0; -  cerr << "(1 dot equals 100 lines of input)" << endl; +  size_t sn = 0; +  cerr << "(A dot equals " << DTRAIN_DOTOUT << " lines of input.)" << endl; + +  for ( size_t t = 0; t < T; t++ ) +  { +    while( getline(cin, in) ) { -    if ( (sid+1) % 100 == 0 ) { +    if ( (sn+1) % DTRAIN_DOTOUT == 0 ) {          cerr << "."; -        if ( (sid+1)%1000 == 0 ) cerr << endl; +        if ( (sn+1) % (20*DTRAIN_DOTOUT) == 0 ) cerr << endl;      } -    //if ( sid > 5000 ) break; +    //if ( sn > 5000 ) break;      // weights      dense_weights.clear();      weights.InitFromVector( lambdas );      weights.InitVector( &dense_weights );      decoder.SetWeights( dense_weights ); -    // handling input.. +    // handling input      strs.clear();      boost::split( strs, in, boost::is_any_of("\t") );      // grammar @@ -89,11 +113,11 @@ main(int argc, char** argv)      register_and_convert( ref_strs, ref_ids );      // scoring kbest      double score = 0; -    size_t cand_len = 0; +    //size_t cand_len = 0;      Scores scores;      for ( size_t i = 0; i < kb->sents.size(); i++ ) {        NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N ); -      if ( i == 0) { +      /*if ( i == 0 ) {          global_counts += counts;          global_hyp_len += kb->sents[i].size();          global_ref_len += ref_ids.size(); @@ -101,24 +125,28 @@ main(int argc, char** argv)        } else {          cand_len = kb->sents[i].size();        } -      //score = bleu( global_counts, -      //                     global_ref_len, -       //                    global_hyp_len + cand_len, N ); +      score = bleu( global_counts, +                    global_ref_len, +                     global_hyp_len + cand_len, N );*/        score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N );        ScorePair sp( kb->scores[i], score );        scores.push_back( sp ); -      //cout << "'" << TD::GetString( ref_ids ) << "' vs '" << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl; +      //cout << "'" << TD::GetString( ref_ids ) << "' vs '"; +      //cout << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;        //cout << kb->feats[i] << endl;      }      // learner      SofiaLearner learner; -    learner.Init( sid, kb->feats, scores ); +    learner.Init( sn, kb->feats, scores );      learner.Update(lambdas);      //print_FD(); -    sid += 1; // TODO does cdec count this already? +    sn += 1;    } + +  } // outer loop +    cerr << endl; -  weights.WriteToFile( "data/weights-final-normalx", true ); +  weights.WriteToFile( "data/weights-vanilla", false );    return 0;  } | 
