diff options
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 9e702e2f..5c3e8181 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -5,6 +5,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "sampler.h" #include "filelib.h" #include "weights.h" #include "line_optimizer.h" @@ -16,8 +17,7 @@ namespace po = boost::program_options; typedef SparseVector<double> Dir; -typedef RandomNumberGenerator<boost::mt19937> RNG; -RNG rng; +MT19937 rng; struct oracle_directions { string forest_repository; @@ -45,6 +45,18 @@ struct oracle_directions { } return dir; } + // if start_random is true, immediately sample w/ replacement from src sentences; otherwise, consume them sequentially until exhausted, then random. oracle vectors are summed + void add_directions(vector<Dir> &dirs,unsigned n,unsigned batchsz=20,bool start_random=false) { + MT19937::IntRNG rsg=rng.inclusive(0,dev_set_size-1); + unsigned b=0; + for(unsigned i=0;i<n;++i) { + dirs.push_back(Dir()); + Dir &d=dirs.back(); + for (unsigned j=0;j<batchsz;++j,++b) + d+=(*this)[(start_random || b>=dev_set_size)?rsg():b]; + d/=(double)batchsz; + } + } }; @@ -135,6 +147,7 @@ int main(int argc, char** argv) { &axes, !conf.count("no_primary") ); + od.add_directions(axes,conf["oracle_directions"].as<unsigned>(),conf["oracle_batch"].as<unsigned>()); compress_similar(axes,conf["max_similarity"].as<double>()); for (int i = 0; i < od.dev_set_size; ++i) for (int j = 0; j < axes.size(); ++j) |