diff options
-rw-r--r-- | decoder/filelib.h | 62 | ||||
-rwxr-xr-x | decoder/oracle_bleu.h | 24 | ||||
-rw-r--r-- | decoder/small_vector.h | 2 | ||||
-rwxr-xr-x | vest/line_mediator.pl | 1 | ||||
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 31 | ||||
-rw-r--r-- | vest/scorer.cc | 85 | ||||
-rw-r--r-- | vest/scorer.h | 33 |
7 files changed, 166 insertions, 72 deletions
diff --git a/decoder/filelib.h b/decoder/filelib.h index 03c22b0d..1630481d 100644 --- a/decoder/filelib.h +++ b/decoder/filelib.h @@ -5,6 +5,7 @@ #include <string> #include <iostream> #include <cstdlib> +#include <boost/shared_ptr.hpp> #include "gzstream.h" bool FileExists(const std::string& file_name); @@ -13,35 +14,57 @@ bool DirectoryExists(const std::string& dir_name); // reads from standard in if filename is - // uncompresses if file ends with .gz // otherwise, reads from a normal file +struct file_null_deleter { + void operator()(void*) const {} +}; + class ReadFile { public: - ReadFile(const std::string& filename) : - no_delete_on_exit_(filename == "-"), - in_(no_delete_on_exit_ ? static_cast<std::istream*>(&std::cin) : - (EndsWith(filename, ".gz") ? - static_cast<std::istream*>(new igzstream(filename.c_str())) : - static_cast<std::istream*>(new std::ifstream(filename.c_str())))) { - if (!no_delete_on_exit_ && !FileExists(filename)) { - std::cerr << "File does not exist: " << filename << std::endl; - abort(); - } - if (!*in_) { - std::cerr << "Failed to open " << filename << std::endl; - abort(); + typedef boost::shared_ptr<std::istream> PS; + ReadFile() { } + std::string filename_; + void Init(const std::string& filename) { + bool stdin=(filename == "-"); + if (stdin) { + in_=PS(&std::cin,file_null_deleter()); + } else { + if (!FileExists(filename)) { + std::cerr << "File does not exist: " << filename << std::endl; + abort(); + } + filename_=filename; + in_.reset(EndsWith(filename, ".gz") ? + static_cast<std::istream*>(new igzstream(filename.c_str())) : + static_cast<std::istream*>(new std::ifstream(filename.c_str()))); + if (!*in_) { + std::cerr << "Failed to open " << filename << std::endl; + abort(); + } } } + void Reset() { + in_.reset(); + } + bool is_null() const { return !in_; } + operator bool() const { + return in_; + } + + explicit ReadFile(const std::string& filename) { + Init(filename); + } ~ReadFile() { - if (!no_delete_on_exit_) delete in_; } - inline std::istream* stream() { return in_; } - + std::istream* stream() { return in_.get(); } + std::istream* operator->() { return in_.get(); } // compat with old ReadFile * -> new Readfile. remove? + std::istream &get() const { return *in_; } + private: static bool EndsWith(const std::string& f, const std::string& suf) { return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); } - const bool no_delete_on_exit_; - std::istream* const in_; + PS in_; }; class WriteFile { @@ -58,7 +81,8 @@ class WriteFile { } inline std::ostream* stream() { return out_; } - + std::ostream &get() const { return *out_; } + private: static bool EndsWith(const std::string& f, const std::string& suf) { return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 66d155d3..4800e9c1 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -1,6 +1,8 @@ #ifndef ORACLE_BLEU_H #define ORACLE_BLEU_H +#define DEBUG_ORACLE_BLEU + #include <sstream> #include <iostream> #include <string> @@ -20,7 +22,7 @@ #include "sentences.h" //TODO: put function impls into .cc -//TODO: disentangle +//TODO: move Translation into its own .h and use in cdec struct Translation { typedef std::vector<WordID> Sentence; Sentence sentence; @@ -153,11 +155,16 @@ struct OracleBleu { init_refs(); } void init_refs() { - if (is_null()) return; + if (is_null()) { +#ifdef DEBUG_ORACLE_BLEU + std::cerr<<"No references for oracle BLEU.\n"; +#endif + return; + } assert(refs.size()); - ds=DocScorer(loss,refs); - doc_score.reset(); -// doc_score=sentscore + ds.Init(loss,refs); + ensure_doc_score(); +// doc_score.reset(); std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl; } @@ -193,10 +200,15 @@ struct OracleBleu { return r; } + // if doc_score wasn't init, add 1 counts to ngram acc. + void ensure_doc_score() { + if (!doc_score) { doc_score.reset(Score::GetOne(loss)); } + } + void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) { // the sentence bleu stats will get added to doc only if you call IncludeLastScore + ensure_doc_score(); sentscore=GetScore(forest,smeta.GetSentenceID()); - if (!doc_score) { doc_score.reset(sentscore->GetOne()); } tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from? using namespace std; DenseWeightVector w; diff --git a/decoder/small_vector.h b/decoder/small_vector.h index c3090a8b..202b72c9 100644 --- a/decoder/small_vector.h +++ b/decoder/small_vector.h @@ -11,7 +11,7 @@ #include <streambuf> // std::max - where to get this? #include <cstring> #include <cassert> -#include <limits.h> +#include <stdint.h> #include <new> #include <stdint.h> //sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1 diff --git a/vest/line_mediator.pl b/vest/line_mediator.pl index bc2bb24c..a47c5d1d 100755 --- a/vest/line_mediator.pl +++ b/vest/line_mediator.pl @@ -89,6 +89,7 @@ if ($ser eq 'SERIAL') { lineto(*STDOUT,$_); } } + } } else { info("DIRECT mode\n"); my @rw1=POSIX::pipe(); diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 4da5326f..5b513f9b 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -110,37 +110,38 @@ struct oracle_directions { po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; + po::notify(*conf); if (conf->count("dev_set_size") == 0) { cerr << "Please specify the size of the development set using -d N\n"; - flag = true; + goto bad_cmdline; } if (conf->count("weights") == 0) { cerr << "Please specify the starting-point weights using -w <weightfile.txt>\n"; - flag = true; + goto bad_cmdline; } if (conf->count("forest_repository") == 0) { cerr << "Please specify the forest repository location using -r <DIR>\n"; - flag = true; + goto bad_cmdline; } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); + if (n_oracle && oracle.refs.empty()) { + cerr<<"Specify references when using oracle directions\n"; + goto bad_cmdline; } - po::notify(*conf); - - if (0) { - dev_set_size = (*conf)["dev_set_size"].as<unsigned>(); - forest_repository = (*conf)["forest_repository"].as<string>(); - weights_file = (*conf)["weights"].as<string>(); - n_random = (*conf)["random_directions"].as<unsigned>(); + if (conf->count("help")) { + cout << dcmdline_options << endl; + exit(0); } + + UseConf(*conf); + return; + bad_cmdline: + cerr << dcmdline_options << endl; + exit(1); } int main(int argc, char *argv[]) { po::variables_map conf; InitCommandLine(argc,argv,&conf); - UseConf(conf); Run(); return 0; } diff --git a/vest/scorer.cc b/vest/scorer.cc index e8e9608a..d8628418 100644 --- a/vest/scorer.cc +++ b/vest/scorer.cc @@ -1,5 +1,8 @@ #include "scorer.h" +#define DEBUG_SCORER + +#include <boost/lexical_cast.hpp> #include <map> #include <sstream> #include <iostream> @@ -121,9 +124,13 @@ class SERScore : public Score { int correct, total; }; +std::string SentenceScorer::verbose_desc() const { + return desc+",ref0={ "+TD::GetString(refs[0])+" }"; +} + class SERScorer : public SentenceScorer { public: - SERScorer(const vector<vector<WordID> >& references) : refs_(references) {} + SERScorer(const vector<vector<WordID> >& references) : SentenceScorer("SERScorer",references),refs_(references) {} Score* ScoreCCandidate(const vector<WordID>& /* hyp */) const { Score* a = NULL; return a; @@ -180,7 +187,7 @@ class BLEUScore : public Score { class BLEUScorerBase : public SentenceScorer { public: BLEUScorerBase(const vector<vector<WordID> >& references, - int n + int n ); Score* ScoreCandidate(const vector<WordID>& hyp) const; Score* ScoreCCandidate(const vector<WordID>& hyp) const; @@ -353,6 +360,33 @@ SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type, } } +Score* SentenceScorer::GetOne() const { + Sentence s; + return ScoreCCandidate(s)->GetOne(); +} + +Score* SentenceScorer::GetZero() const { + Sentence s; + return ScoreCCandidate(s)->GetZero(); +} + +Score* Score::GetOne(ScoreType type) { + std::vector<SentenceScorer::Sentence > refs; + SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs); + Score *s=ps->GetOne(); + delete ps; + return s; +} + +Score* Score::GetZero(ScoreType type) { + std::vector<SentenceScorer::Sentence > refs; + SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs); + Score *s=ps->GetZero(); + delete ps; + return s; +} + + Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const string& in) { switch (type) { case IBM_BLEU: @@ -562,6 +596,7 @@ Score* BLEUScore::GetOne() const { return new BLEUScore(hyp_ngram_counts.size(),1); } + void BLEUScore::Encode(string* out) const { ostringstream os; const int n = correct_ngram_hit_counts.size(); @@ -572,7 +607,7 @@ void BLEUScore::Encode(string* out) const { } BLEUScorerBase::BLEUScorerBase(const vector<vector<WordID> >& references, - int n) : n_(n) { + int n) : SentenceScorer("BLEU"+boost::lexical_cast<string>(n),references),n_(n) { for (vector<vector<WordID> >::const_iterator ci = references.begin(); ci != references.end(); ++ci) { lengths_.push_back(ci->size()); @@ -603,42 +638,40 @@ Score* BLEUScorerBase::ScoreCCandidate(const vector<WordID>& hyp) const { DocScorer::~DocScorer() { - for (int i=0; i < scorers_.size(); ++i) - delete scorers_[i]; } -DocScorer::DocScorer( +void DocScorer::Init( const ScoreType type, const vector<string>& ref_files, const string& src_file) { + scorers_.clear(); // TODO stop using valarray, start using ReadFile cerr << "Loading references (" << ref_files.size() << " files)\n"; - shared_ptr<ReadFile> srcrf; + ReadFile srcrf; if (type == AER && src_file.size() > 0) { cerr << " (source=" << src_file << ")\n"; - srcrf.reset(new ReadFile(src_file)); - } - valarray<ifstream> ifs(ref_files.size()); - for (int i=0; i < ref_files.size(); ++i) { - ifs[i].open(ref_files[i].c_str()); - assert(ifs[i].good()); + srcrf.Init(src_file); } + std::vector<ReadFile> ifs(ref_files.begin(),ref_files.end()); + for (int i=0; i < ref_files.size(); ++i) ifs[i].Init(ref_files[i]); char buf[64000]; bool expect_eof = false; - while (!ifs[0].eof()) { + int line=0; + while (ifs[0].get()) { vector<vector<WordID> > refs(ref_files.size()); for (int i=0; i < ref_files.size(); ++i) { - if (ifs[i].eof()) break; - ifs[i].getline(buf, 64000); + istream &in=ifs[i].get(); + if (in.eof()) break; + in.getline(buf, 64000); refs[i].clear(); if (strlen(buf) == 0) { - if (ifs[i].eof()) { - if (!expect_eof) { - assert(i == 0); - expect_eof = true; - } + if (in.eof()) { + if (!expect_eof) { + assert(i == 0); + expect_eof = true; + } break; - } + } } else { TD::ConvertSentence(buf, &refs[i]); assert(!refs[i].empty()); @@ -648,11 +681,15 @@ DocScorer::DocScorer( if (!expect_eof) { string src_line; if (srcrf) { - getline(*srcrf->stream(), src_line); + getline(srcrf.get(), src_line); map<string,string> dummy; ProcessAndStripSGML(&src_line, &dummy); } - scorers_.push_back(SentenceScorer::CreateSentenceScorer(type, refs, src_line)); + scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); +#ifdef DEBUG_SCORER + cerr<<"doc_scorer["<<line<<"] = "<<scorers_.back()->verbose_desc()<<endl; +#endif + ++line; } } cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n"; diff --git a/vest/scorer.h b/vest/scorer.h index 9c8aebcc..cc6b7335 100644 --- a/vest/scorer.h +++ b/vest/scorer.h @@ -1,8 +1,8 @@ #ifndef SCORER_H_ #define SCORER_H_ - #include <vector> #include <string> +#include <boost/shared_ptr.hpp> #include "wordid.h" @@ -16,6 +16,7 @@ std::string StringFromScoreType(ScoreType st); class Score { public: + typedef boost::shared_ptr<Score> ScoreP; virtual ~Score(); virtual float ComputeScore() const = 0; virtual float ComputePartialScore() const =0; @@ -35,13 +36,24 @@ class Score { // to another score results in no score change // under any circumstances virtual void Encode(std::string* out) const = 0; + static Score* GetZero(ScoreType type); + static Score* GetOne(ScoreType type); }; class SentenceScorer { public: + typedef boost::shared_ptr<Score> ScoreP; + typedef boost::shared_ptr<SentenceScorer> ScorerP; typedef std::vector<WordID> Sentence; + typedef std::vector<Sentence> Sentences; + std::string desc; + Sentences refs; + SentenceScorer(std::string desc="SentenceScorer_unknown", Sentences const& refs=Sentences()) : desc(desc),refs(refs) { } + std::string verbose_desc() const; virtual float ComputeRefLength(const Sentence& hyp) const; // default: avg of refs.length virtual ~SentenceScorer(); + virtual Score* GetOne() const; + virtual Score* GetZero() const; void ComputeErrorSurface(const ViterbiEnvelope& ve, ErrorSurface* es, const ScoreType type, const Hypergraph& hg) const; virtual Score* ScoreCandidate(const Sentence& hyp) const = 0; virtual Score* ScoreCCandidate(const Sentence& hyp) const =0; @@ -57,14 +69,21 @@ class DocScorer { public: ~DocScorer(); DocScorer() { } - DocScorer( - const ScoreType type, - const std::vector<std::string>& ref_files, - const std::string& src_file = ""); + void Init(const ScoreType type, + const std::vector<std::string>& ref_files, + const std::string& src_file = ""); + DocScorer(const ScoreType type, + const std::vector<std::string>& ref_files, + const std::string& src_file = "") + { + Init(type,ref_files,src_file); + } + int size() const { return scorers_.size(); } - const SentenceScorer* operator[](size_t i) const { return scorers_[i]; } + typedef boost::shared_ptr<SentenceScorer> ScorerP; + ScorerP operator[](size_t i) const { return scorers_[i]; } private: - std::vector<SentenceScorer*> scorers_; + std::vector<ScorerP> scorers_; }; #endif |