summaryrefslogtreecommitdiff
path: root/mira
diff options
context:
space:
mode:
Diffstat (limited to 'mira')
-rw-r--r--mira/kbest_mira.cc62
1 files changed, 14 insertions, 48 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index 6918a9a1..459a5e6f 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -32,21 +32,6 @@ namespace po = boost::program_options;
bool invert_score;
boost::shared_ptr<MT19937> rng;
-void SanityCheck(const vector<double>& w) {
- for (int i = 0; i < w.size(); ++i) {
- assert(!isnan(w[i]));
- assert(!isinf(w[i]));
- }
-}
-
-struct FComp {
- const vector<double>& w_;
- FComp(const vector<double>& w) : w_(w) {}
- bool operator()(int a, int b) const {
- return fabs(w_[a]) > fabs(w_[b]);
- }
-};
-
void RandomPermutation(int len, vector<int>* p_ids) {
vector<int>& ids = *p_ids;
ids.resize(len);
@@ -58,21 +43,6 @@ void RandomPermutation(int len, vector<int>* p_ids) {
}
}
-void ShowLargestFeatures(const vector<double>& w) {
- vector<int> fnums(w.size());
- for (int i = 0; i < w.size(); ++i)
- fnums[i] = i;
- vector<int>::iterator mid = fnums.begin();
- mid += (w.size() > 10 ? 10 : w.size());
- partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
- cerr << "TOP FEATURES:";
- --mid;
- for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
- cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
- }
- cerr << endl;
-}
-
bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -209,14 +179,16 @@ int main(int argc, char** argv) {
cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n";
return 1;
}
- // load initial weights
- Weights weights;
- weights.InitFromFile(conf["input_weights"].as<string>());
- SparseVector<double> lambdas;
- weights.InitSparseVector(&lambdas);
ReadFile ini_rf(conf["decoder_config"].as<string>());
Decoder decoder(ini_rf.stream());
+
+ // load initial weights
+ vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ SparseVector<weight_t> lambdas;
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &dense_weights);
+ Weights::InitSparseVector(dense_weights, &lambdas);
+
const double max_step_size = conf["max_step_size"].as<double>();
const double mt_metric_scale = conf["mt_metric_scale"].as<double>();
@@ -230,7 +202,6 @@ int main(int argc, char** argv) {
double tot_loss = 0;
int dots = 0;
int cur_pass = 0;
- vector<double> dense_weights;
SparseVector<double> tot;
tot += lambdas; // initial weights
normalizer++; // count for initial weights
@@ -240,27 +211,22 @@ int main(int argc, char** argv) {
vector<int> order;
RandomPermutation(corpus.size(), &order);
while (lcount <= max_iteration) {
- dense_weights.clear();
- weights.InitFromVector(lambdas);
- weights.InitVector(&dense_weights);
- decoder.SetWeights(dense_weights);
+ lambdas.init_vector(&dense_weights);
if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; }
if (corpus.size() == cur_sent) {
cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n";
- ShowLargestFeatures(dense_weights);
+ Weights::ShowLargestFeatures(dense_weights);
cur_sent = 0;
tot_loss = 0;
dots = 0;
ostringstream os;
os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz";
- weights.WriteToFile(os.str(), true, &msg);
SparseVector<double> x = tot;
x /= normalizer;
ostringstream sa;
sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz";
- Weights ww;
- ww.InitFromVector(x);
- ww.WriteToFile(sa.str(), true, &msga);
+ x.init_vector(&dense_weights);
+ Weights::WriteToFile(os.str(), dense_weights, true, &msg);
++cur_pass;
RandomPermutation(corpus.size(), &order);
}
@@ -294,11 +260,11 @@ int main(int argc, char** argv) {
++cur_sent;
}
cerr << endl;
- weights.WriteToFile("weights.mira-final.gz", true, &msg);
+ Weights::WriteToFile("weights.mira-final.gz", dense_weights, true, &msg);
tot /= normalizer;
- weights.InitFromVector(tot);
+ tot.init_vector(dense_weights);
msg = "# MIRA tuned weights (averaged vector)";
- weights.WriteToFile("weights.mira-final-avg.gz", true, &msg);
+ Weights::WriteToFile("weights.mira-final-avg.gz", dense_weights, true, &msg);
cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n";
return 0;
}