summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-xdecoder/oracle_bleu.h28
1 files changed, 15 insertions, 13 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index b58117c1..cc19fbca 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -17,6 +17,7 @@
#include "apply_models.h"
#include "kbest.h"
#include "timing_stats.h"
+#include "sentences.h"
//TODO: put function impls into .cc
//TODO: disentangle
@@ -44,7 +45,7 @@ struct Translation {
};
-struct Oracles {
+struct Oracle {
bool is_null() {
return model.is_null() /* && fear.is_null() && hope.is_null() */;
}
@@ -52,13 +53,13 @@ struct Oracles {
Translation model,fear,hope;
// feature 0 will be the error rate in fear and hope
// move toward hope
- FeatureVector ModelHopeGradient() {
+ FeatureVector ModelHopeGradient() const {
FeatureVector r=hope.features-model.features;
r.set_value(0,0);
return r;
}
// move toward hope from fear
- FeatureVector FearHopeGradient() {
+ FeatureVector FearHopeGradient() const {
FeatureVector r=hope.features-fear.features;
r.set_value(0,0);
return r;
@@ -150,9 +151,9 @@ struct OracleBleu {
}
// destroys forest (replaces it w/ rescored oracle one)
- Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
Hypergraph &forest=*forest_in_out;
- Oracles r;
+ Oracle r;
int sent_id=smeta.GetSentenceID();
r.model=Translation(forest);
if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);
@@ -169,23 +170,24 @@ struct OracleBleu {
if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);
return r;
}
- typedef std::vector<WordID> Sentence;
- void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
+ void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
sentscore=GetScore(forest,smeta.GetSentenceID());
if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
using namespace std;
- ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get()));
- const IntersectionConfiguration inter_conf_oracle(0, 0);
+ DenseWeightVector w;
+ feature_weights_=feature_weights;
+ feature_weights_.set_value(0,bleu_weight);
+ feature_weights.init_vector(&w);
+ ModelSet oracle_models(w,vector<FeatureFunction const*>(1,pff.get()));
if (log) *log << "Going to call Apply Model " << endl;
ApplyModelSet(forest,
smeta,
oracle_models,
- inter_conf_oracle,
+ IntersectionConfiguration(exhaustive_t()),
dest_forest);
- feature_weights_=feature_weights;
ReweightBleu(dest_forest,bleu_weight);
}
@@ -202,7 +204,7 @@ struct OracleBleu {
}
void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) {
- feature_weights_[0]=bleu_weight;
+ feature_weights_.set_value(0,bleu_weight);
dest_forest->Reweight(feature_weights_);
// dest_forest->SortInEdgesByEdgeWeights();
}
@@ -227,7 +229,7 @@ struct OracleBleu {
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
//calculate score in context of psuedo-doc
- Score* sentscore = GetScore(d->yield,sent_id);
+ ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
float bleu = curr_src_length * sentscore->ComputeScore();
kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "