From 9e35239dd1b4393a320da6c745749500dba8f2b6 Mon Sep 17 00:00:00 2001 From: graehl Date: Mon, 19 Jul 2010 21:33:17 +0000 Subject: shared_ptr for ReadFile and doc_scorer; init ds to GetOne() in oracle git-svn-id: https://ws10smt.googlecode.com/svn/trunk@322 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/oracle_bleu.h | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) (limited to 'decoder/oracle_bleu.h') diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 66d155d3..4800e9c1 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -1,6 +1,8 @@ #ifndef ORACLE_BLEU_H #define ORACLE_BLEU_H +#define DEBUG_ORACLE_BLEU + #include #include #include @@ -20,7 +22,7 @@ #include "sentences.h" //TODO: put function impls into .cc -//TODO: disentangle +//TODO: move Translation into its own .h and use in cdec struct Translation { typedef std::vector Sentence; Sentence sentence; @@ -153,11 +155,16 @@ struct OracleBleu { init_refs(); } void init_refs() { - if (is_null()) return; + if (is_null()) { +#ifdef DEBUG_ORACLE_BLEU + std::cerr<<"No references for oracle BLEU.\n"; +#endif + return; + } assert(refs.size()); - ds=DocScorer(loss,refs); - doc_score.reset(); -// doc_score=sentscore + ds.Init(loss,refs); + ensure_doc_score(); +// doc_score.reset(); std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl; } @@ -193,10 +200,15 @@ struct OracleBleu { return r; } + // if doc_score wasn't init, add 1 counts to ngram acc. + void ensure_doc_score() { + if (!doc_score) { doc_score.reset(Score::GetOne(loss)); } + } + void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { // the sentence bleu stats will get added to doc only if you call IncludeLastScore + ensure_doc_score(); sentscore=GetScore(forest,smeta.GetSentenceID()); - if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? using namespace std; DenseWeightVector w; -- cgit v1.2.3