summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mira/kbest_mira.cc32
1 files changed, 27 insertions, 5 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index ae54c807..6918a9a1 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -23,12 +23,14 @@
#include "fdict.h"
#include "weights.h"
#include "sparse_vector.h"
+#include "sampler.h"
using namespace std;
using boost::shared_ptr;
namespace po = boost::program_options;
bool invert_score;
+boost::shared_ptr<MT19937> rng;
void SanityCheck(const vector<double>& w) {
for (int i = 0; i < w.size(); ++i) {
@@ -45,6 +47,17 @@ struct FComp {
}
};
+void RandomPermutation(int len, vector<int>* p_ids) {
+ vector<int>& ids = *p_ids;
+ ids.resize(len);
+ for (int i = 0; i < len; ++i) ids[i] = i;
+ for (int i = len; i > 0; --i) {
+ int j = rng->next() * i;
+ if (j == i) i--;
+ swap(ids[i-1], ids[j]);
+ }
+}
+
void ShowLargestFeatures(const vector<double>& w) {
vector<int> fnums(w.size());
for (int i = 0; i < w.size(); ++i)
@@ -71,6 +84,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)")
("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by")
("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")
+ ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("decoder_config,c",po::value<string>(),"Decoder configuration file");
po::options_description clo("Command line options");
clo.add_options()
@@ -176,6 +190,10 @@ int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
+ if (conf.count("random_seed"))
+ rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ rng.reset(new MT19937);
vector<string> corpus;
ReadTrainingCorpus(conf["source"].as<string>(), &corpus);
const string metric_name = conf["mt_metric"].as<string>();
@@ -219,6 +237,8 @@ int main(int argc, char** argv) {
int max_iteration = conf["passes"].as<int>() * corpus.size();
string msg = "# MIRA tuned weights";
string msga = "# MIRA tuned weights AVERAGED";
+ vector<int> order;
+ RandomPermutation(corpus.size(), &order);
while (lcount <= max_iteration) {
dense_weights.clear();
weights.InitFromVector(lambdas);
@@ -242,14 +262,16 @@ int main(int argc, char** argv) {
ww.InitFromVector(x);
ww.WriteToFile(sa.str(), true, &msga);
++cur_pass;
- } else if (cur_sent == 0) {
+ RandomPermutation(corpus.size(), &order);
+ }
+ if (cur_sent == 0) {
cerr << "PASS " << (lcount / corpus.size() + 1) << endl;
}
- decoder.SetId(cur_sent);
- decoder.Decode(corpus[cur_sent], &observer); // update oracles
+ decoder.SetId(order[cur_sent]);
+ decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles
const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
- const HypothesisInfo& cur_good = *oracles[cur_sent].good;
- const HypothesisInfo& cur_bad = *oracles[cur_sent].bad;
+ const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good;
+ const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad;
tot_loss += cur_hyp.mt_metric;
if (!ApproxEqual(cur_hyp.mt_metric, cur_good.mt_metric)) {
const double loss = cur_bad.features.dot(dense_weights) - cur_good.features.dot(dense_weights) +