diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 03:50:05 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 03:50:05 +0000 |
commit | f819992b0b22b4fec88c15fe13118aa6b484b91b (patch) | |
tree | 1bf835e4b29ca926a4ca33a2a57743559c9ba58f /decoder | |
parent | c61c0f2f664eebcc434ce76e6767fccdbdf6fae2 (diff) |
oracle bleu refactor
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@259 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/cdec.cc | 167 | ||||
-rw-r--r-- | decoder/ff_bleu.cc | 61 | ||||
-rw-r--r-- | decoder/ff_bleu.h | 1 | ||||
-rw-r--r-- | decoder/ff_lm.cc | 4 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 154 | ||||
-rw-r--r-- | decoder/scfg_translator.cc | 42 |
6 files changed, 259 insertions, 170 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index c15408b5..bec342ef 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -7,6 +7,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "oracle_bleu.h" #include "timing_stats.h" #include "translator.h" #include "phrasebased_translator.h" @@ -146,11 +147,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* confp) { ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM") ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules") ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") - ("references,R", po::value<vector<string> >(), "Translation reference files") ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value<string>(),"Directory to write forests to") ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + OracleBleu::AddOptions(&opts); po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value<string>(), "Configuration file") @@ -260,14 +261,14 @@ void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { } } -// TODO decoder output should probably be moved to another file -void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, const char *kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, Score* doc_score) { +// TODO decoder output should probably be moved to another file - how about oracle_bleu.h +void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const&kbest_out_filename_, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) { cerr << "In kbest\n"; ofstream kbest_out; - kbest_out.open(kbest_out_filename_); + kbest_out.open(kbest_out_filename_.c_str()); cerr << "Output kbest to " << kbest_out_filename_; - + //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; @@ -298,6 +299,15 @@ cerr << "In kbest\n"; } } +void DumpKBest(po::variables_map const& conf,string const& suffix,const int sent_id, const Hypergraph& forest, const int k, const bool unique, float doc_src_length, float tmp_src_length, const DocScorer &ds, shared_ptr<Score> doc_score) +{ + ostringstream kbest_string_stream; + kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_"<<suffix<< "." << sent_id; + DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), doc_src_length, tmp_src_length, ds, doc_score); + +} + + struct ELengthWeightFunction { double operator()(const Hypergraph::Edge& e) const { return e.rule_->ELength() - e.rule_->Arity(); @@ -517,45 +527,9 @@ int main(int argc, char** argv) { const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); const bool get_oracle_forest = conf.count("get_oracle_forest"); - /*Oracle Extraction Prep*/ - vector<const FeatureFunction*> oracle_model_ffs; - vector<double> oracle_feature_weights; - shared_ptr<FeatureFunction> oracle_pff; - if(get_oracle_forest) { - - /*Add feature for oracle rescoring */ - string ff, param; - ff="BLEUModel"; - //pass the location of the references file via param to BLEUModel - for(int kk=0;kk < conf["references"].as<vector<string> >().size();kk++) - { - param = param + " " + conf["references"].as<vector<string> >()[kk]; - } - cerr << "Feature: " << ff << "->" << param << endl; - oracle_pff = global_ff_registry->Create(ff,param); - if (!oracle_pff) { exit(1); } - oracle_model_ffs.push_back(oracle_pff.get()); - oracle_feature_weights.push_back(1.0); - - } - - ModelSet oracle_models(oracle_feature_weights, oracle_model_ffs); - - const string loss_function3 = "IBM_BLEU_3"; - ScoreType type3 = ScoreTypeFromString(loss_function3); - const DocScorer ds(type3, conf["references"].as<vector<string> >(), ""); - cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function3 << endl; - - - std::ostringstream kbest_string_stream; - Score* doc_score=NULL; - float doc_src_length=0; - float tmp_src_length=0; - int oracle_doc_size= 10; //used for scaling/weighting oracle doc - float scale_oracle= 1-float(1)/oracle_doc_size; - - /*End Oracle Extraction Prep*/ - + OracleBleu oracle; + if (get_oracle_forest) + oracle.UseConf(conf); shared_ptr<WriteFile> extract_file; if (conf.count("extract_rules")) @@ -671,83 +645,37 @@ int main(int argc, char** argv) { vector<WordID> trans; ViterbiESentence(forest, &trans); - + /*Oracle Rescoring*/ - if(get_oracle_forest) + if(get_oracle_forest) { + Timer t("Forest Oracle rescoring:"); + + DumpKBest(conf,"model",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score); + + Translation best(forest); { - Timer t("Forest Oracle rescoring:"); - vector<WordID> model_trans; - model_trans = trans; - - trans=model_trans; - Score* sentscore = ds[sent_id]->ScoreCandidate(model_trans); - //initilize psuedo-doc vector to 1 counts - if (!doc_score) { doc_score = sentscore->GetOne(); } - double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); - tmp_src_length = smeta.GetSourceLength(); - smeta.SetScore(doc_score); - smeta.SetDocLen(doc_src_length); - smeta.SetDocScorer(&ds); - - feature_weights[0]=1.0; - - kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_model" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); - kbest_string_stream.str(""); - - - forest.SortInEdgesByEdgeWeights(); - Hypergraph lm_forest; - const IntersectionConfiguration inter_conf_oracle(0, 0); - cerr << "Going to call Apply Model " << endl; - ApplyModelSet(forest, - smeta, - oracle_models, - inter_conf_oracle, - &lm_forest); - - forest.swap(lm_forest); - forest.Reweight(feature_weights); - forest.SortInEdgesByEdgeWeights(); - vector<WordID> oracle_trans; - - ViterbiESentence(forest, &oracle_trans); - cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; - cerr << " +Oracle BLEU Viterbi: " << TD::GetString(oracle_trans) << endl; - - //compute kbest for oracle - kbest_string_stream << conf["forest_output"].as<string>() <<"/kbest_oracle" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length, ds, doc_score); - kbest_string_stream.str(""); - - - //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples - feature_weights[0]=-1.0; - forest.Reweight(feature_weights); - forest.SortInEdgesByEdgeWeights(); - vector<WordID> neg_trans; - ViterbiESentence(forest, &neg_trans); - cerr << " -Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - cerr << " -Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; - cerr << " -Oracle BLEU Viterbi: " << TD::GetString(neg_trans) << endl; - - //compute kbest for negative - kbest_string_stream << conf["forest_output"].as<string>() << "/kbest_negative" << "." << sent_id; - DumpKBest(sent_id, forest, 10, true, kbest_string_stream.str().c_str(), doc_src_length, tmp_src_length,ds, doc_score); - kbest_string_stream.str(""); - - //Add 1-best translation (trans) to psuedo-doc vectors - doc_score->PlusEquals(*sentscore, scale_oracle); - delete sentscore; - - doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; - - - string details; - doc_score->ScoreDetails(&details); - cerr << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<details << endl; + Hypergraph oracle_forest; + oracle.Rescore(smeta,forest,&oracle_forest,feature_weights,1.0); + forest.swap(oracle_forest); } + Translation oracle_trans(forest); + + cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + oracle_trans.Print(cerr," +Oracle BLEU"); + //compute kbest for oracle + DumpKBest(conf,"oracle",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length, oracle.ds, oracle.doc_score); + + //reweight the model with -1 for the BLEU feature to compute k-best list for negative examples + oracle.ReweightBleu(&forest,-1.0); + Translation neg_trans(forest); + neg_trans.Print(cerr," -Oracle BLEU"); + //compute kbest for negative + DumpKBest(conf,"negative",sent_id, forest, 10, true, oracle.doc_src_length, oracle.tmp_src_length,oracle.ds,oracle.doc_score); + + //Add 1-best translation (trans) to psuedo-doc vectors + oracle.IncludeLastScore(&cerr); + } if (conf.count("forest_output") && !has_ref) { @@ -772,9 +700,8 @@ int main(int argc, char** argv) { if (sample_max_trans) { MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0); } else { - if (kbest) { - DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", doc_src_length, tmp_src_length, ds, doc_score); + DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"", oracle.doc_src_length,oracle.tmp_src_length, oracle.ds,oracle.doc_score); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc index 4a13f89e..ab61ed10 100644 --- a/decoder/ff_bleu.cc +++ b/decoder/ff_bleu.cc @@ -1,10 +1,17 @@ -#include "ff_bleu.h" +namespace { +char const* bleu_usage_name="BLEUModel"; +char const* bleu_usage_short="[-o 3|4]"; +char const* bleu_usage_verbose="Uses feature id 0! Make sure there are no other features whose weights aren't specified or there may be conflicts. Computes oracle with weighted combination of BLEU and model score (from previous model set, using weights on edges?). Performs ngram context expansion; expect reference translation info in sentence metadata; if document scorer is IBM_BLEU_3, then use order 3; otherwise use order 4."; +} + #include <sstream> #include <unistd.h> +#include <boost/lexical_cast.hpp> #include <boost/shared_ptr.hpp> +#include "ff_bleu.h" #include "tdict.h" #include "Vocab.h" #include "Ngram.h" @@ -26,16 +33,6 @@ class BLEUModelImpl { kNONE(-1), kSTAR(TD::Convert("<{STAR}>")) {} - BLEUModelImpl(int order, const string& f) : - ngram_(*TD::dict_, order), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), - floor_(-100.0), - kSTART(TD::Convert("<s>")), - kSTOP(TD::Convert("</s>")), - kUNKNOWN(TD::Convert("<unk>")), - kNONE(-1), - kSTAR(TD::Convert("<{STAR}>")) {} - - virtual ~BLEUModelImpl() { } @@ -49,7 +46,7 @@ class BLEUModelImpl { void GetRefToNgram() {} - + string DebugStateToString(const void* state) const { int len = StateSize(state); const int* astate = reinterpret_cast<const int*>(state); @@ -118,15 +115,15 @@ class BLEUModelImpl { for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) cerr << " " << TD::Convert(*rit); cerr << ")\n";} - + return vs; } double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate, const SentenceMetadata& smeta) { - + int len = rule.ELength() - rule.Arity(); - + for (int i = 0; i < ant_states.size(); ++i) len += StateSize(ant_states[i]); buffer_.resize(len + 1); @@ -168,9 +165,9 @@ class BLEUModelImpl { if (buffer_[i] == kSTAR) { edge = i; } else if (edge-i >= order_) { - + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); - + } else if (edge == len && remnant) { remnant[j++] = buffer_[i]; } @@ -182,7 +179,7 @@ class BLEUModelImpl { vector<WordID>::reverse_iterator rit; for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) cerr << " " << TD::Convert(*rit); - cerr << ")\n"; + cerr << ")\n"; */ Score *node_score = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs); @@ -191,7 +188,7 @@ class BLEUModelImpl { const Score *base_score= &smeta.GetScore(); //cerr << "SWBASE : " << base_score->ComputeScore() << details << " "; - int src_length = smeta.GetSourceLength(); + int src_length = smeta.GetSourceLength(); node_score->PlusPartialEquals(*base_score, rule.EWords(), rule.FWords(), src_length ); float oracledoc_factor = (src_length + smeta.GetDocLen())/ src_length; @@ -234,19 +231,27 @@ class BLEUModelImpl { const WordID kSTAR; }; +string BLEUModel::usage(bool param,bool verbose) { + return usage_helper(bleu_usage_name,bleu_usage_short,bleu_usage_verbose,param,verbose); +} + BLEUModel::BLEUModel(const string& param) : fid_(0) { //The partial BLEU score is kept in feature id=0 vector<string> argv; int argc = SplitOnWhitespace(param, &argv); int order = 3; - string filename; - - //loop over argv and load all references into vector of NgramMaps - if (argc < 1) { cerr << "BLEUModel requires a filename, minimally!\n"; abort(); } - - + + //loop over argv and load all references into vector of NgramMaps + if (argc >= 1) { + if (argv[1] != "-o" || argc<2) { + cerr<<bleu_usage_name<<" specification should be: "<<bleu_usage_short<<"; you provided: "<<param<<endl<<bleu_usage_verbose<<endl; + abort(); + } else + order=boost::lexical_cast<int>(argv[1]); + } + SetStateSize(BLEUModelImpl::OrderToStateSize(order)); - pimpl_ = new BLEUModelImpl(order, filename); + pimpl_ = new BLEUModelImpl(order); } BLEUModel::~BLEUModel() { @@ -261,11 +266,11 @@ void BLEUModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const vector<const void*>& ant_states, SparseVector<double>* features, - SparseVector<double>* estimated_features, + SparseVector<double>* /* estimated_features */, void* state) const { (void) smeta; - /*cerr << "In BM calling set " << endl; + /*cerr << "In BM calling set " << endl; const Score *s= &smeta.GetScore(); const int dl = smeta.GetDocLen(); cerr << "SCO " << s->ComputeScore() << endl; diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h index fb127241..e93731c3 100644 --- a/decoder/ff_bleu.h +++ b/decoder/ff_bleu.h @@ -18,6 +18,7 @@ class BLEUModel : public FeatureFunction { virtual void FinalTraversalFeatures(const void* context, SparseVector<double>* features) const; std::string DebugStateToString(const void* state) const; + static std::string usage(bool param,bool verbose); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 5de9c321..bbf63338 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -1,6 +1,8 @@ +namespace { char const* usage_name="LanguageModel"; char const* usage_short="srilm.gz [-n FeatureName] [-o StateOrder] [-m LimitLoadOrder]"; char const* usage_verbose="-n determines the name of the feature (and its weight). -o defaults to 3. -m defaults to effectively infinite, otherwise says what order lm probs to use (up to). you could use -o > -m but that would be wasteful. -o < -m means some ngrams are scored longer (whenever a word is inserted by a rule next to a variable) than the state would ordinarily allow. NOTE: multiple LanguageModel features are allowed, but they will wastefully duplicate state, except in the special case of -o 1 (which uses no state). subsequent references to the same a.lm.gz. unless they specify -m, will reuse the same SRI LM in memory; this means that the -m used in the first load of a.lm.gz will take effect."; +} //TODO: backoff wordclasses for named entity xltns, esp. numbers. e.g. digits -> @. idealy rule features would specify replacement lm tokens/classes @@ -513,7 +515,7 @@ bool parse_lmspec(std::string const& in, int &order, string &featurename, string if (order > 0 && !filename.empty()) return true; usage: - cerr<<usage_name<<" specification should be: "<<usage_short<<"; you provided: "<<in<<usage_verbose<<endl; + cerr<<usage_name<<" specification should be: "<<usage_short<<"; you provided: "<<in<<endl<<usage_verbose<<endl; return false; } diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h new file mode 100755 index 00000000..273e14b8 --- /dev/null +++ b/decoder/oracle_bleu.h @@ -0,0 +1,154 @@ +#ifndef ORACLE_BLEU_H +#define ORACLE_BLEU_H + +#include <iostream> +#include <string> +#include <vector> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include "../vest/scorer.h" +#include "hg.h" +#include "ff_factory.h" +#include "ff_bleu.h" +#include "sparse_vector.h" +#include "viterbi.h" +#include "sentence_metadata.h" +#include "apply_models.h" + +//TODO: put function impls into .cc +//TODO: disentangle +struct Translation { + typedef std::vector<WordID> Sentence; + Sentence sentence; + FeatureVector features; + Translation() { } + Translation(Hypergraph const& hg,WeightVector *feature_weights=0) + { + Viterbi(hg,feature_weights); + } + void Viterbi(Hypergraph const& hg,WeightVector *feature_weights=0) // weights are only for checking that scoring is correct + { + ViterbiESentence(hg,&sentence); + features=ViterbiFeatures(hg,feature_weights,true); + } + void Print(std::ostream &out,std::string pre=" +Oracle BLEU ") { + out<<pre<<"Viterbi: "<<TD::GetString(sentence)<<"\n"; + out<<pre<<"features: "<<features<<std::endl; + } + +}; + + +struct OracleBleu { + typedef std::vector<std::string> Refs; + Refs refs; + WeightVector feature_weights_; + DocScorer ds; + + static void AddOptions(boost::program_options::options_description *opts) { + using namespace boost::program_options; + using namespace std; + opts->add_options() + ("references,R", value<Refs >(), "Translation reference files") + ("oracle_loss", value<string>(), "IBM_BLEU_3 (default), IBM_BLEU etc") + ; + } + int order; + + //TODO: move cdec.cc kbest output files function here + + //TODO: provide for loading most recent translation for every sentence (no more scale.. etc below? it's possible i messed the below up; i assume it's supposed to gracefully figure out the document 1bests as you go, then keep them up to date as you make multiple MIRA passes. provide alternative loading for MERT + double scale_oracle; + int oracle_doc_size; + double tmp_src_length; + double doc_src_length; + void set_oracle_doc_size(int size) { + oracle_doc_size=size; + scale_oracle= 1-1./oracle_doc_size;\ + doc_src_length=0; + } + OracleBleu(int doc_size=10) { + set_oracle_doc_size(doc_size); + } + + boost::shared_ptr<Score> doc_score,sentscore; // made from factory, so we delete them + + void UseConf(boost::program_options::variables_map const& conf) { + using namespace std; + set_loss(conf["oracle_loss"].as<string>()); + set_refs(conf["references"].as<Refs>()); + } + + ScoreType loss; +// std::string loss_name; + boost::shared_ptr<FeatureFunction> pff; + + void set_loss(std::string const& lossd="IBM_BLEU_3") { +// loss_name=lossd; + loss=ScoreTypeFromString(lossd); + order=(loss==IBM_BLEU_3)?3:4; + std::ostringstream param; + param<<"-o "<<order; + pff=global_ff_registry->Create("BLEUModel",param.str()); + } + + void set_refs(Refs const& r) { + refs=r; + assert(refs.size()); + ds=DocScorer(loss,refs); + doc_score.reset(); +// doc_score=sentscore + std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl; + } + + SentenceMetadata MakeMetadata(Hypergraph const& forest,int sent_id) { + std::vector<WordID> srcsent; + ViterbiFSentence(forest,&srcsent); + SentenceMetadata sm(sent_id,Lattice()); //TODO: make reference from refs? + sm.SetSourceLength(srcsent.size()); + return sm; + } + + void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0) { + Translation model_trans(forest); + sentscore.reset(ds[smeta.GetSentenceID()]->ScoreCandidate(model_trans.sentence)); + if (!doc_score) { doc_score.reset(sentscore->GetOne()); } + tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? + smeta.SetScore(doc_score.get()); + smeta.SetDocLen(doc_src_length); + smeta.SetDocScorer(&ds); + using namespace std; + ModelSet oracle_models(FeatureWeights(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get())); + const IntersectionConfiguration inter_conf_oracle(0, 0); + cerr << "Going to call Apply Model " << endl; + ApplyModelSet(forest, + smeta, + oracle_models, + inter_conf_oracle, + dest_forest); + feature_weights_=feature_weights; + ReweightBleu(dest_forest,bleu_weight); + } + + void IncludeLastScore(std::ostream *out=0) { + double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); + doc_score->PlusEquals(*sentscore, scale_oracle); + sentscore.reset(); + doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; + if (out) { + std::string d; + doc_score->ScoreDetails(&d); + *out << "SCALED SCORE: " << bleu_scale_ << "DOC BLEU " << doc_score->ComputeScore() << " " <<d << std::endl; + } + } + + void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) { + feature_weights_[0]=bleu_weight; + dest_forest->Reweight(feature_weights_); +// dest_forest->SortInEdgesByEdgeWeights(); + } + +}; + + +#endif diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 32acfd65..bfbe44ee 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -27,7 +27,7 @@ struct SCFGTranslatorImpl { max_span_limit(conf["scfg_max_span_limit"].as<int>()), add_pass_through_rules(conf.count("add_pass_through_rules")), goal(conf["goal"].as<string>()), - default_nt(conf["scfg_default_nt"].as<string>()), + default_nt(conf["scfg_default_nt"].as<string>()), use_ctf_(conf.count("coarse_to_fine_beam_prune")) { if(conf.count("grammar")){ @@ -43,7 +43,7 @@ struct SCFGTranslatorImpl { cerr << std::endl; if (conf.count("scfg_extra_glue_grammar")) { GlueGrammar* g = new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>()); - g->SetGrammarName("ExtraGlueGrammar"); + g->SetGrammarName("ExtraGlueGrammar"); grammars.push_back(GrammarPtr(g)); cerr << "Adding glue grammar from file " << conf["scfg_extra_glue_grammar"].as<string>() << endl; } @@ -69,11 +69,11 @@ struct SCFGTranslatorImpl { cerr << "Using coarse-to-fine pruning with " << ctf_iterations_ << " grammar projection(s) and alpha=" << ctf_alpha_ << endl; cerr << " Coarse beam will be widened " << ctf_num_widenings_ << " times by a factor of " << ctf_wide_alpha_ << " if fine parse fails" << endl; } - if (!conf.count("scfg_no_hiero_glue_grammar")){ + if (!conf.count("scfg_no_hiero_glue_grammar")){ GlueGrammar* g = new GlueGrammar(goal, default_nt, ctf_iterations_); g->SetGrammarName("GlueGrammar"); grammars.push_back(GrammarPtr(g)); - cerr << "Adding glue grammar for default nonterminal " << default_nt << + cerr << "Adding glue grammar for default nonterminal " << default_nt << " and goal nonterminal " << goal << endl; } } @@ -123,9 +123,9 @@ struct SCFGTranslatorImpl { foreach(int edge_id, goal_node.in_edges_) RefineRule(forest->edges_[edge_id].rule_, ctf_iterations_); double alpha = ctf_alpha_; - bool found_parse; + bool found_parse=false; for (int i=-1; i < ctf_num_widenings_; ++i) { - cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; + cerr << "Coarse-to-fine source parse, alpha=" << alpha << endl; found_parse = true; Hypergraph refined_forest = *forest; for (int j=0; j < ctf_iterations_; ++j) { @@ -136,7 +136,7 @@ struct SCFGTranslatorImpl { if (RefineForest(&refined_forest)) { cerr << " Refinement succeeded." << endl; refined_forest.Reweight(weights); - } else { + } else { cerr << " Refinement failed. Widening beam." << endl; found_parse = false; break; @@ -152,21 +152,21 @@ struct SCFGTranslatorImpl { if (ctf_exhaustive_){ cerr << "Last resort: refining coarse forest without pruning..."; for (int j=0; j < ctf_iterations_; ++j) { - if (RefineForest(forest)){ + if (RefineForest(forest)){ cerr << " Refinement succeeded." << endl; forest->Reweight(weights); } else { cerr << " Refinement failed. No parse found for this sentence." << endl; return false; } - } - } else + } + } else return false; } } return true; } - + typedef std::pair<int, WordID> StateSplit; typedef std::pair<StateSplit, int> StateSplitPair; typedef std::tr1::unordered_map<StateSplit, int, boost::hash<StateSplit> > Split2Node; @@ -179,17 +179,17 @@ struct SCFGTranslatorImpl { Hypergraph::Node& coarse_goal_node = *(forest->nodes_.end()-1); bool refined_goal_node = false; foreach(Hypergraph::Node& node, forest->nodes_){ - cerr << "."; + cerr << "."; foreach(int edge_id, node.in_edges_) { Hypergraph::Edge& edge = forest->edges_[edge_id]; std::vector<int> nt_positions; TRulePtr& coarse_rule_ptr = edge.rule_; for(int i=0; i< coarse_rule_ptr->f_.size(); ++i){ - if (coarse_rule_ptr->f_[i] < 0) - nt_positions.push_back(i); + if (coarse_rule_ptr->f_[i] < 0) + nt_positions.push_back(i); } if (coarse_rule_ptr->fine_rules_ == 0) { - cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << + cerr << "Parsing with mixed levels of coarse-to-fine granularity is currently unsupported." << endl << "Could not find refinement for: " << coarse_rule_ptr->AsString() << " on edge " << edge_id << " spanning " << edge.i_ << "," << edge.j_ << endl; abort(); } @@ -198,20 +198,20 @@ struct SCFGTranslatorImpl { Hypergraph::TailNodeVector tail; for (int pos_i=0; pos_i<nt_positions.size(); ++pos_i){ WordID fine_cat = fine_rule_ptr->f_[nt_positions[pos_i]]; - Split2Node::iterator it = + Split2Node::iterator it = s2n.find(StateSplit(edge.tail_nodes_[pos_i], fine_cat)); - if (it == s2n.end()) + if (it == s2n.end()) break; - else + else tail.push_back(it->second); } if (tail.size() == nt_positions.size()) { WordID cat = fine_rule_ptr->lhs_; - Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); + Hypergraph::Edge* new_edge = refined_forest.AddEdge(fine_rule_ptr, tail); new_edge->i_ = edge.i_; new_edge->j_ = edge.j_; new_edge->feature_values_ = fine_rule_ptr->GetFeatureValues(); - new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), + new_edge->feature_values_.set_value(FD::Convert("LatticeCost"), edge.feature_values_[FD::Convert("LatticeCost")]); Hypergraph::Node* head_node; Split2Node::iterator it = s2n.find(StateSplit(node.id_, cat)); @@ -221,7 +221,7 @@ struct SCFGTranslatorImpl { splits[node.id_].push_back(cat); if (&node == &coarse_goal_node) refined_goal_node = true; - } else + } else head_node = &(refined_forest.nodes_[it->second]); refined_forest.ConnectEdgeToHeadNode(new_edge, head_node); } |