diff options
author | vladimir.eidelman <vladimir.eidelman@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-14 23:00:08 +0000 |
---|---|---|
committer | vladimir.eidelman <vladimir.eidelman@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-14 23:00:08 +0000 |
commit | 1350b8e8e465acc9d4d8d43d807cc6093e8f37b9 (patch) | |
tree | ddbf972363b1d51ecca6d27e1ef226391a4e7151 /decoder | |
parent | dc6e2c9c453a76f0bb3dfbca4471e763cc8af1e7 (diff) |
Added oracle forest rescoring
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@254 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/Makefile.am | 7 | ||||
-rw-r--r-- | decoder/cdec.cc | 154 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 2 | ||||
-rw-r--r-- | decoder/ff_bleu.cc | 285 | ||||
-rw-r--r-- | decoder/ff_bleu.h | 32 | ||||
-rw-r--r-- | decoder/sentence_metadata.h | 13 |
6 files changed, 485 insertions, 8 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 49aa45d0..e7b6abd8 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -74,6 +74,13 @@ libcdec_a_SOURCES = \ ff_wordalign.cc \ ff_csplit.cc \ ff_tagger.cc \ + ff_bleu.cc \ + ../vest/scorer.cc \ + ../vest/ter.cc \ + ../vest/aer_scorer.cc \ + ../vest/comb_scorer.cc \ + ../vest/error_surface.cc \ + ../vest/viterbi_envelope.cc \ tromble_loss.cc \ freqdict.cc \ lexalign.cc \ diff --git a/decoder/cdec.cc b/decoder/cdec.cc index b6cc6f66..5f06b0c8 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -32,6 +32,7 @@ #include "inside_outside.h" #include "exp_semiring.h" #include "sentence_metadata.h" +#include "../vest/scorer.h" using namespace std; using namespace std::tr1; @@ -143,7 +144,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) { ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion") ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)") ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM") - ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") + ("get_oracle_forest,OO", "Calculate rescored hypregraph using approximate BLEU scoring of rules") + ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") + ("references,R", po::value<vector<string> >(), "Translation reference files") ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value<string>(),"Directory to write forests to") @@ -258,16 +261,30 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { } // TODO decoder output should probably be moved to another file -void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) { +void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, const char *kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, Score* doc_score) { cerr << "In kbest\n"; + + ofstream kbest_out; + kbest_out.open(kbest_out_filename_); + cerr << "Output kbest to " << kbest_out_filename_; + + //add length (f side) src length of this sentence to the psuedo-doc src length count + float curr_src_length = doc_src_length + tmp_src_length; + if (unique) { KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); for (int i = 0; i < k; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score) << endl; + //calculate score in context of psuedo-doc + Score* sentscore = ds[sent_id]->ScoreCandidate(d->yield); + sentscore->PlusEquals(*doc_score,float(1)); + float bleu = curr_src_length * sentscore->ComputeScore(); + kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + << d->feature_values << " ||| " << log(d->score) << " ||| " << bleu << endl; + // cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " + // << d->feature_values << " ||| " << log(d->score) << endl; } } else { KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k); @@ -498,6 +515,48 @@ int main(int argc, char** argv) { const bool kbest = conf.count("k_best"); const bool unique_kbest = conf.count("unique_k_best"); const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); + const bool get_oracle_forest = conf.count("get_oracle_forest"); + + /*Oracle Extraction Prep*/ + vector<const FeatureFunction*> oracle_model_ffs; + vector<double> oracle_feature_weights; + shared_ptr<FeatureFunction> oracle_pff; + if(get_oracle_forest) { + + /*Add feature for oracle rescoring */ + string ff, param; + ff="BLEUModel"; + //pass the location of the references file via param to BLEUModel + for(int kk=0;kk < conf["references"].as<vector<string> >().size();kk++) + { + param = param + " " + conf["references"].as<vector<string> >()[kk]; + } + cerr << "Feature: " << ff << "->" << param << endl; + oracle_pff = global_ff_registry->Create(ff,param); + if (!oracle_pff) { exit(1); } + oracle_model_ffs.push_back(oracle_pff.get()); + oracle_feature_weights.push_back(1.0); + + } + + ModelSet oracle_models(oracle_feature_weights, oracle_model_ffs); + + const string loss_function3 = "IBM_BLEU_3"; + ScoreType type3 = ScoreTypeFromString(loss_function3); + const DocScorer ds(type3, conf["references"].as<vector<string> >(), ""); + cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function3 << endl; + + + std::ostringstream kbest_string_stream; + Score* doc_score=NULL; + float doc_src_length=0; + float tmp_src_length=0; + int oracle_doc_size= 10; //used for scaling/weighting oracle doc + float scale_oracle= 1-float(1)/oracle_doc_size; + + /*End Oracle Extraction Prep*/ + + shared_ptr<WriteFile> extract_file; if (conf.count("extract_rules")) extract_file.reset(new WriteFile(str("extract_rules",conf))); @@ -610,6 +669,87 @@ int main(int argc, char** argv) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); + vector<WordID> trans; + ViterbiESentence(forest, &trans); + + /*Oracle Rescoring*/ + if(get_oracle_forest) + { + Timer t("Forest Oracle rescoring:"); + vector<WordID> model_trans; + model_trans = trans; + + trans=model_trans; + Score* sentscore = ds[sent_id]->ScoreCandidate(model_trans); + //initilize psuedo-doc vector to 1 counts + if (!doc_score) { doc_score = sentscore->GetOne(); } + double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); + tmp_src_length = smeta.GetSourceLength(); + smeta.SetScore(doc_score); + smeta.SetDocLen(doc_src_length); + smeta.SetDocScorer(&ds); + + feature_weights[0]=1.0; + + kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_model" << "." << sent_id; + DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); + kbest_string_stream.str(""); + + + forest.SortInEdgesByEdgeWeights(); + Hypergraph lm_forest; + const IntersectionConfiguration inter_conf_oracle(0, 0); + cerr << "Going to call Apply Model " << endl; + ApplyModelSet(forest, + smeta, + oracle_models, + inter_conf_oracle, + &lm_forest); + + forest.swap(lm_forest); + forest.Reweight(feature_weights); + forest.SortInEdgesByEdgeWeights(); + vector<WordID> oracle_trans; + + ViterbiESentence(forest, &oracle_trans); + cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + cerr << " +Oracle BLEU Viterbi: " << TD::GetString(oracle_trans) << endl; + + //compute kbest for oracle + kbest_string_stream << conf["forest_output"].as<string>() <<"/kbest_oracle" << "." << sent_id; + DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); + kbest_string_stream.str(""); + + + //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples + feature_weights[0]=-1.0; + forest.Reweight(feature_weights); + forest.SortInEdgesByEdgeWeights(); + vector<WordID> neg_trans; + ViterbiESentence(forest, &neg_trans); + cerr << " -Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " -Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + cerr << " -Oracle BLEU Viterbi: " << TD::GetString(neg_trans) << endl; + + //compute kbest for negative + kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_negative" << "." << sent_id; + DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length,ds, doc_score); + kbest_string_stream.str(""); + + //Add 1-best translation (trans) to psuedo-doc vectors + doc_score->PlusEquals(*sentscore, scale_oracle); + delete sentscore; + + doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; + + + string details; + doc_score->ScoreDetails(&details); + cerr << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<details << endl; + } + + if (conf.count("forest_output") && !has_ref) { ForestWriter writer(str("forest_output",conf), sent_id); if (FileExists(writer.fname_)) { @@ -632,11 +772,9 @@ int main(int argc, char** argv) { if (sample_max_trans) { MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); } else { - vector<WordID> trans; - ViterbiESentence(forest, &trans); - + if (kbest) { - DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest); + DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", doc_src_length, tmp_src_length, ds, doc_score); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 077956a8..c91780e2 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -7,6 +7,7 @@ #include "ff_tagger.h" #include "ff_factory.h" #include "ff_ruleshape.h" +#include "ff_bleu.h" boost::shared_ptr<FFRegistry> global_ff_registry; @@ -20,6 +21,7 @@ void register_feature_functions() { global_ff_registry->Register(new FFFactory<WordPenalty>); global_ff_registry->Register(new FFFactory<SourceWordPenalty>); global_ff_registry->Register(new FFFactory<ArityPenalty>); + global_ff_registry->Register("BLEUModel", new FFFactory<BLEUModel>); global_ff_registry->Register("RuleShape", new FFFactory<RuleShapeFeatures>); global_ff_registry->Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); global_ff_registry->Register("Model2BinaryFeatures", new FFFactory<Model2BinaryFeatures>); diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc new file mode 100644 index 00000000..4a13f89e --- /dev/null +++ b/decoder/ff_bleu.cc @@ -0,0 +1,285 @@ +#include "ff_bleu.h" + +#include <sstream> +#include <unistd.h> + +#include <boost/shared_ptr.hpp> + +#include "tdict.h" +#include "Vocab.h" +#include "Ngram.h" +#include "hg.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "../vest/scorer.h" + +using namespace std; + +class BLEUModelImpl { + public: + explicit BLEUModelImpl(int order) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + BLEUModelImpl(int order, const string& f) : + ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + kSTART(TD::Convert("<s>")), + kSTOP(TD::Convert("</s>")), + kUNKNOWN(TD::Convert("<unk>")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + + virtual ~BLEUModelImpl() { + } + + inline int StateSize(const void* state) const { + return *(static_cast<const char*>(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast<char*>(state) + state_size_) = size; + } + + void GetRefToNgram() + {} + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast<const int*>(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + { //sum += LookupProbForBufferContents(i); + //cerr << "FT"; + CalcPhrase(buffer_[i], &buffer_[i+1]); + } + } + --i; + } + return sum; + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast<const int*>(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + vector<WordID> CalcPhrase(int word, int* context) { + int i = order_; + vector<WordID> vs; + int c = 1; + vs.push_back(word); + // while (i > 1 && *context > 0) { + while (*context > 0) { + --i; + vs.push_back(*context); + ++context; + ++c; + } + if(false){ cerr << "VS1( "; + vector<WordID>::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n";} + + return vs; + } + + + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate, const SentenceMetadata& smeta) { + + int len = rule.ELength() - rule.Arity(); + + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector<WordID>& e = rule.e(); + + /*cerr << "RULE::" << rule.ELength() << " "; + for (vector<WordID>::const_iterator i = e.begin(); i != e.end(); ++i) + { + const WordID& c = *i; + if(c > 0) cerr << TD::Convert(c) << "--"; + else cerr <<"N--"; + } + cerr << endl; + */ + + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double approx_bleu = 0.0; + int* remnant = reinterpret_cast<int*>(vstate); + int j = 0; + i = len - 1; + int edge = len; + + + vector<WordID> vs; + while (i >= 0) { + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + + //calculate Bvector here + /* cerr << "VS1( "; + vector<WordID>::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n"; + */ + + Score *node_score = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs); + string details; + node_score->ScoreDetails(&details); + const Score *base_score= &smeta.GetScore(); + //cerr << "SWBASE : " << base_score->ComputeScore() << details << " "; + + int src_length = smeta.GetSourceLength(); + node_score->PlusPartialEquals(*base_score, rule.EWords(), rule.FWords(), src_length ); + float oracledoc_factor = (src_length + smeta.GetDocLen())/ src_length; + + //how it seems to be done in code + //TODO: might need to reverse the -1/+1 of the oracle/neg examples + approx_bleu = ( rule.FWords() * oracledoc_factor ) * node_score->ComputeScore(); + //how I thought it was done from the paper + //approx_bleu = ( rule.FWords()+ smeta.GetDocLen() ) * node_score->ComputeScore(); + + if (!remnant){ return approx_bleu;} + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + //cerr << "Return APPROX_BLEU: " << approx_bleu << " "<< DebugStateToString(vstate) << endl; + return approx_bleu; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + protected: + Ngram ngram_; + vector<WordID> buffer_; + const int order_; + const int state_size_; + const double floor_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +BLEUModel::BLEUModel(const string& param) : + fid_(0) { //The partial BLEU score is kept in feature id=0 + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + string filename; + + //loop over argv and load all references into vector of NgramMaps + if (argc < 1) { cerr << "BLEUModel requires a filename, minimally!\n"; abort(); } + + + SetStateSize(BLEUModelImpl::OrderToStateSize(order)); + pimpl_ = new BLEUModelImpl(order, filename); +} + +BLEUModel::~BLEUModel() { + delete pimpl_; +} + +string BLEUModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void BLEUModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + + (void) smeta; + /*cerr << "In BM calling set " << endl; + const Score *s= &smeta.GetScore(); + const int dl = smeta.GetDocLen(); + cerr << "SCO " << s->ComputeScore() << endl; + const DocScorer *ds = &smeta.GetDocScorer(); + */ + + cerr<< "Loading sentence " << smeta.GetSentenceID() << endl; + //} + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state, smeta)); + //cerr << "FID" << fid_ << " " << DebugStateToString(state) << endl; +} + +void BLEUModel::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h new file mode 100644 index 00000000..fb127241 --- /dev/null +++ b/decoder/ff_bleu.h @@ -0,0 +1,32 @@ +#ifndef _BLEU_FF_H_ +#define _BLEU_FF_H_ + +#include <vector> +#include <string> + +#include "hg.h" +#include "ff.h" +#include "config.h" + +class BLEUModelImpl; + +class BLEUModel : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + BLEUModel(const std::string& param); + ~BLEUModel(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector<double>* features) const; + std::string DebugStateToString(const void* state) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + const int fid_; + mutable BLEUModelImpl* pimpl_; +}; +#endif diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index ef9eb388..21be9b21 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -3,6 +3,7 @@ #include <cassert> #include "lattice.h" +#include "../vest/scorer.h" struct SentenceMetadata { SentenceMetadata(int id, const Lattice& ref) : @@ -30,10 +31,22 @@ struct SentenceMetadata { // this will be empty if the translator accepts non FS input! const Lattice& GetSourceLattice() const { return src_lattice_; } + // access to document level scores for MIRA vector computation + void SetScore(Score *s){app_score=s;} + void SetDocScorer (const DocScorer *d){ds = d;} + void SetDocLen(double dl){doc_len = dl;} + + const Score& GetScore() const { return *app_score; } + const DocScorer& GetDocScorer() const { return *ds; } + double GetDocLen() const {return doc_len;} + private: const int sent_id_; // the following should be set, if possible, by the Translator int src_len_; + double doc_len; + const DocScorer* ds; + const Score* app_score; public: Lattice src_lattice_; // this will only be set if inputs are finite state! private: |