From 538bc2149631e989e4806165632c5460c3514670 Mon Sep 17 00:00:00 2001 From: graehl Date: Fri, 16 Jul 2010 01:57:08 +0000 Subject: oracle refactor, oracle vest directions, sparse_vector git-svn-id: https://ws10smt.googlecode.com/svn/trunk@280 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/oracle_bleu.h | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index b58117c1..cc19fbca 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -17,6 +17,7 @@ #include "apply_models.h" #include "kbest.h" #include "timing_stats.h" +#include "sentences.h" //TODO: put function impls into .cc //TODO: disentangle @@ -44,7 +45,7 @@ struct Translation { }; -struct Oracles { +struct Oracle { bool is_null() { return model.is_null() /* && fear.is_null() && hope.is_null() */; } @@ -52,13 +53,13 @@ struct Oracles { Translation model,fear,hope; // feature 0 will be the error rate in fear and hope // move toward hope - FeatureVector ModelHopeGradient() { + FeatureVector ModelHopeGradient() const { FeatureVector r=hope.features-model.features; r.set_value(0,0); return r; } // move toward hope from fear - FeatureVector FearHopeGradient() { + FeatureVector FearHopeGradient() const { FeatureVector r=hope.features-fear.features; r.set_value(0,0); return r; @@ -150,9 +151,9 @@ struct OracleBleu { } // destroys forest (replaces it w/ rescored oracle one) - Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { + Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { Hypergraph &forest=*forest_in_out; - Oracles r; + Oracle r; int sent_id=smeta.GetSentenceID(); r.model=Translation(forest); if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output); @@ -169,23 +170,24 @@ struct OracleBleu { if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output); return r; } - typedef std::vector Sentence; - void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { + void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { // the sentence bleu stats will get added to doc only if you call IncludeLastScore sentscore=GetScore(forest,smeta.GetSentenceID()); if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? using namespace std; - ModelSet oracle_models(WeightVector(bleu_weight,1),vector(1,pff.get())); - const IntersectionConfiguration inter_conf_oracle(0, 0); + DenseWeightVector w; + feature_weights_=feature_weights; + feature_weights_.set_value(0,bleu_weight); + feature_weights.init_vector(&w); + ModelSet oracle_models(w,vector(1,pff.get())); if (log) *log << "Going to call Apply Model " << endl; ApplyModelSet(forest, smeta, oracle_models, - inter_conf_oracle, + IntersectionConfiguration(exhaustive_t()), dest_forest); - feature_weights_=feature_weights; ReweightBleu(dest_forest,bleu_weight); } @@ -202,7 +204,7 @@ struct OracleBleu { } void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) { - feature_weights_[0]=bleu_weight; + feature_weights_.set_value(0,bleu_weight); dest_forest->Reweight(feature_weights_); // dest_forest->SortInEdgesByEdgeWeights(); } @@ -227,7 +229,7 @@ struct OracleBleu { kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; //calculate score in context of psuedo-doc - Score* sentscore = GetScore(d->yield,sent_id); + ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); float bleu = curr_src_length * sentscore->ComputeScore(); kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " -- cgit v1.2.3