diff options
Diffstat (limited to 'vest/mr_vest_generate_mapper_input.cc')
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 75 |
1 files changed, 48 insertions, 27 deletions
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index e9a5650b..677c0497 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -84,16 +84,16 @@ struct oracle_directions { OracleBleu::AddOptions(&opts); opts.add_options() ("dev_set_size,s",po::value<unsigned>(&dev_set_size),"[REQD] Development set size (# of parallel sentences)") - ("forest_repository,r",po::value<string>(),"[REQD] Path to forest repository") - ("weights,w",po::value<string>(),"[REQD] Current feature weights file") + ("forest_repository,r",po::value<string>(&forest_repository),"[REQD] Path to forest repository") + ("weights,w",po::value<string>(&weights_file),"[REQD] Current feature weights file") ("optimize_feature,o",po::value<vector<string> >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") - ("random_directions,d",po::value<unsigned int>()->default_value(20),"Number of random directions to run the line optimizer in") + ("random_directions,d",po::value<unsigned>(&random_directions)->default_value(10),"Number of random directions to run the line optimizer in") ("no_primary,n","don't use the primary (orthogonal each feature alone) directions") - ("oracle_directions,O",po::value<unsigned>()->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.") + ("oracle_directions,O",po::value<unsigned>(&n_oracle)->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.") ("oracle_start_random",po::bool_switch(&start_random),"sample random subsets of dev set for ALL oracle directions, not just those after a sequential run through it") - ("oracle_batch,b",po::value<unsigned>()->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences") - ("max_similarity,m",po::value<double>()->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") - ("fear_to_hope,f","for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") + ("oracle_batch,b",po::value<unsigned>(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences") + ("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") + ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -139,16 +139,20 @@ struct oracle_directions { oracle.UseConf(conf); include_primary=!conf.count("no_primary"); + old_to_hope=!conf.count("no_old_to_hope"); + if (conf.count("optimize_feature") > 0) optimize_features=conf["optimize_feature"].as<vector<string> >(); - fear_to_hope=conf.count("fear_to_hope"); - n_random=conf["random_directions"].as<unsigned int>(); - forest_repository=conf["forest_repository"].as<string>(); + + // po::value<X>(&var) takes care of below: +// fear_to_hope=conf.count("fear_to_hope"); +// n_random=conf["random_directions"].as<unsigned int>(); +// forest_repository=conf["forest_repository"].as<string>(); // dev_set_size=conf["dev_set_size"].as<unsigned int>(); - n_oracle=conf["oracle_directions"].as<unsigned>(); - oracle_batch=conf["oracle_batch"].as<unsigned>(); - max_similarity=conf["max_similarity"].as<double>(); - weights_file=conf["weights"].as<string>(); +// n_oracle=conf["oracle_directions"].as<unsigned>(); +// oracle_batch=conf["oracle_batch"].as<unsigned>(); +// max_similarity=conf["max_similarity"].as<double>(); +// weights_file=conf["weights"].as<string>(); Init(); } @@ -158,7 +162,7 @@ struct oracle_directions { unsigned n_oracle, oracle_batch; string forest_repository; unsigned dev_set_size; - vector<Dir> dirs; //best_to_hope_dirs + vector<Oracle> oracles; vector<int> fids; string forest_file(unsigned i) const { ostringstream o; @@ -178,6 +182,7 @@ struct oracle_directions { weights.InitSparseVector(&origin); fids.clear(); AddFeatureIds(features); + oracles.resize(dev_set_size); } Weights weights; @@ -189,26 +194,42 @@ struct oracle_directions { } - Dir const& operator[](unsigned i) { - Dir &dir=dirs[i]; - if (dir.empty()) { + //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive) + Oracle const& ComputeOracle(unsigned i) { + Oracle &o=oracles[i]; + if (o.is_null()) { ReadFile rf(forest_file(i)); - FeatureVector fear,hope,best; - //TODO: get hope/oracle from vlad. random for now. - LineOptimizer::RandomUnitVector(fids,&dir,&rng); + Hypergraph hg; + { + Timer t("Loading forest from JSON "+forest_file(i)); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + } + o=oracle.ComputeOracles(MakeMetadata(hg,i),hg,origin,&cerr); } - return dir; + return o; } + // if start_random is true, immediately sample w/ replacement from src sentences; otherwise, consume them sequentially until exhausted, then random. oracle vectors are summed void AddOracleDirections() { MT19937::IntRNG rsg=rng.inclusive(0,dev_set_size-1); unsigned b=0; for(unsigned i=0;i<n_oracle;++i) { - directions.push_back(Dir()); - Dir &d=directions.back(); - for (unsigned j=0;j<oracle_batch;++j,++b) - d+=(*this)[(start_random || b>=dev_set_size)?rsg():b]; - d/=(double)oracle_batch; + Dir o2hope; + Dir fear2hope; + for (unsigned j=0;j<oracle_batch;++j,++b) { + Oracle const& o=ComputeOracle((start_random||b>=dev_set_size) ? rsg() : b); + + o2hope+=o.ModelHopeGradient(); + if (fear_to_hope) + fear2hope+=o.FearHopeGradient(); + } + double N=(double)oracle_batch; + o2hope/=N; + directions.push_back(o2hope); + if (fear_to_hope) { + fear2hope/=N; + directions.push_back(fear2hope); + } } } }; |