summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 01:56:34 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 01:56:34 +0000
commitd7d59c4bb81262f1dfece384ec68fa2c25096843 (patch)
tree5521dc624dc23adeb3bc9d9c8f8fecc7feb57724 /decoder/oracle_bleu.h
parentff323448416bbfa691a9697ddf3b30a0398fa08a (diff)
oracle directions
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@276 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/oracle_bleu.h')
-rwxr-xr-xdecoder/oracle_bleu.h79
1 files changed, 67 insertions, 12 deletions
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 5fef53fd..550f438f 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -37,7 +37,31 @@ struct Translation {
out<<pre<<"Viterbi: "<<TD::GetString(sentence)<<"\n";
out<<pre<<"features: "<<features<<std::endl;
}
+ bool is_null() {
+ return features.size()==0 /* && sentence.size()==0 */;
+ }
+
+};
+
+struct Oracles {
+ bool is_null() {
+ return model.is_null() /* && fear.is_null() && hope.is_null() */;
+ }
+ Translation model,fear,hope;
+ // feature 0 will be the error rate in fear and hope
+ // move toward hope
+ WeightVector ModelHopeGradient() {
+ WeightVector r=hope-model;
+ r[0]=0;
+ return r;
+ }
+ // move toward hope from fear
+ WeightVector FearHopeGradient() {
+ Weightvector r=hope-fear;
+ r[0]=0;
+ return r;
+ }
};
@@ -53,6 +77,7 @@ struct OracleBleu {
opts->add_options()
("references,R", value<Refs >(), "Translation reference files")
("oracle_loss", value<string>(), "IBM_BLEU_3 (default), IBM_BLEU etc")
+ ("bleu_weight", value<double>()->default_value(1.), "weight to give the hope/fear loss function vs. model score")
;
}
int order;
@@ -66,17 +91,20 @@ struct OracleBleu {
double doc_src_length;
void set_oracle_doc_size(int size) {
oracle_doc_size=size;
- scale_oracle= 1-1./oracle_doc_size;\
+ scale_oracle= 1-1./oracle_doc_size;
doc_src_length=0;
}
OracleBleu(int doc_size=10) {
set_oracle_doc_size(doc_size);
}
- boost::shared_ptr<Score> doc_score,sentscore; // made from factory, so we delete them
+ typedef boost::shared_ptr<Score> ScoreP;
+ ScoreP doc_score,sentscore; // made from factory, so we delete them
+ double bleu_weight;
void UseConf(boost::program_options::variables_map const& conf) {
using namespace std;
+ bleu_weight=conf["bleu_weight"].as<double>();
set_loss(conf["oracle_loss"].as<string>());
set_refs(conf["references"].as<Refs>());
}
@@ -108,21 +136,48 @@ struct OracleBleu {
ViterbiFSentence(forest,&srcsent);
SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs?
sm.SetSourceLength(srcsent.size());
+ smeta.SetScore(doc_score.get());
+ smeta.SetDocScorer(&ds);
+ smeta.SetDocLen(doc_src_length);
return sm;
}
- void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) {
- Translation model_trans(forest);
- sentscore.reset(ds[smeta.GetSentenceID()]->ScoreCandidate(model_trans.sentence));
+ Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ Oracles r;
+ int sent_id=smeta.GetSentenceID();
+ r.model=Translation(forest);
+
+ if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);
+ {
+ Timer t("Forest Oracle rescoring:");
+ Hypergraph oracle_forest;
+ Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log);
+ forest.swap(oracle_forest);
+ }
+ r.hope=Translation(forest);
+ if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output);
+ oracle.ReweightBleu(&forest,-blue_weight);
+ r.fear=Translation(forest);
+ if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);
+ return r;
+ }
+
+ ScoreP Score(Sentence const& sentence,int sent_id) {
+ return ds[sent_id]->ScoreCandidate(sentence);
+ }
+ ScoreP Score(Hypergraph const& forest,int sent_id) {
+ return Score(model_trans(forest).translation,sent_id);
+ }
+
+ void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
+ // the sentence bleu stats will get added to doc only if you call IncludeLastScore
+ sentscore=Score(forest,smeta.GetSentenceID());
if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
- smeta.SetScore(doc_score.get());
- smeta.SetDocLen(doc_src_length);
- smeta.SetDocScorer(&ds);
using namespace std;
- ModelSet oracle_models(FeatureWeights(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get()));
+ ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get()));
const IntersectionConfiguration inter_conf_oracle(0, 0);
- cerr << "Going to call Apply Model " << endl;
+ if (log) *log << "Going to call Apply Model " << endl;
ApplyModelSet(forest,
smeta,
oracle_models,
@@ -190,10 +245,10 @@ struct OracleBleu {
}
}
- void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique)
+void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
{
std::ostringstream kbest_string_stream;
- kbest_string_stream << conf["forest_output"].as<std::string>() << "/kbest_"<<suffix<< "." << sent_id;
+ kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str());
}