summaryrefslogtreecommitdiff
path: root/vest/mr_vest_generate_mapper_input.cc
diff options
context:
space:
mode:
Diffstat (limited to 'vest/mr_vest_generate_mapper_input.cc')
-rw-r--r--vest/mr_vest_generate_mapper_input.cc20
1 files changed, 14 insertions, 6 deletions
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index cbda78c5..01e93f61 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -62,6 +62,7 @@ struct oracle_directions {
bool start_random;
bool include_primary;
+ bool old_to_hope;
bool fear_to_hope;
unsigned n_random;
void AddPrimaryAndRandomDirections() {
@@ -87,14 +88,15 @@ struct oracle_directions {
("forest_repository,r",po::value<string>(&forest_repository),"[REQD] Path to forest repository")
("weights,w",po::value<string>(&weights_file),"[REQD] Current feature weights file")
("optimize_feature,o",po::value<vector<string> >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)")
- ("random_directions,d",po::value<unsigned>(&random_directions)->default_value(10),"Number of random directions to run the line optimizer in")
+ ("random_directions,d",po::value<unsigned>(&n_random)->default_value(10),"Number of random directions to run the line optimizer in")
("no_primary,n","don't use the primary (orthogonal each feature alone) directions")
("oracle_directions,O",po::value<unsigned>(&n_oracle)->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.")
("oracle_start_random",po::bool_switch(&start_random),"sample random subsets of dev set for ALL oracle directions, not just those after a sequential run through it")
("oracle_batch,b",po::value<unsigned>(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences")
("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?")
("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)")
- ("decoder_translations",po::value<string>(&decoder_translations)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU")
+ ("no_old_to_hope,n","don't emit the usual old -> hope oracle")
+ ("decoder_translations",po::value<string>(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU")
("help,h", "Help");
po::options_description dcmdline_options;
dcmdline_options.add(opts);
@@ -173,7 +175,10 @@ struct oracle_directions {
oracle_directions() { }
+ Sentences model_hyps;
void Init() {
+ if (!decoder_translations_file.empty())
+ model_hyps.Load(decoder_translations_file);
start_random=false;
assert(DirectoryExists(forest_repository));
vector<string> features;
@@ -206,7 +211,7 @@ struct oracle_directions {
Timer t("Loading forest from JSON "+forest_file(i));
HypergraphIO::ReadFromJSON(rf.stream(), &hg);
}
- o=oracle.ComputeOracles(MakeMetadata(hg,i),&hg,origin,&cerr);
+ o=oracle.ComputeOracle(oracle.MakeMetadata(hg,i),&hg,origin,&cerr);
}
return o;
}
@@ -221,13 +226,16 @@ struct oracle_directions {
for (unsigned j=0;j<oracle_batch;++j,++b) {
Oracle const& o=ComputeOracle((start_random||b>=dev_set_size) ? rsg() : b);
- o2hope+=o.ModelHopeGradient();
+ if (old_to_hope)
+ o2hope+=o.ModelHopeGradient();
if (fear_to_hope)
fear2hope+=o.FearHopeGradient();
}
double N=(double)oracle_batch;
- o2hope/=N;
- directions.push_back(o2hope);
+ if (old_to_hope) {
+ o2hope/=N;
+ directions.push_back(o2hope);
+ }
if (fear_to_hope) {
fear2hope/=N;
directions.push_back(fear2hope);