summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mira/kbest_mira.cc23
1 files changed, 18 insertions, 5 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index 60703273..ae54c807 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -53,6 +53,7 @@ void ShowLargestFeatures(const vector<double>& w) {
mid += (w.size() > 10 ? 10 : w.size());
partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
cerr << "TOP FEATURES:";
+ --mid;
for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
}
@@ -207,15 +208,17 @@ int main(int argc, char** argv) {
TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles);
int cur_sent = 0;
int lcount = 0;
+ int normalizer = 0;
double tot_loss = 0;
int dots = 0;
int cur_pass = 0;
vector<double> dense_weights;
SparseVector<double> tot;
tot += lambdas; // initial weights
- lcount++; // count for initial weights
+ normalizer++; // count for initial weights
int max_iteration = conf["passes"].as<int>() * corpus.size();
string msg = "# MIRA tuned weights";
+ string msga = "# MIRA tuned weights AVERAGED";
while (lcount <= max_iteration) {
dense_weights.clear();
weights.InitFromVector(lambdas);
@@ -223,16 +226,25 @@ int main(int argc, char** argv) {
decoder.SetWeights(dense_weights);
if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; }
if (corpus.size() == cur_sent) {
- cur_sent = 0;
cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n";
+ ShowLargestFeatures(dense_weights);
+ cur_sent = 0;
tot_loss = 0;
dots = 0;
ostringstream os;
os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz";
weights.WriteToFile(os.str(), true, &msg);
+ SparseVector<double> x = tot;
+ x /= normalizer;
+ ostringstream sa;
+ sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz";
+ Weights ww;
+ ww.InitFromVector(x);
+ ww.WriteToFile(sa.str(), true, &msga);
++cur_pass;
+ } else if (cur_sent == 0) {
+ cerr << "PASS " << (lcount / corpus.size() + 1) << endl;
}
- if (cur_sent == 0) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl << lambdas << endl; }
decoder.SetId(cur_sent);
decoder.Decode(corpus[cur_sent], &observer); // update oracles
const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
@@ -255,16 +267,17 @@ int main(int argc, char** argv) {
}
}
tot += lambdas;
+ ++normalizer;
++lcount;
++cur_sent;
}
cerr << endl;
weights.WriteToFile("weights.mira-final.gz", true, &msg);
- tot /= lcount;
+ tot /= normalizer;
weights.InitFromVector(tot);
msg = "# MIRA tuned weights (averaged vector)";
weights.WriteToFile("weights.mira-final-avg.gz", true, &msg);
- cerr << "Optimization complete.\\AVERAGED WEIGHTS: weights.mira-final-avg.gz\n";
+ cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n";
return 0;
}