diff options
Diffstat (limited to 'decoder/ff_bleu.cc')
-rw-r--r-- | decoder/ff_bleu.cc | 285 |
1 files changed, 285 insertions, 0 deletions
diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc new file mode 100644 index 00000000..4a13f89e --- /dev/null +++ b/decoder/ff_bleu.cc @@ -0,0 +1,285 @@ +#include "ff_bleu.h" + +#include <sstream> +#include <unistd.h> + +#include <boost/shared_ptr.hpp> + +#include "tdict.h" +#include "Vocab.h" +#include "Ngram.h" +#include "hg.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "../vest/scorer.h" + +using namespace std; + +class BLEUModelImpl { + public: + explicit BLEUModelImpl(int order) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + BLEUModelImpl(int order, const string& f) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + + virtual ~BLEUModelImpl() { + } + + inline int StateSize(const void* state) const { + return *(static_cast<const char*>(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast<char*>(state) + state_size_) = size; + } + + void GetRefToNgram() + {} + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast<const int*>(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + { //sum += LookupProbForBufferContents(i); + //cerr << "FT"; + CalcPhrase(buffer_[i], &buffer_[i+1]); + } + } + --i; + } + return sum; + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast<const int*>(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + vector<WordID> CalcPhrase(int word, int* context) { + int i = order_; + vector<WordID> vs; + int c = 1; + vs.push_back(word); + // while (i > 1 && *context > 0) { + while (*context > 0) { + --i; + vs.push_back(*context); + ++context; + ++c; + } + if(false){ cerr << "VS1( "; + vector<WordID>::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n";} + + return vs; + } + + + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate, const SentenceMetadata& smeta) { + + int len = rule.ELength() - rule.Arity(); + + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector<WordID>& e = rule.e(); + + /*cerr << "RULE::" << rule.ELength() << " "; + for (vector<WordID>::const_iterator i = e.begin(); i != e.end(); ++i) + { + const WordID& c = *i; + if(c > 0) cerr << TD::Convert(c) << "--"; + else cerr <<"N--"; + } + cerr << endl; + */ + + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double approx_bleu = 0.0; + int* remnant = reinterpret_cast<int*>(vstate); + int j = 0; + i = len - 1; + int edge = len; + + + vector<WordID> vs; + while (i >= 0) { + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + + //calculate Bvector here + /* cerr << "VS1( "; + vector<WordID>::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n"; + */ + + Score *node_score = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs); + string details; + node_score->ScoreDetails(&details); + const Score *base_score= &smeta.GetScore(); + //cerr << "SWBASE : " << base_score->ComputeScore() << details << " "; + + int src_length = smeta.GetSourceLength(); + node_score->PlusPartialEquals(*base_score, rule.EWords(), rule.FWords(), src_length ); + float oracledoc_factor = (src_length + smeta.GetDocLen())/ src_length; + + //how it seems to be done in code + //TODO: might need to reverse the -1/+1 of the oracle/neg examples + approx_bleu = ( rule.FWords() * oracledoc_factor ) * node_score->ComputeScore(); + //how I thought it was done from the paper + //approx_bleu = ( rule.FWords()+ smeta.GetDocLen() ) * node_score->ComputeScore(); + + if (!remnant){ return approx_bleu;} + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + //cerr << "Return APPROX_BLEU: " << approx_bleu << " "<< DebugStateToString(vstate) << endl; + return approx_bleu; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + protected: + Ngram ngram_; + vector<WordID> buffer_; + const int order_; + const int state_size_; + const double floor_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +BLEUModel::BLEUModel(const string& param) : + fid_(0) { //The partial BLEU score is kept in feature id=0 + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + string filename; + + //loop over argv and load all references into vector of NgramMaps + if (argc < 1) { cerr << "BLEUModel requires a filename, minimally!\n"; abort(); } + + + SetStateSize(BLEUModelImpl::OrderToStateSize(order)); + pimpl_ = new BLEUModelImpl(order, filename); +} + +BLEUModel::~BLEUModel() { + delete pimpl_; +} + +string BLEUModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void BLEUModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + + (void) smeta; + /*cerr << "In BM calling set " << endl; + const Score *s= &smeta.GetScore(); + const int dl = smeta.GetDocLen(); + cerr << "SCO " << s->ComputeScore() << endl; + const DocScorer *ds = &smeta.GetDocScorer(); + */ + + cerr<< "Loading sentence " << smeta.GetSentenceID() << endl; + //} + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state, smeta)); + //cerr << "FID" << fid_ << " " << DebugStateToString(state) << endl; +} + +void BLEUModel::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} |