diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/apply_models.h | 2 | ||||
-rw-r--r-- | decoder/cdec.cc | 18 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 28 | ||||
-rwxr-xr-x | decoder/sentences.h | 53 | ||||
-rw-r--r-- | decoder/sparse_vector.h | 20 | ||||
-rw-r--r-- | decoder/stringlib.h | 5 |
6 files changed, 103 insertions, 23 deletions
diff --git a/decoder/apply_models.h b/decoder/apply_models.h index 5c220afd..61a5b8f7 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -11,7 +11,7 @@ struct IntersectionConfiguration { const int algorithm; // 0 = full intersection, 1 = cube pruning const int pop_limit; // max number of pops off the heap at each node IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {} - IntersectionConfiguration(exhaustive_t t) : algorithm(0), pop_limit() {(void)t;} + IntersectionConfiguration(exhaustive_t /* t */) : algorithm(0), pop_limit() {} }; void ApplyModelSet(const Hypergraph& in, diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 77179948..8827cce3 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -323,6 +323,12 @@ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_featur } } +void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights) { + WeightVector fw(feature_weights); + forest_stats(forest,name,show_tree,show_features,&fw); +} + + void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { double beam_prune=0,density_prune=0; bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); @@ -390,9 +396,9 @@ int main(int argc, char** argv) { prelm_w.InitFromFile(plmw); prelm_feature_weights.resize(FD::NumFeats()); prelm_w.InitVector(&prelm_feature_weights); -// cerr << "prelm_weights: " << FeatureVector(prelm_feature_weights)<<endl; +// cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl; } -// cerr << "+LM weights: " << FeatureVector(feature_weights)<<endl; +// cerr << "+LM weights: " << WeightVector(feature_weights)<<endl; } bool warn0=conf.count("warn_0_weight"); bool freeze=!conf.count("no_freeze_feature_set"); @@ -548,7 +554,7 @@ int main(int argc, char** argv) { } const bool show_tree_structure=conf.count("show_tree_structure"); const bool show_features=conf.count("show_features"); - forest_stats(forest," -LM forest",show_tree_structure,show_features,&feature_weights); + forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights); if (conf.count("show_expected_length")) { const PRPair<double, double> res = Inside<PRPair<double, double>, @@ -574,7 +580,7 @@ int main(int argc, char** argv) { &prelm_forest); forest.swap(prelm_forest); forest.Reweight(prelm_feature_weights); - forest_stats(forest," prelm forest",show_tree_structure,show_features,&prelm_feature_weights); + forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights); } maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); @@ -593,7 +599,7 @@ int main(int argc, char** argv) { &lm_forest); forest.swap(lm_forest); forest.Reweight(feature_weights); - forest_stats(forest," +LM forest",show_tree_structure,show_features,&feature_weights); + forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights); } maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); @@ -604,7 +610,7 @@ int main(int argc, char** argv) { /*Oracle Rescoring*/ if(get_oracle_forest) { - Oracles o=oracles.ComputeOracles(smeta,&forest,feature_weights,&cerr,10,conf["forest_output"].as<std::string>()); + Oracle o=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),&cerr,10,conf["forest_output"].as<std::string>()); cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; o.hope.Print(cerr," +Oracle BLEU"); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index b58117c1..cc19fbca 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -17,6 +17,7 @@ #include "apply_models.h" #include "kbest.h" #include "timing_stats.h" +#include "sentences.h" //TODO: put function impls into .cc //TODO: disentangle @@ -44,7 +45,7 @@ struct Translation { }; -struct Oracles { +struct Oracle { bool is_null() { return model.is_null() /* && fear.is_null() && hope.is_null() */; } @@ -52,13 +53,13 @@ struct Oracles { Translation model,fear,hope; // feature 0 will be the error rate in fear and hope // move toward hope - FeatureVector ModelHopeGradient() { + FeatureVector ModelHopeGradient() const { FeatureVector r=hope.features-model.features; r.set_value(0,0); return r; } // move toward hope from fear - FeatureVector FearHopeGradient() { + FeatureVector FearHopeGradient() const { FeatureVector r=hope.features-fear.features; r.set_value(0,0); return r; @@ -150,9 +151,9 @@ struct OracleBleu { } // destroys forest (replaces it w/ rescored oracle one) - Oracles ComputeOracles(SentenceMetadata & smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { + Oracle ComputeOracle(SentenceMetadata const& smeta,Hypergraph *forest_in_out,WeightVector const& feature_weights,std::ostream *log=0,unsigned kbest=0,std::string const& forest_output="") { Hypergraph &forest=*forest_in_out; - Oracles r; + Oracle r; int sent_id=smeta.GetSentenceID(); r.model=Translation(forest); if (kbest) DumpKBest("model",sent_id, forest, kbest, true, forest_output); @@ -169,23 +170,24 @@ struct OracleBleu { if (kbest) DumpKBest("negative",sent_id, forest, kbest, true, forest_output); return r; } - typedef std::vector<WordID> Sentence; - void Rescore(SentenceMetadata & smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { + void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { // the sentence bleu stats will get added to doc only if you call IncludeLastScore sentscore=GetScore(forest,smeta.GetSentenceID()); if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? using namespace std; - ModelSet oracle_models(WeightVector(bleu_weight,1),vector<FeatureFunction const*>(1,pff.get())); - const IntersectionConfiguration inter_conf_oracle(0, 0); + DenseWeightVector w; + feature_weights_=feature_weights; + feature_weights_.set_value(0,bleu_weight); + feature_weights.init_vector(&w); + ModelSet oracle_models(w,vector<FeatureFunction const*>(1,pff.get())); if (log) *log << "Going to call Apply Model " << endl; ApplyModelSet(forest, smeta, oracle_models, - inter_conf_oracle, + IntersectionConfiguration(exhaustive_t()), dest_forest); - feature_weights_=feature_weights; ReweightBleu(dest_forest,bleu_weight); } @@ -202,7 +204,7 @@ struct OracleBleu { } void ReweightBleu(Hypergraph *dest_forest,double bleu_weight=-1.) { - feature_weights_[0]=bleu_weight; + feature_weights_.set_value(0,bleu_weight); dest_forest->Reweight(feature_weights_); // dest_forest->SortInEdgesByEdgeWeights(); } @@ -227,7 +229,7 @@ struct OracleBleu { kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; //calculate score in context of psuedo-doc - Score* sentscore = GetScore(d->yield,sent_id); + ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); float bleu = curr_src_length * sentscore->ComputeScore(); kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " diff --git a/decoder/sentences.h b/decoder/sentences.h new file mode 100755 index 00000000..842072b9 --- /dev/null +++ b/decoder/sentences.h @@ -0,0 +1,53 @@ +#ifndef CDEC_SENTENCES_H +#define CDEC_SENTENCES_H + +#include <algorithm> +#include <vector> +#include <iostream> +#include "filelib.h" +#include "tdict.h" +#include "stringlib.h" +typedef std::vector<WordID> Sentence; + +inline void StringToSentence(std::string const& str,Sentence &s) { + using namespace std; + vector<string> ss=SplitOnWhitespace(str); + s.clear(); + transform(ss.begin(),ss.end(),back_inserter(s),ToTD()); +} + +inline Sentence StringToSentence(std::string const& str) { + Sentence s; + StringToSentence(str,s); + return s; +} + +inline std::istream& operator >> (std::istream &in,Sentence &s) { + using namespace std; + string str; + if (getline(in,str)) { + StringToSentence(str,s); + } + return in; +} + + +class Sentences : public std::vector<Sentence> { + typedef std::vector<Sentence> VS; +public: + Sentences() { } + Sentences(unsigned n,Sentence const& sentence) : VS(n,sentence) { } + Sentences(unsigned n,std::string const& sentence) : VS(n,StringToSentence(sentence)) { } + void Load(std::string file) { + ReadFile r(file); + Load(*r.stream()); + } + void Load(std::istream &in) { + this->push_back(Sentence()); + while(in>>this->back()) ; + this->pop_back(); + } +}; + + +#endif diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h index 9c7c9c79..43880014 100644 --- a/decoder/sparse_vector.h +++ b/decoder/sparse_vector.h @@ -12,6 +12,13 @@ #include "fdict.h" +template <class T> +inline T & extend_vector(std::vector<T> &v,int i) { + if (i>=v.size()) + v.resize(i+1); + return v[i]; +} + template <typename T> class SparseVector { public: @@ -29,6 +36,17 @@ public: } + void init_vector(std::vector<T> *vp) const { + init_vector(*vp); + } + + void init_vector(std::vector<T> &v) const { + v.clear(); + for (const_iterator i=values_.begin(),e=values_.end();i!=e;++i) + extend_vector(v,i->first)=i->second; + } + + void set_new_value(int index, T const& val) { assert(values_.find(index)==values_.end()); values_[index]=val; @@ -312,7 +330,7 @@ private: typedef SparseVector<double> FeatureVector; typedef SparseVector<double> WeightVector; - +typedef std::vector<double> DenseWeightVector; template <typename T> SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) { SparseVector<T> result = a; diff --git a/decoder/stringlib.h b/decoder/stringlib.h index eac1dce6..6bb8cff0 100644 --- a/decoder/stringlib.h +++ b/decoder/stringlib.h @@ -1,4 +1,5 @@ -#ifndef _STRINGLIB_H_ +#ifndef CDEC_STRINGLIB_H_ +#define CDEC_STRINGLIB_H_ #include <map> #include <vector> @@ -14,7 +15,7 @@ void ParseTranslatorInput(const std::string& line, std::string* input, std::stri struct Lattice; void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref); -inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") { +inline std::string Trim(const std::string& str, const std::string& dropChars = " \t") { std::string res = str; res.erase(str.find_last_not_of(dropChars)+1); return res.erase(0, res.find_first_not_of(dropChars)); |