summaryrefslogtreecommitdiff
path: root/patches/mira_hack_fix_wp.patch
blob: dfe9965d7ff99cb2b364da41c44702179bbe00a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index 724b185..1224f0c 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -750,6 +750,7 @@ int main(int argc, char** argv) {
       }
       // Regular mode or LEARN line from stream mode
       //TODO: allow batch updating
+      lambdas[FD::Convert("WordPenalty")] = -1.0; // HACK WP
       lambdas.init_vector(&dense_weights);
       dense_w_local = dense_weights;
       decoder.SetId(cur_sent);
@@ -781,7 +782,8 @@ int main(int argc, char** argv) {
       acc_f->PlusEquals(*fear_sentscore);
       
       if(optimizer == 4) { //passive-aggresive update (single dual coordinate step)
-      
+
+    dense_weights[FD::Convert("WordPenalty")] = -1.0; // HACK WP
 	  double margin = cur_bad.features.dot(dense_weights) - cur_good.features.dot(dense_weights);
 	  double mt_loss = (cur_good.mt_metric - cur_bad.mt_metric);
 	  const double loss = margin +  mt_loss;
@@ -927,6 +929,7 @@ int main(int argc, char** argv) {
 
 			//reload weights based on update
 			dense_weights.clear();
+      lambdas[FD::Convert("WordPenalty")] = -1.0; // HACK WP
 			lambdas.init_vector(&dense_weights);
                         if (dense_weights.size() < 500)
                           ShowLargestFeatures(dense_weights);
@@ -1001,6 +1004,7 @@ int main(int argc, char** argv) {
 		ostringstream os;
 		os << weights_dir << "/weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "." << node_id << ".gz";
 		string msg = "# MIRA tuned weights ||| " + boost::lexical_cast<std::string>(node_id) + " ||| " + boost::lexical_cast<std::string>(lcount);
+    lambdas[FD::Convert("WordPenalty")] = -1.0; // HACK WP
 		lambdas.init_vector(&dense_weights);
 		Weights::WriteToFile(os.str(), dense_weights, true, &msg);