From aa5f96417ff81408b15b54aab35a3c16b845adf8 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Mon, 29 Aug 2011 22:02:45 +0200
Subject: big update: working iterating, pretty output, test scripts and more
---
dtrain/Makefile.am | 8 +-
dtrain/common.h | 15 +-
dtrain/dtrain.cc | 412 +-
dtrain/kbestget.h | 8 +-
dtrain/learner.h | 96 -
dtrain/run.sh | 6 +
dtrain/score.cc | 2 +-
dtrain/score.h | 9 +
dtrain/scripts/run.sh | 4 -
dtrain/scripts/test.sh | 6 -
dtrain/test/cdec.ini | 9 +
dtrain/test/compression-test.cc | 49 +
dtrain/test/test.in | 2 +
dtrain/test/toy.dtrain.ini | 9 +
dtrain/test/toy.in | 2 +
dtrain/test/wc_pipes/bible.txt | 30383 ++++++++++++++++++++++++++++++++++++
dtrain/test/wc_pipes/jobconf.xml | 16 +
dtrain/test/wc_pipes/run.sh | 11 +
dtrain/test/wc_pipes/wordcount.cc | 38 +
dtrain/test/wc_pipes/wordcount.hh | 34 +
dtrain/updater.h | 106 +
21 files changed, 31024 insertions(+), 201 deletions(-)
delete mode 100644 dtrain/learner.h
create mode 100755 dtrain/run.sh
delete mode 100755 dtrain/scripts/run.sh
delete mode 100755 dtrain/scripts/test.sh
create mode 100644 dtrain/test/cdec.ini
create mode 100644 dtrain/test/compression-test.cc
create mode 100644 dtrain/test/test.in
create mode 100644 dtrain/test/toy.dtrain.ini
create mode 100644 dtrain/test/toy.in
create mode 100644 dtrain/test/wc_pipes/bible.txt
create mode 100644 dtrain/test/wc_pipes/jobconf.xml
create mode 100755 dtrain/test/wc_pipes/run.sh
create mode 100644 dtrain/test/wc_pipes/wordcount.cc
create mode 100644 dtrain/test/wc_pipes/wordcount.hh
create mode 100644 dtrain/updater.h
(limited to 'dtrain')
diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am
index 03e3ccf7..f218a8ed 100644
--- a/dtrain/Makefile.am
+++ b/dtrain/Makefile.am
@@ -1,11 +1,11 @@
# TODO I'm sure I can leave something out.
-bin_PROGRAMS = dtrain dtest
+bin_PROGRAMS = dtrain #dtest
dtrain_SOURCES = dtrain.cc score.cc tests.cc util.cc
-dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -lboost_filesystem -lboost_iostreams
-dtest_SOURCES = dtest.cc score.cc util.cc
-dtest_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
+#dtest_SOURCES = dtest.cc score.cc util.cc
+#dtest_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/dtrain/common.h b/dtrain/common.h
index cf365d48..706f51c2 100644
--- a/dtrain/common.h
+++ b/dtrain/common.h
@@ -7,7 +7,9 @@
#include
#include
#include
+#include
+// cdec includes
#include "sentence_metadata.h"
#include "verbose.h"
#include "viterbi.h"
@@ -16,16 +18,19 @@
#include "decoder.h"
#include "weights.h"
+// boost includes
#include
#include
+// own headers
#include "score.h"
-#define DTRAIN_DEFAULT_K 100
-#define DTRAIN_DEFAULT_N 4
-#define DTRAIN_DEFAULT_T 1
-
-#define DTRAIN_DOTOUT 100
+#define DTRAIN_DEFAULT_K 100 // k for kbest lists
+#define DTRAIN_DEFAULT_N 4 // N for ngrams (e.g. BLEU)
+#define DTRAIN_DEFAULT_T 1 // iterations
+#define DTRAIN_DEFAULT_SCORER "stupid_bleu" // scorer
+#define DTRAIN_DOTS 100 // when to display a '.'
+#define DTRAIN_TMP_DIR "/tmp" // put this on a SSD?
using namespace std;
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 6023638a..a141a576 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -1,46 +1,72 @@
#include "common.h"
#include "kbestget.h"
-#include "learner.h"
+#include "updater.h"
#include "util.h"
+// boost compression
+#include
+#include
+#include
+//#include
+//#include
+using namespace boost::iostreams;
+
#ifdef DTRAIN_DEBUG
#include "tests.h"
#endif
-
/*
* init
*
*/
bool
-init(int argc, char** argv, po::variables_map* conf)
+init(int argc, char** argv, po::variables_map* cfg)
{
- po::options_description opts( "Options" );
- size_t k, N, T;
- // TODO scoring metric as parameter/in config
- opts.add_options()
- ( "decoder-config,c", po::value(), "configuration file for cdec" )
- ( "kbest,k", po::value(&k)->default_value(DTRAIN_DEFAULT_K), "k for kbest" )
- ( "ngrams,n", po::value(&N)->default_value(DTRAIN_DEFAULT_N), "n for Ngrams" )
- ( "filter,f", po::value(), "filter kbest list" ) // FIXME
- ( "epochs,t", po::value(&T)->default_value(DTRAIN_DEFAULT_T), "# of iterations T" )
- ( "input,i", po::value(), "input file" )
+ po::options_description conff( "Configuration File Options" );
+ size_t k, N, T, stop;
+ string s;
+ conff.add_options()
+ ( "decoder_config", po::value(), "configuration file for cdec" )
+ ( "kbest", po::value(&k)->default_value(DTRAIN_DEFAULT_K), "k for kbest" )
+ ( "ngrams", po::value(&N)->default_value(DTRAIN_DEFAULT_N), "n for Ngrams" )
+ ( "filter", po::value(), "filter kbest list" ) // FIXME
+ ( "epochs", po::value(&T)->default_value(DTRAIN_DEFAULT_T), "# of iterations T" )
+ ( "input", po::value(), "input file" )
+ ( "scorer", po::value(&s)->default_value(DTRAIN_DEFAULT_SCORER), "scoring metric" )
+ ( "output", po::value(), "output weights file" )
+ ( "stop_after", po::value(&stop)->default_value(0), "stop after X input sentences" )
+ ( "weights_file", po::value(), "input weights file (e.g. from previous iteration" );
+
+ po::options_description clo("Command Line Options");
+ clo.add_options()
+ ( "config,c", po::value(), "dtrain config file" )
+ ( "quiet,q", po::value()->zero_tokens(), "be quiet" )
+ ( "verbose,v", po::value()->zero_tokens(), "be verbose" )
#ifndef DTRAIN_DEBUG
;
#else
( "test", "run tests and exit");
#endif
- po::options_description cmdline_options;
- cmdline_options.add(opts);
- po::store( parse_command_line(argc, argv, cmdline_options), *conf );
- po::notify( *conf );
- if ( ! conf->count("decoder-config") || ! conf->count("input") ) {
+ po::options_description config_options, cmdline_options;
+
+ config_options.add(conff);
+ cmdline_options.add(clo);
+ cmdline_options.add(conff);
+
+ po::store( parse_command_line(argc, argv, cmdline_options), *cfg );
+ if ( cfg->count("config") ) {
+ ifstream config( (*cfg)["config"].as().c_str() );
+ po::store( po::parse_config_file(config, config_options), *cfg );
+ }
+ po::notify(*cfg);
+
+ if ( !cfg->count("decoder_config") || !cfg->count("input") ) {
cerr << cmdline_options << endl;
return false;
}
#ifdef DTRAIN_DEBUG
- if ( ! conf->count("test") ) {
+ if ( !cfg->count("test") ) {
cerr << cmdline_options << endl;
return false;
}
@@ -49,6 +75,12 @@ init(int argc, char** argv, po::variables_map* conf)
}
+ostream& _nopos( ostream& out ) { return out << resetiosflags( ios::showpos ); }
+ostream& _pos( ostream& out ) { return out << setiosflags( ios::showpos ); }
+ostream& _prec2( ostream& out ) { return out << setprecision(2); }
+ostream& _prec5( ostream& out ) { return out << setprecision(5); }
+
+
/*
* main
*
@@ -56,104 +88,320 @@ init(int argc, char** argv, po::variables_map* conf)
int
main(int argc, char** argv)
{
- SetSilent(true);
- po::variables_map conf;
- if (!init(argc, argv, &conf)) return 1;
+ // handle most parameters
+ po::variables_map cfg;
+ if ( ! init(argc, argv, &cfg) ) exit(1); // something is wrong
#ifdef DTRAIN_DEBUG
- if ( conf.count("test") ) run_tests();
+ if ( cfg.count("test") ) run_tests(); // run tests and exit
#endif
+ bool quiet = false;
+ if ( cfg.count("quiet") ) quiet = true;
+ bool verbose = false;
+ if ( cfg.count("verbose") ) verbose = true;
+ const size_t k = cfg["kbest"].as();
+ const size_t N = cfg["ngrams"].as();
+ const size_t T = cfg["epochs"].as();
+ const size_t stop_after = cfg["stop_after"].as();
+ if ( !quiet ) {
+ cout << endl << "dtrain" << endl << "Parameters:" << endl;
+ cout << setw(16) << "k " << k << endl;
+ cout << setw(16) << "N " << N << endl;
+ cout << setw(16) << "T " << T << endl;
+ if ( cfg.count("stop-after") )
+ cout << setw(16) << "stop_after " << stop_after << endl;
+ if ( cfg.count("weights") )
+ cout << setw(16) << "weights " << cfg["weights"].as() << endl;
+ cout << setw(16) << "input " << "'" << cfg["input"].as() << "'" << endl;
+ }
+
+ // setup decoder, observer
register_feature_functions();
- size_t k = conf["kbest"].as();
- ReadFile ini_rf( conf["decoder-config"].as() );
+ SetSilent(true);
+ ReadFile ini_rf( cfg["decoder_config"].as() );
+ if ( !quiet )
+ cout << setw(16) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl;
Decoder decoder(ini_rf.stream());
KBestGetter observer( k );
- size_t N = conf["ngrams"].as();
- size_t T = conf["epochs"].as();
- // for approx. bleu
- //NgramCounts global_counts( N );
- //size_t global_hyp_len = 0;
- //size_t global_ref_len = 0;
+ // scoring metric/scorer
+ string scorer_str = cfg["scorer"].as();
+ double (*scorer)( NgramCounts&, const size_t, const size_t, size_t, vector );
+ if ( scorer_str == "bleu" ) {
+ scorer = &bleu;
+ } else if ( scorer_str == "stupid_bleu" ) {
+ scorer = &stupid_bleu;
+ } else if ( scorer_str == "smooth_bleu" ) {
+ scorer = &smooth_bleu;
+ } else if ( scorer_str == "approx_bleu" ) {
+ scorer = &approx_bleu;
+ } else {
+ cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
+ exit(1);
+ }
+ // for approx_bleu
+ NgramCounts global_counts( N ); // counts for 1 best translations
+ size_t global_hyp_len = 0; // sum hypothesis lengths
+ size_t global_ref_len = 0; // sum reference lengths
+ // this is all BLEU implmentations
+ vector bleu_weights; // we leave this empty -> 1/N; TODO?
+ if ( !quiet ) cout << setw(16) << "scorer '" << scorer_str << "'" << endl << endl;
+ // init weights
Weights weights;
+ if ( cfg.count("weights") ) weights.InitFromFile( cfg["weights"].as() );
SparseVector lambdas;
weights.InitSparseVector(&lambdas);
vector dense_weights;
- vector strs, ref_strs;
- vector ref_ids;
- string in, psg;
- size_t sn = 0;
- cerr << "(A dot represents " << DTRAIN_DOTOUT << " lines of input.)" << endl;
-
- string fname = conf["input"].as();
+ // input
+ if ( !quiet && !verbose )
+ cout << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
+ string input_fn = cfg["input"].as();
ifstream input;
- input.open( fname.c_str() );
+ if ( input_fn != "-" ) input.open( input_fn.c_str() );
+ string in;
+ vector in_split; // input: src\tref\tpsg
+ vector ref_tok; // tokenized reference
+ vector ref_ids; // reference as vector of WordID
+ string grammar_str;
+
+ // buffer input for t > 0
+ vector src_str_buf; // source strings, TODO? memory
+ vector > ref_ids_buf; // references as WordID vecs
+ filtering_ostream grammar_buf; // written to compressed file in /tmp
+ // this is for writing the grammar buffer file
+ grammar_buf.push( gzip_compressor() );
+ char grammar_buf_tmp_fn[] = DTRAIN_TMP_DIR"/dtrain-grammars-XXXXXX";
+ mkstemp( grammar_buf_tmp_fn );
+ grammar_buf.push( file_sink(grammar_buf_tmp_fn, ios::binary | ios::trunc) );
+
+ size_t sid = 0, in_sz = 99999999; // sentence id, input size
+ double acc_1best_score = 0., acc_1best_model = 0.;
+ vector > scores_per_iter;
+ double max_score = 0.;
+ size_t best_t = 0;
+ bool next = false, stop = false;
+ double score = 0.;
+ size_t cand_len = 0;
+ Scores scores;
+ double overall_time = 0.;
+
+ cout << setprecision( 5 );
+
- for ( size_t t = 0; t < T; t++ )
+ for ( size_t t = 0; t < T; t++ ) // T epochs
{
- input.seekg(0);
- cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
- while( getline(input, in) ) {
- if ( (sn+1) % DTRAIN_DOTOUT == 0 ) {
- cerr << ".";
- if ( (sn+1) % (20*DTRAIN_DOTOUT) == 0 ) cerr << " " << sn+1 << endl;
+ time_t start, end;
+ time( &start );
+
+ // actually, we need only need this if t > 0 FIXME
+ ifstream grammar_file( grammar_buf_tmp_fn, ios_base::in | ios_base::binary );
+ filtering_istream grammar_buf_in;
+ grammar_buf_in.push( gzip_decompressor() );
+ grammar_buf_in.push( grammar_file );
+
+ // reset average scores
+ acc_1best_score = acc_1best_model = 0.;
+
+ sid = 0; // reset sentence counter
+
+ if ( !quiet ) cout << "Iteration #" << t+1 << " of " << T << "." << endl;
+
+ while( true ) {
+
+ // get input from stdin or file
+ in.clear();
+ next = stop = false; // next iteration, premature stop
+ if ( t == 0 ) {
+ if ( input_fn == "-" ) {
+ if ( !getline(cin, in) ) next = true;
+ } else {
+ if ( !getline(input, in) ) next = true;
+ }
+ } else {
+ if ( sid == in_sz ) next = true; // stop if we reach the end of our input
+ }
+ // stop after X sentences (but still iterate for those)
+ if ( stop_after > 0 && stop_after == sid && !next ) stop = true;
+
+ // produce some pretty output
+ if ( !quiet && !verbose ) {
+ if ( sid == 0 ) cout << " ";
+ if ( (sid+1) % (DTRAIN_DOTS) == 0 ) {
+ cout << ".";
+ cout.flush();
+ }
+ if ( (sid+1) % (20*DTRAIN_DOTS) == 0) {
+ cout << " " << sid+1 << endl;
+ if ( !next && !stop ) cout << " ";
+ }
+ if ( stop ) {
+ if ( sid % (20*DTRAIN_DOTS) != 0 ) cout << " " << sid << endl;
+ cout << "Stopping after " << stop_after << " input sentences." << endl;
+ } else {
+ if ( next ) {
+ if ( sid % (20*DTRAIN_DOTS) != 0 ) {
+ cout << " " << sid << endl;
+ }
+ }
+ }
}
- //if ( sn > 5000 ) break;
+
+ // next iteration
+ if ( next || stop ) break;
+
// weights
dense_weights.clear();
weights.InitFromVector( lambdas );
weights.InitVector( &dense_weights );
decoder.SetWeights( dense_weights );
- // handling input
- strs.clear();
- boost::split( strs, in, boost::is_any_of("\t") );
- // grammar
- psg = boost::replace_all_copy( strs[2], " __NEXT_RULE__ ", "\n" ); psg += "\n";
- decoder.SetSentenceGrammar( psg );
- decoder.Decode( strs[0], &observer );
+
+ switch ( t ) {
+ case 0:
+ // handling input
+ in_split.clear();
+ boost::split( in_split, in, boost::is_any_of("\t") );
+ // getting reference
+ ref_tok.clear(); ref_ids.clear();
+ boost::split( ref_tok, in_split[1], boost::is_any_of(" ") );
+ register_and_convert( ref_tok, ref_ids );
+ ref_ids_buf.push_back( ref_ids );
+ // process and set grammar
+ grammar_buf << in_split[2] << endl;
+ grammar_str = boost::replace_all_copy( in_split[2], " __NEXT_RULE__ ", "\n" );
+ grammar_str += "\n";
+ decoder.SetSentenceGrammarFromString( grammar_str );
+ // decode, kbest
+ src_str_buf.push_back( in_split[0] );
+ decoder.Decode( in_split[0], &observer );
+ break;
+ default:
+ // get buffered grammar
+ string g;
+ getline(grammar_buf_in, g);
+ grammar_str = boost::replace_all_copy( g, " __NEXT_RULE__ ", "\n" );
+ grammar_str += "\n";
+ decoder.SetSentenceGrammarFromString( grammar_str );
+ // decode, kbest
+ decoder.Decode( src_str_buf[sid], &observer );
+ break;
+ }
+
+ // get kbest list
KBestList* kb = observer.GetKBest();
- // reference
- ref_strs.clear(); ref_ids.clear();
- boost::split( ref_strs, strs[1], boost::is_any_of(" ") );
- register_and_convert( ref_strs, ref_ids );
+
// scoring kbest
- double score = 0;
- //size_t cand_len = 0;
- Scores scores;
+ scores.clear();
+ if ( t > 0 ) ref_ids = ref_ids_buf[sid];
for ( size_t i = 0; i < kb->sents.size(); i++ ) {
NgramCounts counts = make_ngram_counts( ref_ids, kb->sents[i], N );
- /*if ( i == 0 ) {
- global_counts += counts;
- global_hyp_len += kb->sents[i].size();
- global_ref_len += ref_ids.size();
- cand_len = 0;
+ // for approx bleu
+ if ( scorer_str == "approx_bleu" ) {
+ if ( i == 0 ) { // 'context of 1best translations'
+ global_counts += counts;
+ global_hyp_len += kb->sents[i].size();
+ global_ref_len += ref_ids.size();
+ counts.reset();
+ cand_len = 0;
+ } else {
+ cand_len = kb->sents[i].size();
+ }
+ NgramCounts counts_tmp = global_counts + counts;
+ score = scorer( counts_tmp,
+ global_ref_len,
+ global_hyp_len + cand_len, N, bleu_weights );
} else {
+ // other scorers
cand_len = kb->sents[i].size();
+ score = scorer( counts,
+ ref_ids.size(),
+ kb->sents[i].size(), N, bleu_weights );
+ }
+
+ if ( i == 0 ) {
+ acc_1best_score += score;
+ acc_1best_model += kb->scores[i];
}
- score = bleu( global_counts,
- global_ref_len,
- global_hyp_len + cand_len, N );*/
- score = bleu ( counts, ref_ids.size(), kb->sents[i].size(), N );
+
+ // scorer score and model score
ScorePair sp( kb->scores[i], score );
scores.push_back( sp );
- //cout << "'" << TD::GetString( ref_ids ) << "' vs '";
- //cout << TD::GetString( kb->sents[i] ) << "' SCORE=" << score << endl;
- //cout << kb->feats[i] << endl;
- }
- // learner
- SofiaLearner learner;
- learner.Init( sn, kb->feats, scores );
- learner.Update(lambdas);
- //print_FD();
- sn += 1;
+
+ if ( verbose ) {
+ cout << "k=" << i+1 << " '" << TD::GetString( ref_ids ) << "'[ref] vs '";
+ cout << _prec5 << _nopos << TD::GetString( kb->sents[i] ) << "'[hyp]";
+ cout << " [SCORE=" << score << ",model="<< kb->scores[i] << "]" << endl;
+ //cout << kb->feats[i] << endl; this is maybe too verbose
+ }
+ } // Nbest loop
+ if ( verbose ) cout << endl;
+
+ // update weights; FIXME others
+ SofiaUpdater updater;
+ updater.Init( sid, kb->feats, scores );
+ updater.Update( lambdas );
+
+ ++sid;
+
+ } // input loop
+
+ if ( t == 0 ) in_sz = sid; // remember size (lines) of input
+
+ // print some stats
+ double avg_1best_score = acc_1best_score/(double)in_sz;
+ double avg_1best_model = acc_1best_model/(double)in_sz;
+ double avg_1best_score_diff, avg_1best_model_diff;
+ if ( t > 0 ) {
+ avg_1best_score_diff = avg_1best_score - scores_per_iter[t-1][0];
+ avg_1best_model_diff = avg_1best_model - scores_per_iter[t-1][1];
+ } else {
+ avg_1best_score_diff = avg_1best_score;
+ avg_1best_model_diff = avg_1best_model;
}
+ cout << _prec5 << _nopos << "(sanity weights Glue=" << dense_weights[FD::Convert( "Glue" )];
+ cout << " LexEF=" << dense_weights[FD::Convert( "LexEF" )];
+ cout << " LexFE=" << dense_weights[FD::Convert( "LexFE" )] << ")" << endl;
+ cout << " avg score: " << avg_1best_score;
+ cout << _pos << " (" << avg_1best_score_diff << ")" << endl;
+ cout << _nopos << "avg modelscore: " << avg_1best_model;
+ cout << _pos << " (" << avg_1best_model_diff << ")" << endl;
+ vector remember_scores;
+ remember_scores.push_back( avg_1best_score );
+ remember_scores.push_back( avg_1best_model );
+ scores_per_iter.push_back( remember_scores );
+ if ( avg_1best_score > max_score ) {
+ max_score = avg_1best_score;
+ best_t = t;
+ }
+
+ // close open files
+ if ( input_fn != "-" ) input.close();
+ close( grammar_buf );
+ grammar_file.close();
+
+ time ( &end );
+ double time_dif = difftime( end, start );
+ overall_time += time_dif;
+ if ( !quiet ) {
+ cout << _prec2 << _nopos << "(time " << time_dif/60. << " min, ";
+ cout << time_dif/(double)in_sz<< " s/S)" << endl;
+ }
+
+ if ( t+1 != T ) cout << endl;
} // outer loop
- cerr << endl;
- weights.WriteToFile( "output/weights-vanilla", true );
+ unlink( grammar_buf_tmp_fn );
+ if ( !quiet ) cout << endl << "writing weights file '" << cfg["output"].as() << "' ...";
+ weights.WriteToFile( cfg["output"].as(), true );
+ if ( !quiet ) cout << "done" << endl;
+
+ if ( !quiet ) {
+ cout << _prec5 << _nopos << endl << "---" << endl << "Best iteration: ";
+ cout << best_t+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl;
+ cout << _prec2 << "This took " << overall_time/60. << " min." << endl;
+ }
return 0;
}
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index 6d93d3b7..5247a2be 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -1,6 +1,7 @@
#ifndef _DTRAIN_KBESTGET_H_
#define _DTRAIN_KBESTGET_H_
+#include "kbest.h"
namespace dtrain
{
@@ -36,14 +37,15 @@ struct KBestGetter : public DecoderObserver
KBestList* GetKBest() { return &kb; }
void
- GetKBest(int sent_id, const Hypergraph& forest)
+ GetKBest(int sid, const Hypergraph& forest)
{
kb.scores.clear();
kb.sents.clear();
kb.feats.clear();
- KBest::KBestDerivations, ESentenceTraversal> kbest( forest, k_ );
+ // FIXME TODO FIXME TODO
+ KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest( forest, k_ );
for ( size_t i = 0; i < k_; ++i ) {
- const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d =
+ const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
kbest.LazyKthBest( forest.nodes_.size() - 1, i );
if (!d) break;
kb.sents.push_back( d->yield);
diff --git a/dtrain/learner.h b/dtrain/learner.h
deleted file mode 100644
index 038749e2..00000000
--- a/dtrain/learner.h
+++ /dev/null
@@ -1,96 +0,0 @@
-#ifndef _DTRAIN_LEARNER_H_
-#define _DTRAIN_LEARNER_H_
-
-#include
-#include
-#include