From 538bc2149631e989e4806165632c5460c3514670 Mon Sep 17 00:00:00 2001 From: graehl Date: Fri, 16 Jul 2010 01:57:08 +0000 Subject: oracle refactor, oracle vest directions, sparse_vector git-svn-id: https://ws10smt.googlecode.com/svn/trunk@280 ec762483-ff6d-05da-a07a-a48fb63a330f --- vest/Makefile.am | 2 +- vest/mr_vest_generate_mapper_input.cc | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) (limited to 'vest') diff --git a/vest/Makefile.am b/vest/Makefile.am index 99bd6430..1c797d50 100644 --- a/vest/Makefile.am +++ b/vest/Makefile.am @@ -23,7 +23,7 @@ mbr_kbest_LDADD = $(top_srcdir)/decoder/libcdec.a -lz fast_score_SOURCES = fast_score.cc ter.cc comb_scorer.cc aer_scorer.cc scorer.cc viterbi_envelope.cc fast_score_LDADD = $(top_srcdir)/decoder/libcdec.a -lz -mr_vest_generate_mapper_input_SOURCES = mr_vest_generate_mapper_input.cc line_optimizer.cc +mr_vest_generate_mapper_input_SOURCES = mr_vest_generate_mapper_input.cc line_optimizer.cc timing_stats.cc mr_vest_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a -lz mr_vest_map_SOURCES = viterbi_envelope.cc error_surface.cc aer_scorer.cc mr_vest_map.cc scorer.cc ter.cc comb_scorer.cc line_optimizer.cc diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index cbda78c5..01e93f61 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -62,6 +62,7 @@ struct oracle_directions { bool start_random; bool include_primary; + bool old_to_hope; bool fear_to_hope; unsigned n_random; void AddPrimaryAndRandomDirections() { @@ -87,14 +88,15 @@ struct oracle_directions { ("forest_repository,r",po::value(&forest_repository),"[REQD] Path to forest repository") ("weights,w",po::value(&weights_file),"[REQD] Current feature weights file") ("optimize_feature,o",po::value >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") - ("random_directions,d",po::value(&random_directions)->default_value(10),"Number of random directions to run the line optimizer in") + ("random_directions,d",po::value(&n_random)->default_value(10),"Number of random directions to run the line optimizer in") ("no_primary,n","don't use the primary (orthogonal each feature alone) directions") ("oracle_directions,O",po::value(&n_oracle)->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.") ("oracle_start_random",po::bool_switch(&start_random),"sample random subsets of dev set for ALL oracle directions, not just those after a sequential run through it") ("oracle_batch,b",po::value(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences") ("max_similarity,m",po::value(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") - ("decoder_translations",po::value(&decoder_translations)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") + ("no_old_to_hope,n","don't emit the usual old -> hope oracle") + ("decoder_translations",po::value(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -173,7 +175,10 @@ struct oracle_directions { oracle_directions() { } + Sentences model_hyps; void Init() { + if (!decoder_translations_file.empty()) + model_hyps.Load(decoder_translations_file); start_random=false; assert(DirectoryExists(forest_repository)); vector features; @@ -206,7 +211,7 @@ struct oracle_directions { Timer t("Loading forest from JSON "+forest_file(i)); HypergraphIO::ReadFromJSON(rf.stream(), &hg); } - o=oracle.ComputeOracles(MakeMetadata(hg,i),&hg,origin,&cerr); + o=oracle.ComputeOracle(oracle.MakeMetadata(hg,i),&hg,origin,&cerr); } return o; } @@ -221,13 +226,16 @@ struct oracle_directions { for (unsigned j=0;j=dev_set_size) ? rsg() : b); - o2hope+=o.ModelHopeGradient(); + if (old_to_hope) + o2hope+=o.ModelHopeGradient(); if (fear_to_hope) fear2hope+=o.FearHopeGradient(); } double N=(double)oracle_batch; - o2hope/=N; - directions.push_back(o2hope); + if (old_to_hope) { + o2hope/=N; + directions.push_back(o2hope); + } if (fear_to_hope) { fear2hope/=N; directions.push_back(fear2hope); -- cgit v1.2.3