From 6bacbfcbe191ec898e43f4f03e570283b156a8ca Mon Sep 17 00:00:00 2001 From: graehl Date: Fri, 16 Jul 2010 20:08:35 +0000 Subject: vest: oracle_loss argument bugfix git-svn-id: https://ws10smt.googlecode.com/svn/trunk@287 ec762483-ff6d-05da-a07a-a48fb63a330f --- vest/mr_vest_generate_mapper_input.cc | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'vest/mr_vest_generate_mapper_input.cc') diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 01e93f61..5e208aa0 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -78,12 +78,7 @@ struct oracle_directions { void AddOptions(po::options_description *opts) { oracle.AddOptions(opts); - } - - void InitCommandLine(int argc, char *argv[], po::variables_map *conf) { - po::options_description opts("Configuration options"); - OracleBleu::AddOptions(&opts); - opts.add_options() + opts-?add_options() ("dev_set_size,s",po::value(&dev_set_size),"[REQD] Development set size (# of parallel sentences)") ("forest_repository,r",po::value(&forest_repository),"[REQD] Path to forest repository") ("weights,w",po::value(&weights_file),"[REQD] Current feature weights file") @@ -96,8 +91,13 @@ struct oracle_directions { ("max_similarity,m",po::value(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") ("no_old_to_hope,n","don't emit the usual old -> hope oracle") - ("decoder_translations",po::value(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") - ("help,h", "Help"); + ("decoder_translations",po::value(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU"); + } + void InitCommandLine(int argc, char *argv[], po::variables_map *conf) { + po::options_description opts("Configuration options"); + AddOptions(&opts); + opts.add_options()("help,h", "Help"); + po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); @@ -176,9 +176,13 @@ struct oracle_directions { oracle_directions() { } Sentences model_hyps; + bool have_doc; void Init() { - if (!decoder_translations_file.empty()) + have_doc=!decoder_translations_file.empty(); + if (have_doc) { model_hyps.Load(decoder_translations_file); + //TODO: compute doc bleu stats for each sentence, then when getting oracle temporarily exclude stats for that sentence (skip regular score updating) + } start_random=false; assert(DirectoryExists(forest_repository)); vector features; @@ -205,6 +209,9 @@ struct oracle_directions { Oracle const& ComputeOracle(unsigned i) { Oracle &o=oracles[i]; if (o.is_null()) { + if (have_doc) { + //TODO: + } ReadFile rf(forest_file(i)); Hypergraph hg; { @@ -212,6 +219,10 @@ struct oracle_directions { HypergraphIO::ReadFromJSON(rf.stream(), &hg); } o=oracle.ComputeOracle(oracle.MakeMetadata(hg,i),&hg,origin,&cerr); + if (have_doc) { + //TODO: + } else + oracle.IncludeLastScore(); } return o; } -- cgit v1.2.3