diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/cdec.cc | 2 | ||||
| -rwxr-xr-x | decoder/oracle_bleu.h | 56 | ||||
| -rw-r--r-- | decoder/tdict.h | 8 | 
3 files changed, 38 insertions, 28 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 75c907b1..77179948 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -604,7 +604,7 @@ int main(int argc, char** argv) {      /*Oracle Rescoring*/      if(get_oracle_forest) { -      Oracles o=oracles.ComputeOracles(smeta,forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>()); +      Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>());        cerr << "  +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;        cerr << "  +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;        o.hope.Print(cerr,"  +Oracle BLEU"); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 6708b02e..b58117c1 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -16,6 +16,7 @@  #include "sentence_metadata.h"  #include "apply_models.h"  #include "kbest.h" +#include "timing_stats.h"  //TODO: put function impls into .cc  //TODO: disentangle @@ -51,15 +52,15 @@ struct Oracles {    Translation model,fear,hope;    // feature 0 will be the error rate in fear and hope    // move toward hope -  WeightVector ModelHopeGradient() { -    WeightVector r=hope-model; -    r[0]=0; +  FeatureVector ModelHopeGradient() { +    FeatureVector r=hope.features-model.features; +    r.set_value(0,0);      return r;    }    // move toward hope from fear -  WeightVector FearHopeGradient() { -    Weightvector r=hope-fear; -    r[0]=0; +  FeatureVector FearHopeGradient() { +    FeatureVector r=hope.features-fear.features; +    r.set_value(0,0);      return r;    }  }; @@ -100,6 +101,12 @@ struct OracleBleu {    typedef boost::shared_ptr<Score> ScoreP;    ScoreP doc_score,sentscore; // made from factory, so we delete them +  ScoreP GetScore(Sentence const& sentence,int sent_id) { +    return ScoreP(ds[sent_id]->ScoreCandidate(sentence)); +  } +  ScoreP GetScore(Hypergraph const& forest,int sent_id) { +    return GetScore(Translation(forest).sentence,sent_id); +  }    double bleu_weight;    void UseConf(boost::program_options::variables_map const& conf) { @@ -134,44 +141,39 @@ struct OracleBleu {    SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) {      std::vector<WordID> srcsent;      ViterbiFSentence(forest,&srcsent); -    SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs? -    sm.SetSourceLength(srcsent.size()); +    SentenceMetadata smeta(sent_id,Lattice()); //TODO: make reference from refs? +    smeta.SetSourceLength(srcsent.size());  	smeta.SetScore(doc_score.get());  	smeta.SetDocScorer(&ds);  	smeta.SetDocLen(doc_src_length); -    return sm; +    return smeta;    } -  Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph const& forest,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { +  // destroys forest (replaces it w/ rescored oracle one) +  Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { +    Hypergraph &forest=*forest_in_out;      Oracles r;      int sent_id=smeta.GetSentenceID();      r.model=Translation(forest); -      if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output);      {        Timer t("Forest Oracle rescoring:");        Hypergraph oracle_forest; -      Rescore(smeta,forest,&oracle_forest,feature_weights,blue_weight,log); +      Rescore(smeta,forest,&oracle_forest,feature_weights,bleu_weight,log);        forest.swap(oracle_forest);      }      r.hope=Translation(forest);      if (kbest) DumpKBest("oracle",sent_id, forest, kbest, true, forest_output); -    oracle.ReweightBleu(&forest,-blue_weight); +    ReweightBleu(&forest,-bleu_weight);      r.fear=Translation(forest);      if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output);      return r;    } - -  ScoreP Score(Sentence const& sentence,int sent_id) { -    return ds[sent_id]->ScoreCandidate(sentence); -  } -  ScoreP Score(Hypergraph const& forest,int sent_id) { -    return Score(model_trans(forest).translation,sent_id); -  } +  typedef std::vector<WordID> Sentence;    void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {      // the sentence bleu stats will get added to doc only if you call IncludeLastScore -    sentscore=Score(forest,smeta.GetSentenceID()); +    sentscore=GetScore(forest,smeta.GetSentenceID());  	if (!doc_score) { doc_score.reset(sentscore->GetOne()); }  	tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?      using namespace std; @@ -219,13 +221,13 @@ struct OracleBleu {      float curr_src_length = doc_src_length + tmp_src_length;      if (unique) { -      KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); +      KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);        for (int i = 0; i < k; ++i) { -        const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = +        const KBest::KBestDerivations<Sentence, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =            kbest.LazyKthBest(forest.nodes_.size() - 1, i);          if (!d) break;          //calculate score in context of psuedo-doc -        Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); +        Score* sentscore = GetScore(d->yield,sent_id);          sentscore->PlusEquals(*doc_score,float(1));          float bleu = curr_src_length * sentscore->ComputeScore();          kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " @@ -234,9 +236,9 @@ struct OracleBleu {          //     << d->feature_values << " ||| " << log(d->score) << endl;        }      } else { -      KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); +      KBest::KBestDerivations<Sentence, ESentenceTraversal> kbest(forest, k);        for (int i = 0; i < k; ++i) { -        const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = +        const KBest::KBestDerivations<Sentence, ESentenceTraversal>::Derivation* d =            kbest.LazyKthBest(forest.nodes_.size() - 1, i);          if (!d) break;          cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " @@ -245,7 +247,7 @@ struct OracleBleu {      }    } -void DumpKBest(boost::program_options::variables_map const& conf,std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output) +void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const& forest_output)    {      std::ostringstream kbest_string_stream;      kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; diff --git a/decoder/tdict.h b/decoder/tdict.h index fd77543d..1fba5179 100644 --- a/decoder/tdict.h +++ b/decoder/tdict.h @@ -28,4 +28,12 @@ struct TD {    static const char* Convert(const WordID& w);  }; +struct ToTD { +  typedef WordID result_type; +  result_type operator()(std::string const& t) const { +    return TD::Convert(t); +  } +}; + +  #endif  | 
