//TODO: debug segfault when references supplied, null shared_ptr when oracle #include #include #include #include #include #include "sampler.h" #include "filelib.h" #include "weights.h" #include "line_optimizer.h" #include "hg.h" #include "hg_io.h" #include "scorer.h" #include "oracle_bleu.h" #include "ff_bleu.h" const bool DEBUG_ORACLE=true; boost::shared_ptr global_ff_registry; namespace { struct init_ff { init_ff() { global_ff_registry.reset(new FFRegistry); global_ff_registry->Register(new FFFactory); } }; init_ff reg; } using namespace std; namespace po = boost::program_options; typedef SparseVector Dir; typedef Dir Point; void compress_similar(vector &dirs,double min_dist,ostream *log=&cerr,bool avg=true,bool verbose=true) { // return; //TODO: debug if (min_dist<=0) return; double max_s=1.-min_dist; if (log&&verbose) *log<<"max allowed S="< "<add_options() ("dev_set_size,s",po::value(&dev_set_size),"[REQD] Development set size (# of parallel sentences)") ("forest_repository,r",po::value(&forest_repository),"[REQD] Path to forest repository") ("weights,w",po::value(&weights_file),"[REQD] Current feature weights file") ("optimize_feature,o",po::value >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") ("random_directions,d",po::value(&n_random)->default_value(10),"Number of random directions to run the line optimizer in") ("no_primary,n","don't use the primary (orthogonal each feature alone) directions") ("oracle_directions,O",po::value(&n_oracle)->default_value(0),"read the forests and choose this many directions based on heading toward a hope max (bleu+modelscore) translation.") ("oracle_start_random",po::bool_switch(&start_random),"sample random subsets of dev set for ALL oracle directions, not just those after a sequential run through it") ("oracle_batch,b",po::value(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences") ("max_similarity,m",po::value(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") ("no_old_to_hope","don't emit the usual old -> hope oracle") ("decoder_translations",po::value(&decoder_translations_file)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") ; } void InitCommandLine(int argc, char *argv[], po::variables_map *conf) { po::options_description opts("Configuration options"); AddOptions(&opts); opts.add_options()("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); po::notify(*conf); if (conf->count("dev_set_size") == 0) { cerr << "Please specify the size of the development set using -d N\n"; goto bad_cmdline; } if (conf->count("weights") == 0) { cerr << "Please specify the starting-point weights using -w \n"; goto bad_cmdline; } if (conf->count("forest_repository") == 0) { cerr << "Please specify the forest repository location using -r \n"; goto bad_cmdline; } if (n_oracle && oracle.refs.empty()) { cerr<<"Specify references when using oracle directions\n"; goto bad_cmdline; } if (conf->count("help")) { cout << dcmdline_options << endl; exit(0); } UseConf(*conf); verbose=oracle.verbose; return; bad_cmdline: cerr << dcmdline_options << endl; exit(1); } int main(int argc, char *argv[]) { po::variables_map conf; InitCommandLine(argc,argv,&conf); Run(); return 0; } bool verbose; void Run() { AddPrimaryAndRandomDirections(); AddOracleDirections(); compress_similar(directions,max_similarity,&cerr,true,verbose); Print(); } Point origin; // old weights that gave model 1best. vector optimize_features; void UseConf(po::variables_map const& conf) { oracle.UseConf(conf); include_primary=!conf.count("no_primary"); old_to_hope=!conf.count("no_old_to_hope"); if (conf.count("optimize_feature") > 0) optimize_features=conf["optimize_feature"].as >(); Init(); } string weights_file; double max_similarity; unsigned n_oracle, oracle_batch; string forest_repository; unsigned dev_set_size; vector oracles; vector fids; string forest_file(unsigned i) const { ostringstream o; o << forest_repository << '/' << i << ".json.gz"; return o.str(); } oracle_directions() { } Sentences model_hyps; vector model_scores; bool have_doc; void Init() { have_doc=!decoder_translations_file.empty(); if (have_doc) { model_hyps.Load(decoder_translations_file); model_scores.resize(model_hyps.size()); for (int i=0;iScoreCandidate(model_hyps[i]); if (verbose) cerr<<"Before model["<ScoreDetails()<PlusEquals(*model_scores[i]); if (verbose) cerr<<"After model["< features; weights.InitFromFile(weights_file, &features); if (optimize_features.size()) features=optimize_features; weights.InitSparseVector(&origin); fids.clear(); AddFeatureIds(features); oracles.resize(dev_set_size); } Weights weights; void AddFeatureIds(vector const& features) { int i = fids.size(); fids.resize(fids.size()+features.size()); for (; i < features.size(); ++i) fids[i] = FD::Convert(features[i]); } std::string decoder_translations_file; // one per line //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive). translations are in run.raw.N.gz - new arg void adjust_doc(unsigned i,double scale=1.) { oracle.doc_score->PlusEquals(*model_scores[i],scale); } Score &ds() { return *oracle.doc_score; } Oracle const& ComputeOracle(unsigned i) { Oracle &o=oracles[i]; if (o.is_null()) { if (have_doc) { if (verbose) cerr<<"Before removing i="<ScoreDetails()<=dev_set_size) ? rsg() : b); if (old_to_hope) o2hope+=o.ModelHopeGradient(); if (fear_to_hope) fear2hope+=o.FearHopeGradient(); } double N=(double)oracle_batch; if (old_to_hope) { o2hope/=N; directions.push_back(o2hope); } if (fear_to_hope) { fear2hope/=N; directions.push_back(fear2hope); } } } }; int main(int argc, char** argv) { oracle_directions od; return od.main(argc,argv); }