diff options
-rw-r--r-- | decoder/cdec.cc | 2 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 56 | ||||
-rw-r--r-- | decoder/tdict.h | 8 | ||||
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 6 |
4 files changed, 42 insertions, 30 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 75c907b1..77179948 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -604,7 +604,7 @@ int main(int argc, char** argv) { /*Oracle Rescoring*/ if(get_oracle_forest) { - Oracles o=oracles.ComputeOracles(smeta,forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>()); + Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>()); cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; o.hope.Print(cerr," +Oracle BLEU"); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 6708b02e..b58117c1 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -16,6 +16,7 @@ #include "sentence_metadata.h" #include "apply_models.h" #include "kbest.h" +#include "timing_stats.h" //TODO: put function impls into .cc //TODO: disentangle @@ -51,15 +52,15 @@ struct Oracles { Translation model,fear,hope; // feature 0 will be the error rate in fear and hope // move toward hope - WeightVector ModelHopeGradient() { - WeightVector r=hope-model; - r[0]=0; + FeatureVector ModelHopeGradient() { + FeatureVector r=hope.features-model.features; + r.set_value(0,0); return r; } // move toward hope from fear - WeightVector FearHopeGradient() { - Weightvector r=hope-fear; - r[0]=0; + FeatureVector FearHopeGradient() { + FeatureVector r=hope.features-fear.features; + r.set_value(0,0); return r; } }; @@ -100,6 +101,12 @@ struct OracleBleu { typedef boost::shared_ptr<Score> ScoreP; ScoreP doc_score,sentscore; // made from factory, so we delete them + ScoreP GetScore(Sentence const& sentence,int sent_id) { + return ScoreP(ds[sent_id]->ScoreCandidate(sentence)); + } + ScoreP GetScore(Hypergraph const& forest,int sent_id) { + return GetScore(Translation(forest).sentence,sent_id); + } double bleu_weight; void UseConf(boost::program_options::variables_map const& conf) { @@ -134,44 +141,39 @@ struct OracleBleu { SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) { std::vector<WordID> srcsent; ViterbiFSentence(forest,&srcsent); - SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs? - sm.SetSourceLength(srcsent.size()); + SentenceMetadata smeta(sent_id,Lattice()); //TODO: make reference from refs? + smeta.SetSourceLength(srcsent.size()); smeta.SetScore(doc_score.get()); smeta.SetDocScorer(&ds); smeta.SetDocLen(doc_src_length); - return sm; + return smeta; } - Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { + // destroys forest (replaces it w/ rescored oracle one) + Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { + Hypergraph &forest=*forest_in_out; Oracles r; int sent_id=smeta.GetSentenceID(); r.model=Translation(forest); - if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output); { Timer t("Forest Oracle rescoring:"); Hypergraph oracle_forest; - Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log); + Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight,log); forest.swap(oracle_forest); } r.hope=Translation(forest); if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output); - oracle.ReweightBleu(&forest,-blue_weight); + ReweightBleu(&forest,-bleu_weight); r.fear=Translation(forest); if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output); return r; } - - ScoreP Score(Sentence const& sentence,int sent_id) { - return ds[sent_id]->ScoreCandidate(sentence); - } - ScoreP Score(Hypergraph const& forest,int sent_id) { - return Score(model_trans(forest).translation,sent_id); - } + typedef std::vector<WordID> Sentence; void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { // the sentence bleu stats will get added to doc only if you call IncludeLastScore - sentscore=Score(forest,smeta.GetSentenceID()); + sentscore=GetScore(forest,smeta.GetSentenceID()); if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? using namespace std; @@ -219,13 +221,13 @@ struct OracleBleu { float curr_src_length = doc_src_length + tmp_src_length; if (unique) { - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); for (int i = 0; i < k; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + const KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; //calculate score in context of psuedo-doc - Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); + Score* sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); float bleu = curr_src_length * sentscore->ComputeScore(); kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " @@ -234,9 +236,9 @@ struct OracleBleu { // << d->feature_values << " ||| " << log(d->score) << endl; } } else { - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); + KBest::KBestDerivations<Sentence, ESentenceTraversal> kbest(forest, k); for (int i = 0; i < k; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + const KBest::KBestDerivations<Sentence, ESentenceTraversal>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " @@ -245,7 +247,7 @@ struct OracleBleu { } } -void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output) +void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output) { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; diff --git a/decoder/tdict.h b/decoder/tdict.h index fd77543d..1fba5179 100644 --- a/decoder/tdict.h +++ b/decoder/tdict.h @@ -28,4 +28,12 @@ struct TD { static const char* Convert(const WordID& w); }; +struct ToTD { + typedef WordID result_type; + result_type operator()(std::string const& t) const { + return TD::Convert(t); + } +}; + + #endif diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 677c0497..cbda78c5 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -94,6 +94,7 @@ struct oracle_directions { ("oracle_batch,b",po::value<unsigned>(&oracle_batch)->default_value(10),"to produce each oracle direction, sum the 'gradient' over this many sentences") ("max_similarity,m",po::value<double>(&max_similarity)->default_value(0),"remove directions that are too similar (Tanimoto coeff. less than (1-this)). 0 means don't filter, 1 means only 1 direction allowed?") ("fear_to_hope,f",po::bool_switch(&fear_to_hope),"for each of the oracle_directions, also include a direction from fear to hope (as well as origin to hope)") + ("decoder_translations",po::value<string>(&decoder_translations)->default_value(""),"one per line decoder 1best translations for computing document BLEU vs. sentences-seen-so-far BLEU") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -194,7 +195,8 @@ struct oracle_directions { } - //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive) + std::string decoder_translations_file; // one per line + //TODO: is it worthwhile to get a complete document bleu first? would take a list of 1best translations one per line from the decoders, rather than loading all the forests (expensive). translations are in run.raw.N.gz - new arg Oracle const& ComputeOracle(unsigned i) { Oracle &o=oracles[i]; if (o.is_null()) { @@ -204,7 +206,7 @@ struct oracle_directions { Timer t("Loading forest from JSON "+forest_file(i)); HypergraphIO::ReadFromJSON(rf.stream(), &hg); } - o=oracle.ComputeOracles(MakeMetadata(hg,i),hg,origin,&cerr); + o=oracle.ComputeOracles(MakeMetadata(hg,i),&hg,origin,&cerr); } return o; } |