From 27ed3c0fecde089a761ccf718748413bb572a3a4 Mon Sep 17 00:00:00 2001 From: graehl Date: Thu, 15 Jul 2010 03:50:05 +0000 Subject: oracle bleu refactor git-svn-id: https://ws10smt.googlecode.com/svn/trunk@259 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cdec.cc | 167 ++++++++++++++++---------------------------------------- 1 file changed, 47 insertions(+), 120 deletions(-) (limited to 'decoder/cdec.cc') diff --git a/decoder/cdec.cc b/decoder/cdec.cc index c15408b5..bec342ef 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -7,6 +7,7 @@ #include #include +#include "oracle_bleu.h" #include "timing_stats.h" #include "translator.h" #include "phrasebased_translator.h" @@ -146,11 +147,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) { ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM") ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules") ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") - ("references,R", po::value >(), "Translation reference files") ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value(),"Directory to write forests to") ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + OracleBleu::AddOptions(&opts); po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value(), "Configuration file") @@ -260,14 +261,14 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { } } -// TODO decoder output should probably be moved to another file -void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, const char *kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, Score* doc_score) { +// TODO decoder output should probably be moved to another file - how about oracle_bleu.h +void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr doc_score) { cerr << "In kbest\n"; ofstream kbest_out; - kbest_out.open(kbest_out_filename_); + kbest_out.open(kbest_out_filename_.c_str()); cerr << "Output kbest to " << kbest_out_filename_; - + //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; @@ -298,6 +299,15 @@ cerr << "In kbest\n"; } } +void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr doc_score) +{ + ostringstream kbest_string_stream; + kbest_string_stream << conf["forest_output"].as() << "/kbest_"<ELength() - e.rule_->Arity(); @@ -517,45 +527,9 @@ int main(int argc, char** argv) { const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); const bool get_oracle_forest = conf.count("get_oracle_forest"); - /*Oracle Extraction Prep*/ - vector oracle_model_ffs; - vector oracle_feature_weights; - shared_ptr oracle_pff; - if(get_oracle_forest) { - - /*Add feature for oracle rescoring */ - string ff, param; - ff="BLEUModel"; - //pass the location of the references file via param to BLEUModel - for(int kk=0;kk < conf["references"].as >().size();kk++) - { - param = param + " " + conf["references"].as >()[kk]; - } - cerr << "Feature: " << ff << "->" << param << endl; - oracle_pff = global_ff_registry->Create(ff,param); - if (!oracle_pff) { exit(1); } - oracle_model_ffs.push_back(oracle_pff.get()); - oracle_feature_weights.push_back(1.0); - - } - - ModelSet oracle_models(oracle_feature_weights, oracle_model_ffs); - - const string loss_function3 = "IBM_BLEU_3"; - ScoreType type3 = ScoreTypeFromString(loss_function3); - const DocScorer ds(type3, conf["references"].as >(), ""); - cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function3 << endl; - - - std::ostringstream kbest_string_stream; - Score* doc_score=NULL; - float doc_src_length=0; - float tmp_src_length=0; - int oracle_doc_size= 10; //used for scaling/weighting oracle doc - float scale_oracle= 1-float(1)/oracle_doc_size; - - /*End Oracle Extraction Prep*/ - + OracleBleu oracle; + if (get_oracle_forest) + oracle.UseConf(conf); shared_ptr extract_file; if (conf.count("extract_rules")) @@ -671,83 +645,37 @@ int main(int argc, char** argv) { vector trans; ViterbiESentence(forest, &trans); - + /*Oracle Rescoring*/ - if(get_oracle_forest) + if(get_oracle_forest) { + Timer t("Forest Oracle rescoring:"); + + DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score); + + Translation best(forest); { - Timer t("Forest Oracle rescoring:"); - vector model_trans; - model_trans = trans; - - trans=model_trans; - Score* sentscore = ds[sent_id]->ScoreCandidate(model_trans); - //initilize psuedo-doc vector to 1 counts - if (!doc_score) { doc_score = sentscore->GetOne(); } - double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); - tmp_src_length = smeta.GetSourceLength(); - smeta.SetScore(doc_score); - smeta.SetDocLen(doc_src_length); - smeta.SetDocScorer(&ds); - - feature_weights[0]=1.0; - - kbest_string_stream << conf["forest_output"].as() << "/kbest_model" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); - kbest_string_stream.str(""); - - - forest.SortInEdgesByEdgeWeights(); - Hypergraph lm_forest; - const IntersectionConfiguration inter_conf_oracle(0, 0); - cerr << "Going to call Apply Model " << endl; - ApplyModelSet(forest, - smeta, - oracle_models, - inter_conf_oracle, - &lm_forest); - - forest.swap(lm_forest); - forest.Reweight(feature_weights); - forest.SortInEdgesByEdgeWeights(); - vector oracle_trans; - - ViterbiESentence(forest, &oracle_trans); - cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; - cerr << " +Oracle BLEU Viterbi: " << TD::GetString(oracle_trans) << endl; - - //compute kbest for oracle - kbest_string_stream << conf["forest_output"].as() <<"/kbest_oracle" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); - kbest_string_stream.str(""); - - - //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples - feature_weights[0]=-1.0; - forest.Reweight(feature_weights); - forest.SortInEdgesByEdgeWeights(); - vector neg_trans; - ViterbiESentence(forest, &neg_trans); - cerr << " -Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " -Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; - cerr << " -Oracle BLEU Viterbi: " << TD::GetString(neg_trans) << endl; - - //compute kbest for negative - kbest_string_stream << conf["forest_output"].as() << "/kbest_negative" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length,ds, doc_score); - kbest_string_stream.str(""); - - //Add 1-best translation (trans) to psuedo-doc vectors - doc_score->PlusEquals(*sentscore, scale_oracle); - delete sentscore; - - doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; - - - string details; - doc_score->ScoreDetails(&details); - cerr << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <
() : 0); } else { - if (kbest) { - DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"", doc_src_length, tmp_src_length, ds, doc_score); + DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { -- cgit v1.2.3