summaryrefslogtreecommitdiff
path: root/vest/mr_vest_generate_mapper_input.cc
diff options
context:
space:
mode:
Diffstat (limited to 'vest/mr_vest_generate_mapper_input.cc')
-rw-r--r--vest/mr_vest_generate_mapper_input.cc75
1 files changed, 48 insertions, 27 deletions
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index e9a5650b..677c0497 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -84,16 +84,16 @@ struct oracle_directions {
OracleBleu::AddOptions(&opts);
opts.add_options()
("dev_set_size,s",po::value<unsigned>(&dev_set_size),"[REQD] Development set size (# of parallel sentences)")
- ("forest_repository,r",po::value<string>(),"[REQD] Path to forest repository")
- ("weights,w",po::value<string>(),"[REQD] Current feature weights file")
+ ("forest_repository,r",po::value<string>(&forest_repository),"[REQD] Path to forest repository")
+ ("weights,w",po::value<string>(&weights_file),"[REQD] Current feature weights file")
("optimize_feature,o",po::value<vector<string> >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)")
- ("random_directions,d",po::value<unsigned int>()->default_value(20),"Number of random directions to run the line optimizer in")
+ ("random_directions,d",po::value<unsigned>(&random_directions)->default_value(10),"Number of random directions to run the line optimizer in")
("no_primary,n","don't use the primary (orthogonal each feature alone) directions")
- ("oracle_directions,O",po::value<unsigned>()->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.")
+ ("oracle_directions,O",po::value<unsigned>(&n_oracle)->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.")
("oracle_start_random",po::bool_switch(&start_random),"sample random subsets of dev set for ALL oracle directions, not just those after a sequential run through it")
- ("oracle_batch,b",po::value<unsigned>()->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences")
- ("max_similarity,m",po::value<double>()->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?")
- ("fear_to_hope,f","for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)")
+ ("oracle_batch,b",po::value<unsigned>(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences")
+ ("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?")
+ ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)")
("help,h", "Help");
po::options_description dcmdline_options;
dcmdline_options.add(opts);
@@ -139,16 +139,20 @@ struct oracle_directions {
oracle.UseConf(conf);
include_primary=!conf.count("no_primary");
+ old_to_hope=!conf.count("no_old_to_hope");
+
if (conf.count("optimize_feature") > 0)
optimize_features=conf["optimize_feature"].as<vector<string> >();
- fear_to_hope=conf.count("fear_to_hope");
- n_random=conf["random_directions"].as<unsigned int>();
- forest_repository=conf["forest_repository"].as<string>();
+
+ // po::value<X>(&var) takes care of below:
+// fear_to_hope=conf.count("fear_to_hope");
+// n_random=conf["random_directions"].as<unsigned int>();
+// forest_repository=conf["forest_repository"].as<string>();
// dev_set_size=conf["dev_set_size"].as<unsigned int>();
- n_oracle=conf["oracle_directions"].as<unsigned>();
- oracle_batch=conf["oracle_batch"].as<unsigned>();
- max_similarity=conf["max_similarity"].as<double>();
- weights_file=conf["weights"].as<string>();
+// n_oracle=conf["oracle_directions"].as<unsigned>();
+// oracle_batch=conf["oracle_batch"].as<unsigned>();
+// max_similarity=conf["max_similarity"].as<double>();
+// weights_file=conf["weights"].as<string>();
Init();
}
@@ -158,7 +162,7 @@ struct oracle_directions {
unsigned n_oracle, oracle_batch;
string forest_repository;
unsigned dev_set_size;
- vector<Dir> dirs; //best_to_hope_dirs
+ vector<Oracle> oracles;
vector<int> fids;
string forest_file(unsigned i) const {
ostringstream o;
@@ -178,6 +182,7 @@ struct oracle_directions {
weights.InitSparseVector(&origin);
fids.clear();
AddFeatureIds(features);
+ oracles.resize(dev_set_size);
}
Weights weights;
@@ -189,26 +194,42 @@ struct oracle_directions {
}
- Dir const& operator[](unsigned i) {
- Dir &dir=dirs[i];
- if (dir.empty()) {
+ //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive)
+ Oracle const& ComputeOracle(unsigned i) {
+ Oracle &o=oracles[i];
+ if (o.is_null()) {
ReadFile rf(forest_file(i));
- FeatureVector fear,hope,best;
- //TODO: get hope/oracle from vlad. random for now.
- LineOptimizer::RandomUnitVector(fids,&dir,&rng);
+ Hypergraph hg;
+ {
+ Timer t("Loading forest from JSON "+forest_file(i));
+ HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ }
+ o=oracle.ComputeOracles(MakeMetadata(hg,i),hg,origin,&cerr);
}
- return dir;
+ return o;
}
+
// if start_random is true, immediately sample w/ replacement from src sentences; otherwise, consume them sequentially until exhausted, then random. oracle vectors are summed
void AddOracleDirections() {
MT19937::IntRNG rsg=rng.inclusive(0,dev_set_size-1);
unsigned b=0;
for(unsigned i=0;i<n_oracle;++i) {
- directions.push_back(Dir());
- Dir &d=directions.back();
- for (unsigned j=0;j<oracle_batch;++j,++b)
- d+=(*this)[(start_random || b>=dev_set_size)?rsg():b];
- d/=(double)oracle_batch;
+ Dir o2hope;
+ Dir fear2hope;
+ for (unsigned j=0;j<oracle_batch;++j,++b) {
+ Oracle const& o=ComputeOracle((start_random||b>=dev_set_size) ? rsg() : b);
+
+ o2hope+=o.ModelHopeGradient();
+ if (fear_to_hope)
+ fear2hope+=o.FearHopeGradient();
+ }
+ double N=(double)oracle_batch;
+ o2hope/=N;
+ directions.push_back(o2hope);
+ if (fear_to_hope) {
+ fear2hope/=N;
+ directions.push_back(fear2hope);
+ }
}
}
};