summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/filelib.h62
-rwxr-xr-xdecoder/oracle_bleu.h24
-rw-r--r--decoder/small_vector.h2
3 files changed, 62 insertions, 26 deletions
diff --git a/decoder/filelib.h b/decoder/filelib.h
index 03c22b0d..1630481d 100644
--- a/decoder/filelib.h
+++ b/decoder/filelib.h
@@ -5,6 +5,7 @@
#include <string>
#include <iostream>
#include <cstdlib>
+#include <boost/shared_ptr.hpp>
#include "gzstream.h"
bool FileExists(const std::string& file_name);
@@ -13,35 +14,57 @@ bool DirectoryExists(const std::string& dir_name);
// reads from standard in if filename is -
// uncompresses if file ends with .gz
// otherwise, reads from a normal file
+struct file_null_deleter {
+ void operator()(void*) const {}
+};
+
class ReadFile {
public:
- ReadFile(const std::string& filename) :
- no_delete_on_exit_(filename == "-"),
- in_(no_delete_on_exit_ ? static_cast<std::istream*>(&std::cin) :
- (EndsWith(filename, ".gz") ?
- static_cast<std::istream*>(new igzstream(filename.c_str())) :
- static_cast<std::istream*>(new std::ifstream(filename.c_str())))) {
- if (!no_delete_on_exit_ && !FileExists(filename)) {
- std::cerr << "File does not exist: " << filename << std::endl;
- abort();
- }
- if (!*in_) {
- std::cerr << "Failed to open " << filename << std::endl;
- abort();
+ typedef boost::shared_ptr<std::istream> PS;
+ ReadFile() { }
+ std::string filename_;
+ void Init(const std::string& filename) {
+ bool stdin=(filename == "-");
+ if (stdin) {
+ in_=PS(&std::cin,file_null_deleter());
+ } else {
+ if (!FileExists(filename)) {
+ std::cerr << "File does not exist: " << filename << std::endl;
+ abort();
+ }
+ filename_=filename;
+ in_.reset(EndsWith(filename, ".gz") ?
+ static_cast<std::istream*>(new igzstream(filename.c_str())) :
+ static_cast<std::istream*>(new std::ifstream(filename.c_str())));
+ if (!*in_) {
+ std::cerr << "Failed to open " << filename << std::endl;
+ abort();
+ }
}
}
+ void Reset() {
+ in_.reset();
+ }
+ bool is_null() const { return !in_; }
+ operator bool() const {
+ return in_;
+ }
+
+ explicit ReadFile(const std::string& filename) {
+ Init(filename);
+ }
~ReadFile() {
- if (!no_delete_on_exit_) delete in_;
}
- inline std::istream* stream() { return in_; }
-
+ std::istream* stream() { return in_.get(); }
+ std::istream* operator->() { return in_.get(); } // compat with old ReadFile * -> new Readfile. remove?
+ std::istream &get() const { return *in_; }
+
private:
static bool EndsWith(const std::string& f, const std::string& suf) {
return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
}
- const bool no_delete_on_exit_;
- std::istream* const in_;
+ PS in_;
};
class WriteFile {
@@ -58,7 +81,8 @@ class WriteFile {
}
inline std::ostream* stream() { return out_; }
-
+ std::ostream &get() const { return *out_; }
+
private:
static bool EndsWith(const std::string& f, const std::string& suf) {
return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 66d155d3..4800e9c1 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -1,6 +1,8 @@
#ifndef ORACLE_BLEU_H
#define ORACLE_BLEU_H
+#define DEBUG_ORACLE_BLEU
+
#include <sstream>
#include <iostream>
#include <string>
@@ -20,7 +22,7 @@
#include "sentences.h"
//TODO: put function impls into .cc
-//TODO: disentangle
+//TODO: move Translation into its own .h and use in cdec
struct Translation {
typedef std::vector<WordID> Sentence;
Sentence sentence;
@@ -153,11 +155,16 @@ struct OracleBleu {
init_refs();
}
void init_refs() {
- if (is_null()) return;
+ if (is_null()) {
+#ifdef DEBUG_ORACLE_BLEU
+ std::cerr<<"No references for oracle BLEU.\n";
+#endif
+ return;
+ }
assert(refs.size());
- ds=DocScorer(loss,refs);
- doc_score.reset();
-// doc_score=sentscore
+ ds.Init(loss,refs);
+ ensure_doc_score();
+// doc_score.reset();
std::cerr << "Loaded " << ds.size() << " references for scoring with " << StringFromScoreType(loss) << std::endl;
}
@@ -193,10 +200,15 @@ struct OracleBleu {
return r;
}
+ // if doc_score wasn't init, add 1 counts to ngram acc.
+ void ensure_doc_score() {
+ if (!doc_score) { doc_score.reset(Score::GetOne(loss)); }
+ }
+
void Rescore(SentenceMetadata const& smeta,Hypergraph const& forest,Hypergraph *dest_forest,WeightVector const& feature_weights,double bleu_weight=1.0,std::ostream *log=&std::cerr) {
// the sentence bleu stats will get added to doc only if you call IncludeLastScore
+ ensure_doc_score();
sentscore=GetScore(forest,smeta.GetSentenceID());
- if (!doc_score) { doc_score.reset(sentscore->GetOne()); }
tmp_src_length = smeta.GetSourceLength(); //TODO: where does this come from?
using namespace std;
DenseWeightVector w;
diff --git a/decoder/small_vector.h b/decoder/small_vector.h
index c3090a8b..202b72c9 100644
--- a/decoder/small_vector.h
+++ b/decoder/small_vector.h
@@ -11,7 +11,7 @@
#include <streambuf> // std::max - where to get this?
#include <cstring>
#include <cassert>
-#include <limits.h>
+#include <stdint.h>
#include <new>
#include <stdint.h>
//sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1