diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 140 |
1 files changed, 83 insertions, 57 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 4554e417..d58478a8 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -3,6 +3,8 @@ #include "util.h" #include "sample.h" +#include "ksampler.h" + // boost compression #include <boost/iostreams/device/file.hpp> #include <boost/iostreams/filtering_stream.hpp> @@ -11,6 +13,7 @@ //#include <boost/iostreams/filter/bzip2.hpp> using namespace boost::iostreams; + #ifdef DTRAIN_DEBUG #include "tests.h" #endif @@ -101,6 +104,7 @@ ostream& _prec5( ostream& out ) { return out << setprecision(5); } int main( int argc, char** argv ) { + cout << setprecision( 5 ); // handle most parameters po::variables_map cfg; if ( ! init(argc, argv, &cfg) ) exit(1); // something is wrong @@ -143,7 +147,9 @@ main( int argc, char** argv ) if ( !quiet ) cout << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; Decoder decoder( ini_rf.stream() ); - KBestGetter observer( k, filter_type ); + //KBestGetter observer( k, filter_type ); + MT19937 rng; + KSampler observer( k, &rng ); // scoring metric/scorer string scorer_str = cfg["scorer"].as<string>(); @@ -207,11 +213,13 @@ main( int argc, char** argv ) size_t cand_len = 0; double overall_time = 0.; - cout << setprecision( 5 ); - - // for the perceptron - double eta = 0.5; // TODO as parameter + // for the perceptron/SVM; TODO as params + double eta = 0.0005; + double gamma = 0.01; // -> SVM lambdas.add_value( FD::Convert("__bias"), 0 ); + + // for random sampling + srand ( time(NULL) ); for ( size_t t = 0; t < T; t++ ) // T epochs @@ -284,44 +292,44 @@ main( int argc, char** argv ) weights.InitVector( &dense_weights ); decoder.SetWeights( dense_weights ); - srand ( time(NULL) ); - - switch ( t ) { - case 0: - // handling input - in_split.clear(); - boost::split( in_split, in, boost::is_any_of("\t") ); - // in_split[0] is id - //cout << in_split[0] << endl; - // getting reference - ref_tok.clear(); ref_ids.clear(); - boost::split( ref_tok, in_split[2], boost::is_any_of(" ") ); - register_and_convert( ref_tok, ref_ids ); - ref_ids_buf.push_back( ref_ids ); - // process and set grammar - //grammar_buf << in_split[3] << endl; - grammar_str = boost::replace_all_copy( in_split[3], " __NEXT__RULE__ ", "\n" ) + "\n"; // FIXME copy, __ - grammar_buf << grammar_str << DTRAIN_GRAMMAR_DELIM << endl; - decoder.SetSentenceGrammarFromString( grammar_str ); - // decode, kbest - src_str_buf.push_back( in_split[1] ); - decoder.Decode( in_split[1], &observer ); - break; - default: - // get buffered grammar - grammar_str.clear(); - int i = 1; - while ( true ) { - string g; - getline( grammar_buf_in, g ); - if ( g == DTRAIN_GRAMMAR_DELIM ) break; - grammar_str += g+"\n"; - i += 1; + if ( t == 0 ) { + // handling input + in_split.clear(); + boost::split( in_split, in, boost::is_any_of("\t") ); // in_split[0] is id + // getting reference + ref_tok.clear(); ref_ids.clear(); + boost::split( ref_tok, in_split[2], boost::is_any_of(" ") ); + register_and_convert( ref_tok, ref_ids ); + ref_ids_buf.push_back( ref_ids ); + // process and set grammar + bool broken_grammar = true; + for ( string::iterator ti = in_split[3].begin(); ti != in_split[3].end(); ti++ ) { + if ( !isspace(*ti) ) { + broken_grammar = false; + break; } - decoder.SetSentenceGrammarFromString( grammar_str ); - // decode, kbest - decoder.Decode( src_str_buf[sid], &observer ); - break; + } + if ( broken_grammar ) continue; + grammar_str = boost::replace_all_copy( in_split[3], " __NEXT__RULE__ ", "\n" ) + "\n"; // FIXME copy, __ + grammar_buf << grammar_str << DTRAIN_GRAMMAR_DELIM << endl; + decoder.SetSentenceGrammarFromString( grammar_str ); + // decode, kbest + src_str_buf.push_back( in_split[1] ); + decoder.Decode( in_split[1], &observer ); + } else { + // get buffered grammar + grammar_str.clear(); + int i = 1; + while ( true ) { + string g; + getline( grammar_buf_in, g ); + if ( g == DTRAIN_GRAMMAR_DELIM ) break; + grammar_str += g+"\n"; + i += 1; + } + decoder.SetSentenceGrammarFromString( grammar_str ); + // decode, kbest + decoder.Decode( src_str_buf[sid], &observer ); } // get kbest list @@ -346,6 +354,7 @@ main( int argc, char** argv ) cand_len = kb->sents[i].size(); } NgramCounts counts_tmp = global_counts + counts; + // TODO as param score = 0.9 * scorer( counts_tmp, global_ref_len, global_hyp_len + cand_len, N, bleu_weights ); @@ -380,31 +389,48 @@ main( int argc, char** argv ) TrainingInstances pairs; - sample_all(kb, pairs); + sample_all_rand(kb, pairs); + cout << pairs.size() << endl; for ( TrainingInstances::iterator ti = pairs.begin(); ti != pairs.end(); ti++ ) { - // perceptron + SparseVector<double> dv; - if ( ti->type == -1 ) { + if ( ti->first_score - ti->second_score < 0 ) { dv = ti->second - ti->first; - } else { - dv = ti->first - ti->second; - } - dv.add_value(FD::Convert("__bias"), -1); + //} else { + //dv = ti->first - ti->second; + //} + dv.add_value( FD::Convert("__bias"), -1 ); + + SparseVector<double> reg; + reg = lambdas * ( 2 * gamma ); + dv -= reg; lambdas += dv * eta; - /*if ( verbose ) { - cout << "{{ f(i) > f(j) but g(i) < g(j), so update" << endl; - cout << " i " << TD::GetString(kb->sents[ii]) << endl; - cout << " " << kb->feats[ii] << endl; - cout << " j " << TD::GetString(kb->sents[jj]) << endl; - cout << " " << kb->feats[jj] << endl; - cout << " dv " << dv << endl; + if ( verbose ) { + cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl; + cout << " i " << TD::GetString(kb->sents[ti->first_rank]) << endl; + cout << " " << kb->feats[ti->first_rank] << endl; + cout << " j " << TD::GetString(kb->sents[ti->second_rank]) << endl; + cout << " " << kb->feats[ti->second_rank] << endl; + cout << " diff vec: " << dv << endl; + cout << " lambdas after update: " << lambdas << endl; cout << "}}" << endl; - }*/ + } + + } else { + //if ( 0 ) { + SparseVector<double> reg; + reg = lambdas * ( gamma * 2 ); + lambdas += reg * ( -eta ); + //} + } } + //double l2 = lambdas.l2norm(); + //if ( l2 ) lambdas /= lambdas.l2norm(); + } ++sid; |