summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 21:33:17 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 21:33:17 +0000
commit9e35239dd1b4393a320da6c745749500dba8f2b6 (patch)
treef7ebb65dfe3a87b1515e4746dd20a4a789e5e49b
parent96869a482482a4ef0ee8b101ab32cc10219cc3d4 (diff)
shared_ptr for ReadFile and doc_scorer; init ds to GetOne() in oracle
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@322 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--decoder/filelib.h62
-rwxr-xr-xdecoder/oracle_bleu.h24
-rw-r--r--decoder/small_vector.h2
-rwxr-xr-xvest/line_mediator.pl1
-rw-r--r--vest/mr_vest_generate_mapper_input.cc31
-rw-r--r--vest/scorer.cc85
-rw-r--r--vest/scorer.h33
7 files changed, 166 insertions, 72 deletions
diff --git a/decoder/filelib.h b/decoder/filelib.h
index 03c22b0d..1630481d 100644
--- a/decoder/filelib.h
+++ b/decoder/filelib.h
@@ -5,6 +5,7 @@
#include <string>
#include <iostream>
#include <cstdlib>
+#include <boost/shared_ptr.hpp>
#include "gzstream.h"
bool FileExists(const std::string& file_name);
@@ -13,35 +14,57 @@ bool DirectoryExists(const std::string& dir_name);
// reads from standard in if filename is -
// uncompresses if file ends with .gz
// otherwise, reads from a normal file
+struct file_null_deleter {
+ void operator()(void*) const {}
+};
+
class ReadFile {
public:
- ReadFile(const std::string& filename) :
- no_delete_on_exit_(filename == "-"),
- in_(no_delete_on_exit_ ? static_cast<std::istream*>(&std::cin) :
- (EndsWith(filename, ".gz") ?
- static_cast<std::istream*>(new igzstream(filename.c_str())) :
- static_cast<std::istream*>(new std::ifstream(filename.c_str())))) {
- if (!no_delete_on_exit_ && !FileExists(filename)) {
- std::cerr << "File does not exist: " << filename << std::endl;
- abort();
- }
- if (!*in_) {
- std::cerr << "Failed to open " << filename << std::endl;
- abort();
+ typedef boost::shared_ptr<std::istream> PS;
+ ReadFile() { }
+ std::string filename_;
+ void Init(const std::string& filename) {
+ bool stdin=(filename == "-");
+ if (stdin) {
+ in_=PS(&std::cin,file_null_deleter());
+ } else {
+ if (!FileExists(filename)) {
+ std::cerr << "File does not exist: " << filename << std::endl;
+ abort();
+ }
+ filename_=filename;
+ in_.reset(EndsWith(filename, ".gz") ?
+ static_cast<std::istream*>(new igzstream(filename.c_str())) :
+ static_cast<std::istream*>(new std::ifstream(filename.c_str())));
+ if (!*in_) {
+ std::cerr << "Failed to open " << filename << std::endl;
+ abort();
+ }
}
}
+ void Reset() {
+ in_.reset();
+ }
+ bool is_null() const { return !in_; }
+ operator bool() const {
+ return in_;
+ }
+
+ explicit ReadFile(const std::string& filename) {
+ Init(filename);
+ }
~ReadFile() {
- if (!no_delete_on_exit_) delete in_;
}
- inline std::istream* stream() { return in_; }
-
+ std::istream* stream() { return in_.get(); }
+ std::istream* operator->() { return in_.get(); } // compat with old ReadFile * -> new Readfile. remove?
+ std::istream &get() const { return *in_; }
+
private:
static bool EndsWith(const std::string& f, const std::string& suf) {
return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
}
- const bool no_delete_on_exit_;
- std::istream* const in_;
+ PS in_;
};
class WriteFile {
@@ -58,7 +81,8 @@ class WriteFile {
}
inline std::ostream* stream() { return out_; }
-
+ std::ostream &get() const { return *out_; }
+
private:
static bool EndsWith(const std::string& f, const std::string& suf) {
return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 66d155d3..4800e9c1 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -1,6 +1,8 @@
#ifndef ORACLE_BLEU_H
#define ORACLE_BLEU_H
+#define DEBUG_ORACLE_BLEU
+
#include <sstream>
#include <iostream>
#include <string>
@@ -20,7 +22,7 @@
#include "sentences.h"
//TODO: put function impls into .cc
-//TODO: disentangle
+//TODO: move Translation into its own .h and use in cdec
struct Translation {
typedef std::vector<WordID> Sentence;
Sentence sentence;
@@ -153,11 +155,16 @@ struct OracleBleu {
init_refs();
}
void init_refs() {
- if (is_null()) return;
+ if (is_null()) {
+#ifdef DEBUG_ORACLE_BLEU
+ std::cerr<<"No references for oracle BLEU.\n";
+#endif
+ return;
+ }
assert(refs.size());
- ds=DocScorer(loss,refs);
- doc_score.reset();
-// doc_score=sentscore
+ ds.Init(loss,refs);
+ ensure_doc_score();
+// doc_score.reset();
std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl;
}
@@ -193,10 +200,15 @@ struct OracleBleu {
return r;
}
+ // if doc_score wasn't init, add 1 counts to ngram acc.
+ void ensure_doc_score() {
+ if (!doc_score) { doc_score.reset(Score::GetOne(loss)); }
+ }
+
void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
+ ensure_doc_score();
sentscore=GetScore(forest,smeta.GetSentenceID());
- if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
using namespace std;
DenseWeightVector w;
diff --git a/decoder/small_vector.h b/decoder/small_vector.h
index c3090a8b..202b72c9 100644
--- a/decoder/small_vector.h
+++ b/decoder/small_vector.h
@@ -11,7 +11,7 @@
#include <streambuf> // std::max - where to get this?
#include <cstring>
#include <cassert>
-#include <limits.h>
+#include <stdint.h>
#include <new>
#include <stdint.h>
//sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1
diff --git a/vest/line_mediator.pl b/vest/line_mediator.pl
index bc2bb24c..a47c5d1d 100755
--- a/vest/line_mediator.pl
+++ b/vest/line_mediator.pl
@@ -89,6 +89,7 @@ if ($ser eq 'SERIAL') {
lineto(*STDOUT,$_);
}
}
+ }
} else {
info("DIRECT mode\n");
my @rw1=POSIX::pipe();
diff --git a/vest/mr_vest_generate_mapper_input.cc b/vest/mr_vest_generate_mapper_input.cc
index 4da5326f..5b513f9b 100644
--- a/vest/mr_vest_generate_mapper_input.cc
+++ b/vest/mr_vest_generate_mapper_input.cc
@@ -110,37 +110,38 @@ struct oracle_directions {
po::options_description dcmdline_options;
dcmdline_options.add(opts);
po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
- bool flag = false;
+ po::notify(*conf);
if (conf->count("dev_set_size") == 0) {
cerr << "Please specify the size of the development set using -d N\n";
- flag = true;
+ goto bad_cmdline;
}
if (conf->count("weights") == 0) {
cerr << "Please specify the starting-point weights using -w <weightfile.txt>\n";
- flag = true;
+ goto bad_cmdline;
}
if (conf->count("forest_repository") == 0) {
cerr << "Please specify the forest repository location using -r <DIR>\n";
- flag = true;
+ goto bad_cmdline;
}
- if (flag || conf->count("help")) {
- cerr << dcmdline_options << endl;
- exit(1);
+ if (n_oracle && oracle.refs.empty()) {
+ cerr<<"Specify references when using oracle directions\n";
+ goto bad_cmdline;
}
- po::notify(*conf);
-
- if (0) {
- dev_set_size = (*conf)["dev_set_size"].as<unsigned>();
- forest_repository = (*conf)["forest_repository"].as<string>();
- weights_file = (*conf)["weights"].as<string>();
- n_random = (*conf)["random_directions"].as<unsigned>();
+ if (conf->count("help")) {
+ cout << dcmdline_options << endl;
+ exit(0);
}
+
+ UseConf(*conf);
+ return;
+ bad_cmdline:
+ cerr << dcmdline_options << endl;
+ exit(1);
}
int main(int argc, char *argv[]) {
po::variables_map conf;
InitCommandLine(argc,argv,&conf);
- UseConf(conf);
Run();
return 0;
}
diff --git a/vest/scorer.cc b/vest/scorer.cc
index e8e9608a..d8628418 100644
--- a/vest/scorer.cc
+++ b/vest/scorer.cc
@@ -1,5 +1,8 @@
#include "scorer.h"
+#define DEBUG_SCORER
+
+#include <boost/lexical_cast.hpp>
#include <map>
#include <sstream>
#include <iostream>
@@ -121,9 +124,13 @@ class SERScore : public Score {
int correct, total;
};
+std::string SentenceScorer::verbose_desc() const {
+ return desc+",ref0={ "+TD::GetString(refs[0])+" }";
+}
+
class SERScorer : public SentenceScorer {
public:
- SERScorer(const vector<vector<WordID> >& references) : refs_(references) {}
+ SERScorer(const vector<vector<WordID> >& references) : SentenceScorer("SERScorer",references),refs_(references) {}
Score* ScoreCCandidate(const vector<WordID>& /* hyp */) const {
Score* a = NULL;
return a;
@@ -180,7 +187,7 @@ class BLEUScore : public Score {
class BLEUScorerBase : public SentenceScorer {
public:
BLEUScorerBase(const vector<vector<WordID> >& references,
- int n
+ int n
);
Score* ScoreCandidate(const vector<WordID>& hyp) const;
Score* ScoreCCandidate(const vector<WordID>& hyp) const;
@@ -353,6 +360,33 @@ SentenceScorer* SentenceScorer::CreateSentenceScorer(const ScoreType type,
}
}
+Score* SentenceScorer::GetOne() const {
+ Sentence s;
+ return ScoreCCandidate(s)->GetOne();
+}
+
+Score* SentenceScorer::GetZero() const {
+ Sentence s;
+ return ScoreCCandidate(s)->GetZero();
+}
+
+Score* Score::GetOne(ScoreType type) {
+ std::vector<SentenceScorer::Sentence > refs;
+ SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs);
+ Score *s=ps->GetOne();
+ delete ps;
+ return s;
+}
+
+Score* Score::GetZero(ScoreType type) {
+ std::vector<SentenceScorer::Sentence > refs;
+ SentenceScorer *ps=SentenceScorer::CreateSentenceScorer(type,refs);
+ Score *s=ps->GetZero();
+ delete ps;
+ return s;
+}
+
+
Score* SentenceScorer::CreateScoreFromString(const ScoreType type, const string& in) {
switch (type) {
case IBM_BLEU:
@@ -562,6 +596,7 @@ Score* BLEUScore::GetOne() const {
return new BLEUScore(hyp_ngram_counts.size(),1);
}
+
void BLEUScore::Encode(string* out) const {
ostringstream os;
const int n = correct_ngram_hit_counts.size();
@@ -572,7 +607,7 @@ void BLEUScore::Encode(string* out) const {
}
BLEUScorerBase::BLEUScorerBase(const vector<vector<WordID> >& references,
- int n) : n_(n) {
+ int n) : SentenceScorer("BLEU"+boost::lexical_cast<string>(n),references),n_(n) {
for (vector<vector<WordID> >::const_iterator ci = references.begin();
ci != references.end(); ++ci) {
lengths_.push_back(ci->size());
@@ -603,42 +638,40 @@ Score* BLEUScorerBase::ScoreCCandidate(const vector<WordID>& hyp) const {
DocScorer::~DocScorer() {
- for (int i=0; i < scorers_.size(); ++i)
- delete scorers_[i];
}
-DocScorer::DocScorer(
+void DocScorer::Init(
const ScoreType type,
const vector<string>& ref_files,
const string& src_file) {
+ scorers_.clear();
// TODO stop using valarray, start using ReadFile
cerr << "Loading references (" << ref_files.size() << " files)\n";
- shared_ptr<ReadFile> srcrf;
+ ReadFile srcrf;
if (type == AER && src_file.size() > 0) {
cerr << " (source=" << src_file << ")\n";
- srcrf.reset(new ReadFile(src_file));
- }
- valarray<ifstream> ifs(ref_files.size());
- for (int i=0; i < ref_files.size(); ++i) {
- ifs[i].open(ref_files[i].c_str());
- assert(ifs[i].good());
+ srcrf.Init(src_file);
}
+ std::vector<ReadFile> ifs(ref_files.begin(),ref_files.end());
+ for (int i=0; i < ref_files.size(); ++i) ifs[i].Init(ref_files[i]);
char buf[64000];
bool expect_eof = false;
- while (!ifs[0].eof()) {
+ int line=0;
+ while (ifs[0].get()) {
vector<vector<WordID> > refs(ref_files.size());
for (int i=0; i < ref_files.size(); ++i) {
- if (ifs[i].eof()) break;
- ifs[i].getline(buf, 64000);
+ istream &in=ifs[i].get();
+ if (in.eof()) break;
+ in.getline(buf, 64000);
refs[i].clear();
if (strlen(buf) == 0) {
- if (ifs[i].eof()) {
- if (!expect_eof) {
- assert(i == 0);
- expect_eof = true;
- }
+ if (in.eof()) {
+ if (!expect_eof) {
+ assert(i == 0);
+ expect_eof = true;
+ }
break;
- }
+ }
} else {
TD::ConvertSentence(buf, &refs[i]);
assert(!refs[i].empty());
@@ -648,11 +681,15 @@ DocScorer::DocScorer(
if (!expect_eof) {
string src_line;
if (srcrf) {
- getline(*srcrf->stream(), src_line);
+ getline(srcrf.get(), src_line);
map<string,string> dummy;
ProcessAndStripSGML(&src_line, &dummy);
}
- scorers_.push_back(SentenceScorer::CreateSentenceScorer(type, refs, src_line));
+ scorers_.push_back(ScorerP(SentenceScorer::CreateSentenceScorer(type, refs, src_line)));
+#ifdef DEBUG_SCORER
+ cerr<<"doc_scorer["<<line<<"] = "<<scorers_.back()->verbose_desc()<<endl;
+#endif
+ ++line;
}
}
cerr << "Loaded reference translations for " << scorers_.size() << " sentences.\n";
diff --git a/vest/scorer.h b/vest/scorer.h
index 9c8aebcc..cc6b7335 100644
--- a/vest/scorer.h
+++ b/vest/scorer.h
@@ -1,8 +1,8 @@
#ifndef SCORER_H_
#define SCORER_H_
-
#include <vector>
#include <string>
+#include <boost/shared_ptr.hpp>
#include "wordid.h"
@@ -16,6 +16,7 @@ std::string StringFromScoreType(ScoreType st);
class Score {
public:
+ typedef boost::shared_ptr<Score> ScoreP;
virtual ~Score();
virtual float ComputeScore() const = 0;
virtual float ComputePartialScore() const =0;
@@ -35,13 +36,24 @@ class Score {
// to another score results in no score change
// under any circumstances
virtual void Encode(std::string* out) const = 0;
+ static Score* GetZero(ScoreType type);
+ static Score* GetOne(ScoreType type);
};
class SentenceScorer {
public:
+ typedef boost::shared_ptr<Score> ScoreP;
+ typedef boost::shared_ptr<SentenceScorer> ScorerP;
typedef std::vector<WordID> Sentence;
+ typedef std::vector<Sentence> Sentences;
+ std::string desc;
+ Sentences refs;
+ SentenceScorer(std::string desc="SentenceScorer_unknown", Sentences const& refs=Sentences()) : desc(desc),refs(refs) { }
+ std::string verbose_desc() const;
virtual float ComputeRefLength(const Sentence& hyp) const; // default: avg of refs.length
virtual ~SentenceScorer();
+ virtual Score* GetOne() const;
+ virtual Score* GetZero() const;
void ComputeErrorSurface(const ViterbiEnvelope& ve, ErrorSurface* es, const ScoreType type, const Hypergraph& hg) const;
virtual Score* ScoreCandidate(const Sentence& hyp) const = 0;
virtual Score* ScoreCCandidate(const Sentence& hyp) const =0;
@@ -57,14 +69,21 @@ class DocScorer {
public:
~DocScorer();
DocScorer() { }
- DocScorer(
- const ScoreType type,
- const std::vector<std::string>& ref_files,
- const std::string& src_file = "");
+ void Init(const ScoreType type,
+ const std::vector<std::string>& ref_files,
+ const std::string& src_file = "");
+ DocScorer(const ScoreType type,
+ const std::vector<std::string>& ref_files,
+ const std::string& src_file = "")
+ {
+ Init(type,ref_files,src_file);
+ }
+
int size() const { return scorers_.size(); }
- const SentenceScorer* operator[](size_t i) const { return scorers_[i]; }
+ typedef boost::shared_ptr<SentenceScorer> ScorerP;
+ ScorerP operator[](size_t i) const { return scorers_[i]; }
private:
- std::vector<SentenceScorer*> scorers_;
+ std::vector<ScorerP> scorers_;
};
#endif