summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc140
1 files changed, 83 insertions, 57 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 4554e417..d58478a8 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -3,6 +3,8 @@
#include "util.h"
#include "sample.h"
+#include "ksampler.h"
+
// boost compression
#include <boost/iostreams/device/file.hpp>
#include <boost/iostreams/filtering_stream.hpp>
@@ -11,6 +13,7 @@
//#include <boost/iostreams/filter/bzip2.hpp>
using namespace boost::iostreams;
+
#ifdef DTRAIN_DEBUG
#include "tests.h"
#endif
@@ -101,6 +104,7 @@ ostream& _prec5( ostream& out ) { return out << setprecision(5); }
int
main( int argc, char** argv )
{
+ cout << setprecision( 5 );
// handle most parameters
po::variables_map cfg;
if ( ! init(argc, argv, &cfg) ) exit(1); // something is wrong
@@ -143,7 +147,9 @@ main( int argc, char** argv )
if ( !quiet )
cout << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
Decoder decoder( ini_rf.stream() );
- KBestGetter observer( k, filter_type );
+ //KBestGetter observer( k, filter_type );
+ MT19937 rng;
+ KSampler observer( k, &rng );
// scoring metric/scorer
string scorer_str = cfg["scorer"].as<string>();
@@ -207,11 +213,13 @@ main( int argc, char** argv )
size_t cand_len = 0;
double overall_time = 0.;
- cout << setprecision( 5 );
-
- // for the perceptron
- double eta = 0.5; // TODO as parameter
+ // for the perceptron/SVM; TODO as params
+ double eta = 0.0005;
+ double gamma = 0.01; // -> SVM
lambdas.add_value( FD::Convert("__bias"), 0 );
+
+ // for random sampling
+ srand ( time(NULL) );
for ( size_t t = 0; t < T; t++ ) // T epochs
@@ -284,44 +292,44 @@ main( int argc, char** argv )
weights.InitVector( &dense_weights );
decoder.SetWeights( dense_weights );
- srand ( time(NULL) );
-
- switch ( t ) {
- case 0:
- // handling input
- in_split.clear();
- boost::split( in_split, in, boost::is_any_of("\t") );
- // in_split[0] is id
- //cout << in_split[0] << endl;
- // getting reference
- ref_tok.clear(); ref_ids.clear();
- boost::split( ref_tok, in_split[2], boost::is_any_of(" ") );
- register_and_convert( ref_tok, ref_ids );
- ref_ids_buf.push_back( ref_ids );
- // process and set grammar
- //grammar_buf << in_split[3] << endl;
- grammar_str = boost::replace_all_copy( in_split[3], " __NEXT__RULE__ ", "\n" ) + "\n"; // FIXME copy, __
- grammar_buf << grammar_str << DTRAIN_GRAMMAR_DELIM << endl;
- decoder.SetSentenceGrammarFromString( grammar_str );
- // decode, kbest
- src_str_buf.push_back( in_split[1] );
- decoder.Decode( in_split[1], &observer );
- break;
- default:
- // get buffered grammar
- grammar_str.clear();
- int i = 1;
- while ( true ) {
- string g;
- getline( grammar_buf_in, g );
- if ( g == DTRAIN_GRAMMAR_DELIM ) break;
- grammar_str += g+"\n";
- i += 1;
+ if ( t == 0 ) {
+ // handling input
+ in_split.clear();
+ boost::split( in_split, in, boost::is_any_of("\t") ); // in_split[0] is id
+ // getting reference
+ ref_tok.clear(); ref_ids.clear();
+ boost::split( ref_tok, in_split[2], boost::is_any_of(" ") );
+ register_and_convert( ref_tok, ref_ids );
+ ref_ids_buf.push_back( ref_ids );
+ // process and set grammar
+ bool broken_grammar = true;
+ for ( string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++ ) {
+ if ( !isspace(*ti) ) {
+ broken_grammar = false;
+ break;
}
- decoder.SetSentenceGrammarFromString( grammar_str );
- // decode, kbest
- decoder.Decode( src_str_buf[sid], &observer );
- break;
+ }
+ if ( broken_grammar ) continue;
+ grammar_str = boost::replace_all_copy( in_split[3], " __NEXT__RULE__ ", "\n" ) + "\n"; // FIXME copy, __
+ grammar_buf << grammar_str << DTRAIN_GRAMMAR_DELIM << endl;
+ decoder.SetSentenceGrammarFromString( grammar_str );
+ // decode, kbest
+ src_str_buf.push_back( in_split[1] );
+ decoder.Decode( in_split[1], &observer );
+ } else {
+ // get buffered grammar
+ grammar_str.clear();
+ int i = 1;
+ while ( true ) {
+ string g;
+ getline( grammar_buf_in, g );
+ if ( g == DTRAIN_GRAMMAR_DELIM ) break;
+ grammar_str += g+"\n";
+ i += 1;
+ }
+ decoder.SetSentenceGrammarFromString( grammar_str );
+ // decode, kbest
+ decoder.Decode( src_str_buf[sid], &observer );
}
// get kbest list
@@ -346,6 +354,7 @@ main( int argc, char** argv )
cand_len = kb->sents[i].size();
}
NgramCounts counts_tmp = global_counts + counts;
+ // TODO as param
score = 0.9 * scorer( counts_tmp,
global_ref_len,
global_hyp_len + cand_len, N, bleu_weights );
@@ -380,31 +389,48 @@ main( int argc, char** argv )
TrainingInstances pairs;
- sample_all(kb, pairs);
+ sample_all_rand(kb, pairs);
+ cout << pairs.size() << endl;
for ( TrainingInstances::iterator ti = pairs.begin();
ti != pairs.end(); ti++ ) {
- // perceptron
+
SparseVector<double> dv;
- if ( ti->type == -1 ) {
+ if ( ti->first_score - ti->second_score < 0 ) {
dv = ti->second - ti->first;
- } else {
- dv = ti->first - ti->second;
- }
- dv.add_value(FD::Convert("__bias"), -1);
+ //} else {
+ //dv = ti->first - ti->second;
+ //}
+ dv.add_value( FD::Convert("__bias"), -1 );
+
+ SparseVector<double> reg;
+ reg = lambdas * ( 2 * gamma );
+ dv -= reg;
lambdas += dv * eta;
- /*if ( verbose ) {
- cout << "{{ f(i) > f(j) but g(i) < g(j), so update" << endl;
- cout << " i " << TD::GetString(kb->sents[ii]) << endl;
- cout << " " << kb->feats[ii] << endl;
- cout << " j " << TD::GetString(kb->sents[jj]) << endl;
- cout << " " << kb->feats[jj] << endl;
- cout << " dv " << dv << endl;
+ if ( verbose ) {
+ cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
+ cout << " i " << TD::GetString(kb->sents[ti->first_rank]) << endl;
+ cout << " " << kb->feats[ti->first_rank] << endl;
+ cout << " j " << TD::GetString(kb->sents[ti->second_rank]) << endl;
+ cout << " " << kb->feats[ti->second_rank] << endl;
+ cout << " diff vec: " << dv << endl;
+ cout << " lambdas after update: " << lambdas << endl;
cout << "}}" << endl;
- }*/
+ }
+
+ } else {
+ //if ( 0 ) {
+ SparseVector<double> reg;
+ reg = lambdas * ( gamma * 2 );
+ lambdas += reg * ( -eta );
+ //}
+ }
}
+ //double l2 = lambdas.l2norm();
+ //if ( l2 ) lambdas /= lambdas.l2norm();
+
}
++sid;