summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/cdec.cc2
-rwxr-xr-xdecoder/oracle_bleu.h56
-rw-r--r--decoder/tdict.h8
-rw-r--r--vest/mr_vest_generate_mapper_input.cc6
4 files changed, 42 insertions, 30 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 75c907b1..77179948 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -604,7 +604,7 @@ int main(int argc, char** argv) {
/*Oracle Rescoring*/
if(get_oracle_forest) {
- Oracles o=oracles.ComputeOracles(smeta,forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());
+ Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());
cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
o.hope.Print(cerr," +Oracle BLEU");
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 6708b02e..b58117c1 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -16,6 +16,7 @@
#include "sentence_metadata.h"
#include "apply_models.h"
#include "kbest.h"
+#include "timing_stats.h"
//TODO: put function impls into .cc
//TODO: disentangle
@@ -51,15 +52,15 @@ struct Oracles {
Translation model,fear,hope;
// feature 0 will be the error rate in fear and hope
// move toward hope
- WeightVector ModelHopeGradient() {
- WeightVector r=hope-model;
- r[0]=0;
+ FeatureVector ModelHopeGradient() {
+ FeatureVector r=hope.features-model.features;
+ r.set_value(0,0);
return r;
}
// move toward hope from fear
- WeightVector FearHopeGradient() {
- Weightvector r=hope-fear;
- r[0]=0;
+ FeatureVector FearHopeGradient() {
+ FeatureVector r=hope.features-fear.features;
+ r.set_value(0,0);
return r;
}
};
@@ -100,6 +101,12 @@ struct OracleBleu {
typedef boost::shared_ptr<Score> ScoreP;
ScoreP doc_score,sentscore; // made from factory, so we delete them
+ ScoreP GetScore(Sentence const& sentence,int sent_id) {
+ return ScoreP(ds[sent_id]->ScoreCandidate(sentence));
+ }
+ ScoreP GetScore(Hypergraph const& forest,int sent_id) {
+ return GetScore(Translation(forest).sentence,sent_id);
+ }
double bleu_weight;
void UseConf(boost::program_options::variables_map const& conf) {
@@ -134,44 +141,39 @@ struct OracleBleu {
SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) {
std::vector<WordID> srcsent;
ViterbiFSentence(forest,&srcsent);
- SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs?
- sm.SetSourceLength(srcsent.size());
+ SentenceMetadata smeta(sent_id,Lattice()); //TODO: make reference from refs?
+ smeta.SetSourceLength(srcsent.size());
smeta.SetScore(doc_score.get());
smeta.SetDocScorer(&ds);
smeta.SetDocLen(doc_src_length);
- return sm;
+ return smeta;
}
- Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ // destroys forest (replaces it w/ rescored oracle one)
+ Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") {
+ Hypergraph &forest=*forest_in_out;
Oracles r;
int sent_id=smeta.GetSentenceID();
r.model=Translation(forest);
-
if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);
{
Timer t("Forest Oracle rescoring:");
Hypergraph oracle_forest;
- Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log);
+ Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight,log);
forest.swap(oracle_forest);
}
r.hope=Translation(forest);
if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output);
- oracle.ReweightBleu(&forest,-blue_weight);
+ ReweightBleu(&forest,-bleu_weight);
r.fear=Translation(forest);
if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);
return r;
}
-
- ScoreP Score(Sentence const& sentence,int sent_id) {
- return ds[sent_id]->ScoreCandidate(sentence);
- }
- ScoreP Score(Hypergraph const& forest,int sent_id) {
- return Score(model_trans(forest).translation,sent_id);
- }
+ typedef std::vector<WordID> Sentence;
void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
- sentscore=Score(forest,smeta.GetSentenceID());
+ sentscore=GetScore(forest,smeta.GetSentenceID());
if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
using namespace std;
@@ -219,13 +221,13 @@ struct OracleBleu {
float curr_src_length = doc_src_length + tmp_src_length;
if (unique) {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+ KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+ const KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
//calculate score in context of psuedo-doc
- Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield);
+ Score* sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
float bleu = curr_src_length * sentscore->ComputeScore();
kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
@@ -234,9 +236,9 @@ struct OracleBleu {
// << d->feature_values << " ||| " << log(d->score) << endl;
}
} else {
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+ KBest::KBestDerivations<Sentence, ESentenceTraversal> kbest(forest, k);
for (int i = 0; i < k; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ const KBest::KBestDerivations<Sentence, ESentenceTraversal>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
@@ -245,7 +247,7 @@ struct OracleBleu {
}
}
-void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
+void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
diff --git a/decoder/tdict.h b/decoder/tdict.h
index fd77543d..1fba5179 100644
--- a/decoder/tdict.h
+++ b/decoder/tdict.h
@@ -28,4 +28,12 @@ struct TD {
static const char* Convert(const WordID& w);
};
+struct ToTD {
+ typedef WordID result_type;
+ result_type operator()(std::string const& t) const {
+ return TD::Convert(t);
+ }
+};
+
+
#endif
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index 677c0497..cbda78c5 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -94,6 +94,7 @@ struct oracle_directions {
("oracle_batch,b",po::value<unsigned>(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences")
("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?")
("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)")
+ ("decoder_translations",po::value<string>(&decoder_translations)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU")
("help,h", "Help");
po::options_description dcmdline_options;
dcmdline_options.add(opts);
@@ -194,7 +195,8 @@ struct oracle_directions {
}
- //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive)
+ std::string decoder_translations_file; // one per line
+ //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive). translations are in run.raw.N.gz - new arg
Oracle const& ComputeOracle(unsigned i) {
Oracle &o=oracles[i];
if (o.is_null()) {
@@ -204,7 +206,7 @@ struct oracle_directions {
Timer t("Loading forest from JSON "+forest_file(i));
HypergraphIO::ReadFromJSON(rf.stream(), &hg);
}
- o=oracle.ComputeOracles(MakeMetadata(hg,i),hg,origin,&cerr);
+ o=oracle.ComputeOracles(MakeMetadata(hg,i),&hg,origin,&cerr);
}
return o;
}