diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-19 21:33:17 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-19 21:33:17 +0000 |
commit | a429a0f0f510751b842b7161f5e1c8feab05fc1b (patch) | |
tree | 0f1667737581b3b6dc5b9cff8d904a596e9d2335 /vest | |
parent | 648c350aa8b90bd60ca1448bd3bb702004b9ad26 (diff) |
shared_ptr for ReadFile and doc_scorer; init ds to GetOne() in oracle
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@322 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'vest')
-rwxr-xr-x | vest/line_mediator.pl | 1 | ||||
-rw-r--r-- | vest/mr_vest_generate_mapper_input.cc | 31 | ||||
-rw-r--r-- | vest/scorer.cc | 85 | ||||
-rw-r--r-- | vest/scorer.h | 33 |
4 files changed, 104 insertions, 46 deletions
diff --git a/vest/line_mediator.pl b/vest/line_mediator.pl index bc2bb24c..a47c5d1d 100755 --- a/vest/line_mediator.pl +++ b/vest/line_mediator.pl @@ -89,6 +89,7 @@ if ($ser eq 'SERIAL') { lineto(*STDOUT,$_); } } + } } else { info("DIRECT mode\n"); my @rw1=POSIX::pipe(); diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc index 4da5326f..5b513f9b 100644 --- a/vest/mr_vest_generate_mapper_input.cc +++ b/vest/mr_vest_generate_mapper_input.cc @@ -110,37 +110,38 @@ struct oracle_directions { po::options_description dcmdline_options; dcmdline_options.add(opts); po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; + po::notify(*conf); if (conf->count("dev_set_size") == 0) { cerr << "Please specify the size of the development set using -d N\n"; - flag = true; + goto bad_cmdline; } if (conf->count("weights") == 0) { cerr << "Please specify the starting-point weights using -w <weightfile.txt>\n"; - flag = true; + goto bad_cmdline; } if (conf->count("forest_repository") == 0) { cerr << "Please specify the forest repository location using -r <DIR>\n"; - flag = true; + goto bad_cmdline; } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); + if (n_oracle && oracle.refs.empty()) { + cerr<<"Specify references when using oracle directions\n"; + goto bad_cmdline; } - po::notify(*conf); - - if (0) { - dev_set_size = (*conf)["dev_set_size"].as<unsigned>(); - forest_repository = (*conf)["forest_repository"].as<string>(); - weights_file = (*conf)["weights"].as<string>(); - n_random = (*conf)["random_directions"].as<unsigned>(); + if (conf->count("help")) { + cout << dcmdline_options << endl; + exit(0); } + + UseConf(*conf); + return; + bad_cmdline: + cerr << dcmdline_options << endl; + exit(1); } int main(int argc, char *argv[]) { po::variables_map conf; InitCommandLine(argc,argv,&conf); - UseConf(conf); Run(); return 0; } diff --git a/vest/scorer.cc b/vest/scorer.cc index e8e9608a..d8628418 100644 --- a/vest/scorer.cc +++ b/vest/scorer.cc @@ -1,5 +1,8 @@ #include "scorer.h" +#define DEBUG_SCORER + +#include <boost/lexical_cast.hpp> #include <map> #include <sstream> #include <iostream> @@ -121,9 +124,13 @@ class SERScore : public Score { int correct, total; }; +std::string SentenceScorer::verbose_desc() const { + return desc+",ref0={ "+TD::GetString(refs[0])+" }"; +} + class SERScorer : public SentenceScorer { public: - SERScorer(const vector<vector<WordID> >& references) : refs_(references) {} + SERScorer(const vector<vector<WordID> >& references) : SentenceScorer("SERScorer",references),refs_(references) {} Score* ScoreCCandidate(const vector<WordID>& /* hyp */) const { Score* a = NULL; return a; @@ -180,7 +187,7 @@ class BLEUScore : public Score { class BLEUScorerBase : public SentenceScorer { public: BLEUScorerBase(const vector<vector<WordID> >& references, - int n + int n ); Score* ScoreCandidate(const vector<WordID>& hyp) const; Score* ScoreCCandidate(const vector<WordID>& hyp) const; @@ -353,6 +360,33 @@ SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type, } } +Score* SentenceScorer::GetOne() const { + Sentence s; + return ScoreCCandidate(s)->GetOne(); +} + +Score* SentenceScorer::GetZero() const { + Sentence s; + return ScoreCCandidate(s)->GetZero(); +} + +Score* Score::GetOne(ScoreType type) { + std::vector<SentenceScorer::Sentence > refs; + SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs); + Score *s=ps->GetOne(); + delete ps; + return s; +} + +Score* Score::GetZero(ScoreType type) { + std::vector<SentenceScorer::Sentence > refs; + SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs); + Score *s=ps->GetZero(); + delete ps; + return s; +} + + Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const string& in) { switch (type) { case IBM_BLEU: @@ -562,6 +596,7 @@ Score* BLEUScore::GetOne() const { return new BLEUScore(hyp_ngram_counts.size(),1); } + void BLEUScore::Encode(string* out) const { ostringstream os; const int n = correct_ngram_hit_counts.size(); @@ -572,7 +607,7 @@ void BLEUScore::Encode(string* out) const { } BLEUScorerBase::BLEUScorerBase(const vector<vector<WordID> >& references, - int n) : n_(n) { + int n) : SentenceScorer("BLEU"+boost::lexical_cast<string>(n),references),n_(n) { for (vector<vector<WordID> >::const_iterator ci = references.begin(); ci != references.end(); ++ci) { lengths_.push_back(ci->size()); @@ -603,42 +638,40 @@ Score* BLEUScorerBase::ScoreCCandidate(const vector<WordID>& hyp) const { DocScorer::~DocScorer() { - for (int i=0; i < scorers_.size(); ++i) - delete scorers_[i]; } -DocScorer::DocScorer( +void DocScorer::Init( const ScoreType type, const vector<string>& ref_files, const string& src_file) { + scorers_.clear(); // TODO stop using valarray, start using ReadFile cerr << "Loading references (" << ref_files.size() << " files)\n"; - shared_ptr<ReadFile> srcrf; + ReadFile srcrf; if (type == AER && src_file.size() > 0) { cerr << " (source=" << src_file << ")\n"; - srcrf.reset(new ReadFile(src_file)); - } - valarray<ifstream> ifs(ref_files.size()); - for (int i=0; i < ref_files.size(); ++i) { - ifs[i].open(ref_files[i].c_str()); - assert(ifs[i].good()); + srcrf.Init(src_file); } + std::vector<ReadFile> ifs(ref_files.begin(),ref_files.end()); + for (int i=0; i < ref_files.size(); ++i) ifs[i].Init(ref_files[i]); char buf[64000]; bool expect_eof = false; - while (!ifs[0].eof()) { + int line=0; + while (ifs[0].get()) { vector<vector<WordID> > refs(ref_files.size()); for (int i=0; i < ref_files.size(); ++i) { - if (ifs[i].eof()) break; - ifs[i].getline(buf, 64000); + istream &in=ifs[i].get(); + if (in.eof()) break; + in.getline(buf, 64000); refs[i].clear(); if (strlen(buf) == 0) { - if (ifs[i].eof()) { - if (!expect_eof) { - assert(i == 0); - expect_eof = true; - } + if (in.eof()) { + if (!expect_eof) { + assert(i == 0); + expect_eof = true; + } break; - } + } } else { TD::ConvertSentence(buf, &refs[i]); assert(!refs[i].empty()); @@ -648,11 +681,15 @@ DocScorer::DocScorer( if (!expect_eof) { string src_line; if (srcrf) { - getline(*srcrf->stream(), src_line); + getline(srcrf.get(), src_line); map<string,string> dummy; ProcessAndStripSGML(&src_line, &dummy); } - scorers_.push_back(SentenceScorer::CreateSentenceScorer(type, refs, src_line)); + scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line))); +#ifdef DEBUG_SCORER + cerr<<"doc_scorer["<<line<<"] = "<<scorers_.back()->verbose_desc()<<endl; +#endif + ++line; } } cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n"; diff --git a/vest/scorer.h b/vest/scorer.h index 9c8aebcc..cc6b7335 100644 --- a/vest/scorer.h +++ b/vest/scorer.h @@ -1,8 +1,8 @@ #ifndef SCORER_H_ #define SCORER_H_ - #include <vector> #include <string> +#include <boost/shared_ptr.hpp> #include "wordid.h" @@ -16,6 +16,7 @@ std::string StringFromScoreType(ScoreType st); class Score { public: + typedef boost::shared_ptr<Score> ScoreP; virtual ~Score(); virtual float ComputeScore() const = 0; virtual float ComputePartialScore() const =0; @@ -35,13 +36,24 @@ class Score { // to another score results in no score change // under any circumstances virtual void Encode(std::string* out) const = 0; + static Score* GetZero(ScoreType type); + static Score* GetOne(ScoreType type); }; class SentenceScorer { public: + typedef boost::shared_ptr<Score> ScoreP; + typedef boost::shared_ptr<SentenceScorer> ScorerP; typedef std::vector<WordID> Sentence; + typedef std::vector<Sentence> Sentences; + std::string desc; + Sentences refs; + SentenceScorer(std::string desc="SentenceScorer_unknown", Sentences const& refs=Sentences()) : desc(desc),refs(refs) { } + std::string verbose_desc() const; virtual float ComputeRefLength(const Sentence& hyp) const; // default: avg of refs.length virtual ~SentenceScorer(); + virtual Score* GetOne() const; + virtual Score* GetZero() const; void ComputeErrorSurface(const ViterbiEnvelope& ve, ErrorSurface* es, const ScoreType type, const Hypergraph& hg) const; virtual Score* ScoreCandidate(const Sentence& hyp) const = 0; virtual Score* ScoreCCandidate(const Sentence& hyp) const =0; @@ -57,14 +69,21 @@ class DocScorer { public: ~DocScorer(); DocScorer() { } - DocScorer( - const ScoreType type, - const std::vector<std::string>& ref_files, - const std::string& src_file = ""); + void Init(const ScoreType type, + const std::vector<std::string>& ref_files, + const std::string& src_file = ""); + DocScorer(const ScoreType type, + const std::vector<std::string>& ref_files, + const std::string& src_file = "") + { + Init(type,ref_files,src_file); + } + int size() const { return scorers_.size(); } - const SentenceScorer* operator[](size_t i) const { return scorers_[i]; } + typedef boost::shared_ptr<SentenceScorer> ScorerP; + ScorerP operator[](size_t i) const { return scorers_[i]; } private: - std::vector<SentenceScorer*> scorers_; + std::vector<ScorerP> scorers_; }; #endif |