summaryrefslogtreecommitdiff
path: root/dtrain/dtest.cc
blob: 7674a3ca6f75684436f78e6ccc7779843e585a18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include "common.h"
#include "kbestget.h"
#include "util.h"


/*
 * init
 *
 */
bool
init(int argc, char** argv, po::variables_map* conf)
{
  int N;
  bool q;
  po::options_description opts( "Options" );
  opts.add_options()
    ( "decoder-config,c", po::value<string>(),                              "configuration file for cdec" )
    ( "weights,w",        po::value<string>(),                                             "weights file" )
    ( "ngrams,n",         po::value<int>(&N)->default_value(DTRAIN_DEFAULT_N), "N for Ngrams (default 5)" )
    ( "quiet,q",          po::value<bool>(&q)->default_value(true),          "do not output translations" );
  po::options_description cmdline_options;
  cmdline_options.add(opts);
  po::store( parse_command_line(argc, argv, cmdline_options), *conf );
  po::notify( *conf );
  if ( ! (conf->count("decoder-config") || conf->count("weights")) ) {
    cerr << cmdline_options << endl;
    return false;
  }
  return true;
}


/*
 * main
 *
 */
int
main(int argc, char** argv)
{
  SetSilent(true);
  po::variables_map conf;
  if (!init(argc, argv, &conf)) return 1;
  register_feature_functions();
  size_t k = 1;
  ReadFile ini_rf(conf["decoder-config"].as<string>());
  Decoder decoder(ini_rf.stream());
  KBestGetter observer(k);
  size_t N = conf["ngrams"].as<int>();
  bool quiet = conf["quiet"].as<bool>();

  Weights weights;
  weights.InitFromFile(conf["weights"].as<string>());
  vector<double> w;
  weights.InitVector(&w);
  decoder.SetWeights(w);
 
  vector<string> strs, ref_strs;
  vector<WordID> ref_ids;
  string in, psg;
  size_t sn = 0;
  double overall  = 0.0;
  double overall1 = 0.0;
  double overall2 = 0.0;
  cerr << "(A dot represents " << DTRAIN_DOTOUT << " lines of input.)" << endl;
  while( getline(cin, in) ) {
    if ( (sn+1) % DTRAIN_DOTOUT == 0 ) {
        cerr << ".";
        if ( (sn+1) % (20*DTRAIN_DOTOUT) == 0 ) cerr << " " << sn+1 << endl;
    }
    //if ( sn > 5000 ) break;
    strs.clear();
    boost::split( strs, in, boost::is_any_of("\t") );
    // grammar
    psg = boost::replace_all_copy( strs[2], " __NEXT_RULE__ ", "\n" ); psg += "\n";
    decoder.SetSentenceGrammar( psg );
    decoder.Decode( strs[0], &observer );
    KBestList* kb = observer.GetKBest();
    // reference
    ref_strs.clear(); ref_ids.clear();
    boost::split( ref_strs, strs[1], boost::is_any_of(" ") );
    register_and_convert( ref_strs, ref_ids );
    // scoring kbest
    double score  = 0.0;
    double score1 = 0.0;
    double score2 = 0.0;
    NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[0], 4 );
    score =  smooth_bleu( counts, ref_ids.size(), kb->sents[0].size(), N );
    score1 = stupid_bleu( counts, ref_ids.size(), kb->sents[0].size(), N );
    score2 =        bleu( counts, ref_ids.size(), kb->sents[0].size(), N );
    if ( ! quiet ) cout << TD::GetString( kb->sents[0] ) << endl;
    overall += score;
    overall1 += score1;
    overall2 += score2;
    sn += 1;
  }
  cerr << "Average score (smooth) : " << overall/(double)(sn+1) << endl;
  cerr << "Average score (stupid) : " << overall1/(double)(sn+1) << endl;
  cerr << "Average score (vanilla): " << overall2/(double)(sn+1) << endl;
  cerr << endl;

  return 0;
}