diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 01:56:34 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 01:56:34 +0000 |
commit | d7d59c4bb81262f1dfece384ec68fa2c25096843 (patch) | |
tree | 5521dc624dc23adeb3bc9d9c8f8fecc7feb57724 /decoder/oracle_bleu.h | |
parent | ff323448416bbfa691a9697ddf3b30a0398fa08a (diff) |
oracle directions
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@276 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-x | decoder/oracle_bleu.h | 79 |
1 files changed, 67 insertions, 12 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 5fef53fd..550f438f 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -37,7 +37,31 @@ struct Translation { out<<pre<<"Viterbi: "<<TD::GetString(sentence)<<"\n"; out<<pre<<"features: "<<features<<std::endl; } + bool is_null() { + return features.size()==0 /* && sentence.size()==0 */; + } + +}; + +struct Oracles { + bool is_null() { + return model.is_null() /* && fear.is_null() && hope.is_null() */; + } + Translation model,fear,hope; + // feature 0 will be the error rate in fear and hope + // move toward hope + WeightVector ModelHopeGradient() { + WeightVector r=hope-model; + r[0]=0; + return r; + } + // move toward hope from fear + WeightVector FearHopeGradient() { + Weightvector r=hope-fear; + r[0]=0; + return r; + } }; @@ -53,6 +77,7 @@ struct OracleBleu { opts->add_options() ("references,R", value<Refs >(), "Translation reference files") ("oracle_loss", value<string>(), "IBM_BLEU_3 (default), IBM_BLEU etc") + ("bleu_weight", value<double>()->default_value(1.), "weight to give the hope/fear loss function vs. model score") ; } int order; @@ -66,17 +91,20 @@ struct OracleBleu { double doc_src_length; void set_oracle_doc_size(int size) { oracle_doc_size=size; - scale_oracle= 1-1./oracle_doc_size;\ + scale_oracle= 1-1./oracle_doc_size; doc_src_length=0; } OracleBleu(int doc_size=10) { set_oracle_doc_size(doc_size); } - boost::shared_ptr<Score> doc_score,sentscore; // made from factory, so we delete them + typedef boost::shared_ptr<Score> ScoreP; + ScoreP doc_score,sentscore; // made from factory, so we delete them + double bleu_weight; void UseConf(boost::program_options::variables_map const& conf) { using namespace std; + bleu_weight=conf["bleu_weight"].as<double>(); set_loss(conf["oracle_loss"].as<string>()); set_refs(conf["references"].as<Refs>()); } @@ -108,21 +136,48 @@ struct OracleBleu { ViterbiFSentence(forest,&srcsent); SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs? sm.SetSourceLength(srcsent.size()); + smeta.SetScore(doc_score.get()); + smeta.SetDocScorer(&ds); + smeta.SetDocLen(doc_src_length); return sm; } - void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) { - Translation model_trans(forest); - sentscore.reset(ds[smeta.GetSentenceID()]->ScoreCandidate(model_trans.sentence)); + Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { + Oracles r; + int sent_id=smeta.GetSentenceID(); + r.model=Translation(forest); + + if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output); + { + Timer t("Forest Oracle rescoring:"); + Hypergraph oracle_forest; + Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log); + forest.swap(oracle_forest); + } + r.hope=Translation(forest); + if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output); + oracle.ReweightBleu(&forest,-blue_weight); + r.fear=Translation(forest); + if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output); + return r; + } + + ScoreP Score(Sentence const& sentence,int sent_id) { + return ds[sent_id]->ScoreCandidate(sentence); + } + ScoreP Score(Hypergraph const& forest,int sent_id) { + return Score(model_trans(forest).translation,sent_id); + } + + void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { + // the sentence bleu stats will get added to doc only if you call IncludeLastScore + sentscore=Score(forest,smeta.GetSentenceID()); if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? - smeta.SetScore(doc_score.get()); - smeta.SetDocLen(doc_src_length); - smeta.SetDocScorer(&ds); using namespace std; - ModelSet oracle_models(FeatureWeights(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get())); + ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get())); const IntersectionConfiguration inter_conf_oracle(0, 0); - cerr << "Going to call Apply Model " << endl; + if (log) *log << "Going to call Apply Model " << endl; ApplyModelSet(forest, smeta, oracle_models, @@ -190,10 +245,10 @@ struct OracleBleu { } } - void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique) +void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output) { std::ostringstream kbest_string_stream; - kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id; + kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str()); } |