summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
blob: 373458e84a011da081846d62987e72fb3f939f02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include "dcommon.h"
#include "learner.h"




/*
 * init
 *
 */
bool
init(int argc, char** argv, po::variables_map* conf)
{
  po::options_description opts( "Options" );
  opts.add_options()
    ( "decoder-config,c", po::value<string>(), "configuration file for cdec" )
    ( "kbest,k",          po::value<size_t>(), "k for kbest" )
    ( "ngrams,n",         po::value<int>(),    "n for Ngrams" )
    ( "filter,f",         po::value<string>(), "filter kbest list" )
    ( "test",                                       "run tests and exit");
  po::options_description cmdline_options;
  cmdline_options.add(opts);
  po::store( parse_command_line(argc, argv, cmdline_options), *conf );
  po::notify( *conf );
  if ( ! (conf->count("decoder-config") || conf->count("test")) ) {
    cerr << cmdline_options << endl;
    return false;
  }
  return true;
}


/*
 * main
 *
 */
int
main(int argc, char** argv)
{
  SetSilent(true);
  po::variables_map conf;
  if (!init(argc, argv, &conf)) return 1;
  if ( conf.count("test") ) run_tests(); 
  register_feature_functions();
  size_t k = conf["kbest"].as<size_t>();
  ReadFile ini_rf(conf["decoder-config"].as<string>());
  Decoder decoder(ini_rf.stream());
  KBestGetter observer(k);
  size_t N = 3; // TODO as parameter/in config 

  // TODO scoring metric as parameter/in config 
  // for approx. bleu
  NgramCounts global_counts(N);
  size_t global_hyp_len = 0;
  size_t global_ref_len = 0;

  Weights weights;
  SparseVector<double> lambdas;
  weights.InitSparseVector(&lambdas);
  vector<double> dense_weights;

  vector<string> strs, ref_strs;
  vector<WordID> ref_ids;
  string in, psg;
  size_t sid = 0;
  cerr << "(1 dot equals 100 lines of input)" << endl;
  while( getline(cin, in) ) {
    if ( (sid+1) % 100 == 0 ) {
        cerr << ".";
        if ( (sid+1)%1000 == 0 ) cerr << endl;
    }
    //if ( sid > 5000 ) break;
    // weights
    dense_weights.clear();
    weights.InitFromVector( lambdas );
    weights.InitVector( &dense_weights );
    decoder.SetWeights( dense_weights );
    // handling input..
    strs.clear();
    boost::split( strs, in, boost::is_any_of("\t") );
    // grammar
    psg = boost::replace_all_copy( strs[2], " __NEXT_RULE__ ", "\n" ); psg += "\n";
    decoder.SetSentenceGrammar( psg );
    decoder.Decode( strs[0], &observer );
    KBestList* kb = observer.GetKBest();
    // reference
    ref_strs.clear(); ref_ids.clear();
    boost::split( ref_strs, strs[1], boost::is_any_of(" ") );
    register_and_convert( ref_strs, ref_ids );
    // scoring kbest
    double score = 0;
    size_t cand_len = 0;
    Scores scores;
    for ( size_t i = 0; i < kb->sents.size(); i++ ) {
      NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N );
      if ( i == 0) {
        global_counts += counts;
        global_hyp_len += kb->sents[i].size();
        global_ref_len += ref_ids.size();
        cand_len = 0;
      } else {
        cand_len = kb->sents[i].size();
      }
      //score = bleu( global_counts,
      //                     global_ref_len,
       //                    global_hyp_len + cand_len, N );
      score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N );
      ScorePair sp( kb->scores[i], score );
      scores.push_back( sp );
      //cout << "'" << TD::GetString( ref_ids ) << "' vs '" << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;
      //cout << kb->feats[i] << endl;
    }
    // learner
    SofiaLearner learner;
    learner.Init( sid, kb->feats, scores );
    learner.Update(lambdas);
    //print_FD();
    sid += 1; // TODO does cdec count this already?
  }
  cerr << endl;
  weights.WriteToFile( "data/weights-final-normalx", true );

  return 0;
}