summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--vest/mr_vest_generate_mapper_input.cc17
1 files changed, 15 insertions, 2 deletions
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index 9e702e2f..5c3e8181 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -5,6 +5,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "sampler.h"
#include "filelib.h"
#include "weights.h"
#include "line_optimizer.h"
@@ -16,8 +17,7 @@ namespace po = boost::program_options;
typedef SparseVector<double> Dir;
-typedef RandomNumberGenerator<boost::mt19937> RNG;
-RNG rng;
+MT19937 rng;
struct oracle_directions {
string forest_repository;
@@ -45,6 +45,18 @@ struct oracle_directions {
}
return dir;
}
+ // if start_random is true, immediately sample w/ replacement from src sentences; otherwise, consume them sequentially until exhausted, then random. oracle vectors are summed
+ void add_directions(vector<Dir> &dirs,unsigned n,unsigned batchsz=20,bool start_random=false) {
+ MT19937::IntRNG rsg=rng.inclusive(0,dev_set_size-1);
+ unsigned b=0;
+ for(unsigned i=0;i<n;++i) {
+ dirs.push_back(Dir());
+ Dir &d=dirs.back();
+ for (unsigned j=0;j<batchsz;++j,++b)
+ d+=(*this)[(start_random || b>=dev_set_size)?rsg():b];
+ d/=(double)batchsz;
+ }
+ }
};
@@ -135,6 +147,7 @@ int main(int argc, char** argv) {
&axes,
!conf.count("no_primary")
);
+ od.add_directions(axes,conf["oracle_directions"].as<unsigned>(),conf["oracle_batch"].as<unsigned>());
compress_similar(axes,conf["max_similarity"].as<double>());
for (int i = 0; i < od.dev_set_size; ++i)
for (int j = 0; j < axes.size(); ++j)