summaryrefslogtreecommitdiff
path: root/decoder/oracle_bleu.h
blob: 273e14b86a88fcb19f4f6b7f16ad16ce58fa5210 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#ifndef ORACLE_BLEU_H
#define ORACLE_BLEU_H

#include <iostream>
#include <string>
#include <vector>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include "../vest/scorer.h"
#include "hg.h"
#include "ff_factory.h"
#include "ff_bleu.h"
#include "sparse_vector.h"
#include "viterbi.h"
#include "sentence_metadata.h"
#include "apply_models.h"

//TODO: put function impls into .cc
//TODO: disentangle
struct Translation {
  typedef std::vector<WordID> Sentence;
  Sentence sentence;
  FeatureVector features;
  Translation() {  }
  Translation(Hypergraph const& hg,WeightVector *feature_weights=0)
  {
    Viterbi(hg,feature_weights);
  }
  void Viterbi(Hypergraph const& hg,WeightVector *feature_weights=0) // weights are only for checking that scoring is correct
  {
    ViterbiESentence(hg,&sentence);
    features=ViterbiFeatures(hg,feature_weights,true);
  }
  void Print(std::ostream &out,std::string pre="   +Oracle BLEU ") {
    out<<pre<<"Viterbi: "<<TD::GetString(sentence)<<"\n";
    out<<pre<<"features: "<<features<<std::endl;
  }

};


struct OracleBleu {
  typedef std::vector<std::string> Refs;
  Refs refs;
  WeightVector feature_weights_;
  DocScorer ds;

  static void AddOptions(boost::program_options::options_description *opts) {
    using namespace boost::program_options;
    using namespace std;
    opts->add_options()
      ("references,R", value<Refs >(), "Translation reference files")
      ("oracle_loss", value<string>(), "IBM_BLEU_3 (default), IBM_BLEU etc")
      ;
  }
  int order;

  //TODO: move cdec.cc kbest output files function here

  //TODO: provide for loading most recent translation for every sentence (no more scale.. etc below? it's possible i messed the below up; i assume it's supposed to gracefully figure out the document 1bests as you go, then keep them up to date as you make multiple MIRA passes.  provide alternative loading for MERT
  double scale_oracle;
  int oracle_doc_size;
  double tmp_src_length;
  double doc_src_length;
  void set_oracle_doc_size(int size) {
    oracle_doc_size=size;
    scale_oracle=  1-1./oracle_doc_size;\
    doc_src_length=0;
  }
  OracleBleu(int doc_size=10) {
    set_oracle_doc_size(doc_size);
  }

  boost::shared_ptr<Score> doc_score,sentscore; // made from factory, so we delete them

  void UseConf(boost::program_options::variables_map const& conf) {
    using namespace std;
    set_loss(conf["oracle_loss"].as<string>());
    set_refs(conf["references"].as<Refs>());
  }

  ScoreType loss;
//  std::string loss_name;
  boost::shared_ptr<FeatureFunction> pff;

  void set_loss(std::string const& lossd="IBM_BLEU_3") {
//    loss_name=lossd;
    loss=ScoreTypeFromString(lossd);
    order=(loss==IBM_BLEU_3)?3:4;
    std::ostringstream param;
    param<<"-o "<<order;
    pff=global_ff_registry->Create("BLEUModel",param.str());
  }

  void set_refs(Refs const& r) {
    refs=r;
    assert(refs.size());
    ds=DocScorer(loss,refs);
    doc_score.reset();
//    doc_score=sentscore
    std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl;
  }

  SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) {
    std::vector<WordID> srcsent;
    ViterbiFSentence(forest,&srcsent);
    SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs?
    sm.SetSourceLength(srcsent.size());
    return sm;
  }

  void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) {
    Translation model_trans(forest);
    sentscore.reset(ds[smeta.GetSentenceID()]->ScoreCandidate(model_trans.sentence));
	if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
	tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
	smeta.SetScore(doc_score.get());
	smeta.SetDocLen(doc_src_length);
	smeta.SetDocScorer(&ds);
    using namespace std;
    ModelSet oracle_models(FeatureWeights(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get()));
    const IntersectionConfiguration inter_conf_oracle(0, 0);
	cerr << "Going to call Apply Model " << endl;
	ApplyModelSet(forest,
                  smeta,
                  oracle_models,
                  inter_conf_oracle,
                  dest_forest);
    feature_weights_=feature_weights;
    ReweightBleu(dest_forest,bleu_weight);
  }

  void IncludeLastScore(std::ostream *out=0) {
    double bleu_scale_ = doc_src_length * doc_score->ComputeScore();
    doc_score->PlusEquals(*sentscore, scale_oracle);
	sentscore.reset();
    doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle;
    if (out) {
      std::string d;
      doc_score->ScoreDetails(&d);
      *out << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<d << std::endl;
    }
  }

  void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) {
    feature_weights_[0]=bleu_weight;
	dest_forest->Reweight(feature_weights_);
//	dest_forest->SortInEdgesByEdgeWeights();
  }

};


#endif