From f819992b0b22b4fec88c15fe13118aa6b484b91b Mon Sep 17 00:00:00 2001 From: graehl Date: Thu, 15 Jul 2010 03:50:05 +0000 Subject: oracle bleu refactor git-svn-id: https://ws10smt.googlecode.com/svn/trunk@259 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/oracle_bleu.h | 154 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100755 decoder/oracle_bleu.h (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h new file mode 100755 index 00000000..273e14b8 --- /dev/null +++ b/decoder/oracle_bleu.h @@ -0,0 +1,154 @@ +#ifndef ORACLE_BLEU_H +#define ORACLE_BLEU_H + +#include +#include +#include +#include +#include +#include "../vest/scorer.h" +#include "hg.h" +#include "ff_factory.h" +#include "ff_bleu.h" +#include "sparse_vector.h" +#include "viterbi.h" +#include "sentence_metadata.h" +#include "apply_models.h" + +//TODO: put function impls into .cc +//TODO: disentangle +struct Translation { + typedef std::vector Sentence; + Sentence sentence; + FeatureVector features; + Translation() { } + Translation(Hypergraph const& hg,WeightVector *feature_weights=0) + { + Viterbi(hg,feature_weights); + } + void Viterbi(Hypergraph const& hg,WeightVector *feature_weights=0) // weights are only for checking that scoring is correct + { + ViterbiESentence(hg,&sentence); + features=ViterbiFeatures(hg,feature_weights,true); + } + void Print(std::ostream &out,std::string pre=" +Oracle BLEU ") { + out< Refs; + Refs refs; + WeightVector feature_weights_; + DocScorer ds; + + static void AddOptions(boost::program_options::options_description *opts) { + using namespace boost::program_options; + using namespace std; + opts->add_options() + ("references,R", value(), "Translation reference files") + ("oracle_loss", value(), "IBM_BLEU_3 (default), IBM_BLEU etc") + ; + } + int order; + + //TODO: move cdec.cc kbest output files function here + + //TODO: provide for loading most recent translation for every sentence (no more scale.. etc below? it's possible i messed the below up; i assume it's supposed to gracefully figure out the document 1bests as you go, then keep them up to date as you make multiple MIRA passes. provide alternative loading for MERT + double scale_oracle; + int oracle_doc_size; + double tmp_src_length; + double doc_src_length; + void set_oracle_doc_size(int size) { + oracle_doc_size=size; + scale_oracle= 1-1./oracle_doc_size;\ + doc_src_length=0; + } + OracleBleu(int doc_size=10) { + set_oracle_doc_size(doc_size); + } + + boost::shared_ptr doc_score,sentscore; // made from factory, so we delete them + + void UseConf(boost::program_options::variables_map const& conf) { + using namespace std; + set_loss(conf["oracle_loss"].as()); + set_refs(conf["references"].as()); + } + + ScoreType loss; +// std::string loss_name; + boost::shared_ptr pff; + + void set_loss(std::string const& lossd="IBM_BLEU_3") { +// loss_name=lossd; + loss=ScoreTypeFromString(lossd); + order=(loss==IBM_BLEU_3)?3:4; + std::ostringstream param; + param<<"-o "<Create("BLEUModel",param.str()); + } + + void set_refs(Refs const& r) { + refs=r; + assert(refs.size()); + ds=DocScorer(loss,refs); + doc_score.reset(); +// doc_score=sentscore + std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl; + } + + SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) { + std::vector srcsent; + ViterbiFSentence(forest,&srcsent); + SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs? + sm.SetSourceLength(srcsent.size()); + return sm; + } + + void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) { + Translation model_trans(forest); + sentscore.reset(ds[smeta.GetSentenceID()]->ScoreCandidate(model_trans.sentence)); + if (!doc_score) { doc_score.reset(sentscore->GetOne()); } + tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? + smeta.SetScore(doc_score.get()); + smeta.SetDocLen(doc_src_length); + smeta.SetDocScorer(&ds); + using namespace std; + ModelSet oracle_models(FeatureWeights(bleu_weight,1),vector(1,pff.get())); + const IntersectionConfiguration inter_conf_oracle(0, 0); + cerr << "Going to call Apply Model " << endl; + ApplyModelSet(forest, + smeta, + oracle_models, + inter_conf_oracle, + dest_forest); + feature_weights_=feature_weights; + ReweightBleu(dest_forest,bleu_weight); + } + + void IncludeLastScore(std::ostream *out=0) { + double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); + doc_score->PlusEquals(*sentscore, scale_oracle); + sentscore.reset(); + doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; + if (out) { + std::string d; + doc_score->ScoreDetails(&d); + *out << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <Reweight(feature_weights_); +// dest_forest->SortInEdgesByEdgeWeights(); + } + +}; + + +#endif -- cgit v1.2.3