diff options
Diffstat (limited to 'decoder/oracle_bleu.h')
| -rwxr-xr-x | decoder/oracle_bleu.h | 28 | 
1 files changed, 15 insertions, 13 deletions
| diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index b58117c1..cc19fbca 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -17,6 +17,7 @@  #include "apply_models.h"  #include "kbest.h"  #include "timing_stats.h" +#include "sentences.h"  //TODO: put function impls into .cc  //TODO: disentangle @@ -44,7 +45,7 @@ struct Translation {  }; -struct Oracles { +struct Oracle {    bool is_null() {      return model.is_null() /* && fear.is_null() && hope.is_null() */;    } @@ -52,13 +53,13 @@ struct Oracles {    Translation model,fear,hope;    // feature 0 will be the error rate in fear and hope    // move toward hope -  FeatureVector ModelHopeGradient() { +  FeatureVector ModelHopeGradient() const {      FeatureVector r=hope.features-model.features;      r.set_value(0,0);      return r;    }    // move toward hope from fear -  FeatureVector FearHopeGradient() { +  FeatureVector FearHopeGradient() const {      FeatureVector r=hope.features-fear.features;      r.set_value(0,0);      return r; @@ -150,9 +151,9 @@ struct OracleBleu {    }    // destroys forest (replaces it w/ rescored oracle one) -  Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { +  Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {      Hypergraph &forest=*forest_in_out; -    Oracles r; +    Oracle r;      int sent_id=smeta.GetSentenceID();      r.model=Translation(forest);      if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output); @@ -169,23 +170,24 @@ struct OracleBleu {      if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);      return r;    } -  typedef std::vector<WordID> Sentence; -  void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { +  void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {      // the sentence bleu stats will get added to doc only if you call IncludeLastScore      sentscore=GetScore(forest,smeta.GetSentenceID());  	if (!doc_score) { doc_score.reset(sentscore->GetOne()); }  	tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?      using namespace std; -    ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get())); -    const IntersectionConfiguration inter_conf_oracle(0, 0); +    DenseWeightVector w; +    feature_weights_=feature_weights; +    feature_weights_.set_value(0,bleu_weight); +    feature_weights.init_vector(&w); +    ModelSet oracle_models(w,vector<FeatureFunction const*>(1,pff.get()));  	if (log) *log << "Going to call Apply Model " << endl;  	ApplyModelSet(forest,                    smeta,                    oracle_models, -                  inter_conf_oracle, +                  IntersectionConfiguration(exhaustive_t()),                    dest_forest); -    feature_weights_=feature_weights;      ReweightBleu(dest_forest,bleu_weight);    } @@ -202,7 +204,7 @@ struct OracleBleu {    }    void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) { -    feature_weights_[0]=bleu_weight; +    feature_weights_.set_value(0,bleu_weight);  	dest_forest->Reweight(feature_weights_);  //	dest_forest->SortInEdgesByEdgeWeights();    } @@ -227,7 +229,7 @@ struct OracleBleu {            kbest.LazyKthBest(forest.nodes_.size() - 1, i);          if (!d) break;          //calculate score in context of psuedo-doc -        Score* sentscore = GetScore(d->yield,sent_id); +        ScoreP sentscore = GetScore(d->yield,sent_id);          sentscore->PlusEquals(*doc_score,float(1));          float bleu = curr_src_length * sentscore->ComputeScore();          kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " | 
