diff options
-rw-r--r-- | mira/kbest_mira.cc | 32 |
1 files changed, 27 insertions, 5 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index ae54c807..6918a9a1 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -23,12 +23,14 @@ #include "fdict.h" #include "weights.h" #include "sparse_vector.h" +#include "sampler.h" using namespace std; using boost::shared_ptr; namespace po = boost::program_options; bool invert_score; +boost::shared_ptr<MT19937> rng; void SanityCheck(const vector<double>& w) { for (int i = 0; i < w.size(); ++i) { @@ -45,6 +47,17 @@ struct FComp { } }; +void RandomPermutation(int len, vector<int>* p_ids) { + vector<int>& ids = *p_ids; + ids.resize(len); + for (int i = 0; i < len; ++i) ids[i] = i; + for (int i = len; i > 0; --i) { + int j = rng->next() * i; + if (j == i) i--; + swap(ids[i-1], ids[j]); + } +} + void ShowLargestFeatures(const vector<double>& w) { vector<int> fnums(w.size()); for (int i = 0; i < w.size(); ++i) @@ -71,6 +84,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)") ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by") ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("decoder_config,c",po::value<string>(),"Decoder configuration file"); po::options_description clo("Command line options"); clo.add_options() @@ -176,6 +190,10 @@ int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); + else + rng.reset(new MT19937); vector<string> corpus; ReadTrainingCorpus(conf["source"].as<string>(), &corpus); const string metric_name = conf["mt_metric"].as<string>(); @@ -219,6 +237,8 @@ int main(int argc, char** argv) { int max_iteration = conf["passes"].as<int>() * corpus.size(); string msg = "# MIRA tuned weights"; string msga = "# MIRA tuned weights AVERAGED"; + vector<int> order; + RandomPermutation(corpus.size(), &order); while (lcount <= max_iteration) { dense_weights.clear(); weights.InitFromVector(lambdas); @@ -242,14 +262,16 @@ int main(int argc, char** argv) { ww.InitFromVector(x); ww.WriteToFile(sa.str(), true, &msga); ++cur_pass; - } else if (cur_sent == 0) { + RandomPermutation(corpus.size(), &order); + } + if (cur_sent == 0) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl; } - decoder.SetId(cur_sent); - decoder.Decode(corpus[cur_sent], &observer); // update oracles + decoder.SetId(order[cur_sent]); + decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); - const HypothesisInfo& cur_good = *oracles[cur_sent].good; - const HypothesisInfo& cur_bad = *oracles[cur_sent].bad; + const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; + const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; tot_loss += cur_hyp.mt_metric; if (!ApproxEqual(cur_hyp.mt_metric, cur_good.mt_metric)) { const double loss = cur_bad.features.dot(dense_weights) - cur_good.features.dot(dense_weights) + |