From 38a5bee71f6b49515cd105a9467ff602ff9dee64 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 13:25:46 +0100 Subject: optional support for doing perfect hashing of feature strings to save lots of memory --- utils/Makefile.am | 9 +++- utils/fdict.cc | 4 ++ utils/fdict.h | 36 ++++++++++++++ utils/perfect_hash.cc | 37 ++++++++++++++ utils/perfect_hash.h | 24 +++++++++ utils/phmt.cc | 44 +++++++++++++++++ utils/weights.cc | 132 ++++++++++++++++++++++++++++++++++---------------- utils/weights.h | 14 +++--- 8 files changed, 249 insertions(+), 51 deletions(-) create mode 100644 utils/perfect_hash.cc create mode 100644 utils/perfect_hash.h create mode 100644 utils/phmt.cc (limited to 'utils') diff --git a/utils/Makefile.am b/utils/Makefile.am index 94f9be30..c50747bf 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,5 +1,5 @@ -noinst_PROGRAMS = ts -TESTS = ts +noinst_PROGRAMS = ts phmt +TESTS = ts phmt if HAVE_GTEST noinst_PROGRAMS += \ @@ -27,6 +27,11 @@ libutils_a_SOURCES = \ verbose.cc \ weights.cc +if HAVE_CMPH + libutils_a_SOURCES += perfect_hash.cc +endif + +phmt_SOURCES = phmt.cc ts_SOURCES = ts.cc dict_test_SOURCES = dict_test.cc dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) diff --git a/utils/fdict.cc b/utils/fdict.cc index baa0b552..676c951c 100644 --- a/utils/fdict.cc +++ b/utils/fdict.cc @@ -9,6 +9,10 @@ using namespace std; Dict FD::dict_; bool FD::frozen_ = false; +#ifdef HAVE_CMPH +PerfectHashFunction* FD::hash_ = NULL; +#endif + std::string FD::Convert(std::vector const& v) { return Convert(&*v.begin(),&*v.end()); } diff --git a/utils/fdict.h b/utils/fdict.h index f9673023..771e8b91 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,23 +1,56 @@ #ifndef _FDICT_H_ #define _FDICT_H_ +#include "config.h" + +#include #include #include #include "dict.h" +#ifdef HAVE_CMPH +#include "perfect_hash.h" +#include "string_to.h" +#endif + struct FD { // once the FD is frozen, new features not already in the // dictionary will return 0 static void Freeze() { frozen_ = true; } + static bool UsingPerfectHashFunction() { +#ifdef HAVE_CMPH + return hash_; +#else + return false; +#endif + } + static void EnableHash(const std::string& cmph_file) { +#ifdef HAVE_CMPH + hash_ = new PerfectHashFunction(cmph_file); +#endif + } static inline int NumFeats() { +#ifdef HAVE_CMPH + if (hash_) return hash_->number_of_keys(); +#endif return dict_.max() + 1; } static inline WordID Convert(const std::string& s) { +#ifdef HAVE_CMPH + if (hash_) return (*hash_)(s); +#endif return dict_.Convert(s, frozen_); } static inline const std::string& Convert(const WordID& w) { +#ifdef HAVE_CMPH + if (hash_) { + static std::string tls; + tls = to_string(w); + return tls; + } +#endif return dict_.Convert(w); } static std::string Convert(WordID const *i,WordID const* e); @@ -29,6 +62,9 @@ struct FD { static Dict dict_; private: static bool frozen_; +#ifdef HAVE_CMPH + static PerfectHashFunction* hash_; +#endif }; #endif diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc new file mode 100644 index 00000000..706e2741 --- /dev/null +++ b/utils/perfect_hash.cc @@ -0,0 +1,37 @@ +#include "config.h" + +#ifdef HAVE_CMPH + +#include "perfect_hash.h" + +#include +#include + +using namespace std; + +PerfectHashFunction::~PerfectHashFunction() { + cmph_destroy(mphf_); +} + +PerfectHashFunction::PerfectHashFunction(const string& fname) { + FILE* f = fopen(fname.c_str(), "r"); + if (!f) { + cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n"; + abort(); + } + mphf_ = cmph_load(f); + if (!mphf_) { + cerr << "cmph_load failed on " << fname << "!\n"; + abort(); + } +} + +size_t PerfectHashFunction::operator()(const string& key) const { + return cmph_search(mphf_, &key[0], key.size()); +} + +size_t PerfectHashFunction::number_of_keys() const { + return cmph_size(mphf_); +} + +#endif diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h new file mode 100644 index 00000000..8ac11f18 --- /dev/null +++ b/utils/perfect_hash.h @@ -0,0 +1,24 @@ +#ifndef _PERFECT_HASH_MAP_H_ +#define _PERFECT_HASH_MAP_H_ + +#include "config.h" + +#ifndef HAVE_CMPH +#error libcmph is required to use PerfectHashFunction +#endif + +#include +#include +#include "cmph.h" + +class PerfectHashFunction : boost::noncopyable { + public: + explicit PerfectHashFunction(const std::string& fname); + ~PerfectHashFunction(); + size_t operator()(const std::string& key) const; + size_t number_of_keys() const; + private: + cmph_t *mphf_; +}; + +#endif diff --git a/utils/phmt.cc b/utils/phmt.cc new file mode 100644 index 00000000..1f59afaf --- /dev/null +++ b/utils/phmt.cc @@ -0,0 +1,44 @@ +#include "config.h" + +#ifndef HAVE_CMPH +int main() { + return 0; +} +#else + +#include +#include "weights.h" +#include "fdict.h" + +using namespace std; + +int main(int argc, char** argv) { + if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; } + FD::EnableHash(argv[1]); + cerr << "Number of keys: " << FD::NumFeats() << endl; + cerr << "LexFE = " << FD::Convert("LexFE") << endl; + cerr << "LexEF = " << FD::Convert("LexEF") << endl; + { + Weights w; + vector v(FD::NumFeats()); + v[FD::Convert("LexFE")] = 1.0; + v[FD::Convert("LexEF")] = 0.5; + w.InitFromVector(v); + cerr << "Writing...\n"; + w.WriteToFile("weights.bin"); + cerr << "Done.\n"; + } + { + Weights w; + vector v(FD::NumFeats()); + cerr << "Reading...\n"; + w.InitFromFile("weights.bin"); + cerr << "Done.\n"; + w.InitVector(&v); + assert(v[FD::Convert("LexFE")] == 1.0); + assert(v[FD::Convert("LexEF")] == 0.5); + } +} + +#endif + diff --git a/utils/weights.cc b/utils/weights.cc index b994a2fe..0916b72a 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector* feature_ ReadFile in_file(filename); istream& in = *in_file.stream(); assert(in); - int weight_count = 0; - bool fl = false; - string buf; - double val = 0; - while (in) { - getline(in, buf); - if (buf.size() == 0) continue; - if (buf[0] == '#') continue; - for (int i = 0; i < buf.size(); ++i) - if (buf[i] == '=') buf[i] = ' '; - int start = 0; - while(start < buf.size() && buf[start] == ' ') ++start; - int end = 0; - while(end < buf.size() && buf[end] != ' ') ++end; - const int fid = FD::Convert(buf.substr(start, end - start)); - while(end < buf.size() && buf[end] == ' ') ++end; - val = strtod(&buf.c_str()[end], NULL); - if (isnan(val)) { - cerr << FD::Convert(fid) << " has weight NaN!\n"; - abort(); + + bool read_text = true; + if (1) { + ReadFile hdrrf(filename); + istream& hi = *hdrrf.stream(); + assert(hi); + char buf[10]; + hi.get(buf, 6); + assert(hi.good()); + if (strncmp(buf, "_PHWf", 5) == 0) { + read_text = false; + } + } + + if (read_text) { + int weight_count = 0; + bool fl = false; + string buf; + weight_t val = 0; + while (in) { + getline(in, buf); + if (buf.size() == 0) continue; + if (buf[0] == '#') continue; + if (buf[0] == ' ') { + cerr << "Weights file lines may not start with whitespace.\n" << buf << endl; + abort(); + } + for (int i = buf.size() - 1; i > 0; --i) + if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; } + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = 0; + while(end < buf.size() && buf[end] != ' ') ++end; + const int fid = FD::Convert(buf.substr(start, end - start)); + while(end < buf.size() && buf[end] == ' ') ++end; + val = strtod(&buf.c_str()[end], NULL); + if (isnan(val)) { + cerr << FD::Convert(fid) << " has weight NaN!\n"; + abort(); + } + if (wv_.size() <= fid) + wv_.resize(fid + 1); + wv_[fid] = val; + if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + ++weight_count; + if (!SILENT) { + if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + } } - if (wv_.size() <= fid) - wv_.resize(fid + 1); - wv_[fid] = val; - if (feature_list) { feature_list->push_back(FD::Convert(fid)); } - ++weight_count; if (!SILENT) { - if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } - if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + if (fl) { cerr << endl; } + cerr << "Loaded " << weight_count << " feature weights\n"; + } + } else { // !read_text + char buf[6]; + in.get(buf, 6); + size_t num_keys[2]; + in.get(reinterpret_cast(&num_keys[0]), sizeof(size_t) + 1); + if (num_keys[0] != FD::NumFeats()) { + cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; + abort(); + } + wv_.resize(num_keys[0]); + in.get(reinterpret_cast(&wv_[0]), num_keys[0] * sizeof(weight_t)); + if (!in.good()) { + cerr << "Error loading weights!\n"; + abort(); } - } - if (!SILENT) { - if (fl) { cerr << endl; } - cerr << "Loaded " << weight_count << " feature weights\n"; } } @@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature WriteFile out(fname); ostream& o = *out.stream(); assert(o); - if (extra) { o << "# " << *extra << endl; } - o.precision(17); - const int num_feats = FD::NumFeats(); - for (int i = 1; i < num_feats; ++i) { - const double val = (i < wv_.size() ? wv_[i] : 0.0); - if (hide_zero_value_features && val == 0.0) continue; - o << FD::Convert(i) << ' ' << val << endl; + bool write_text = !FD::UsingPerfectHashFunction(); + + if (write_text) { + if (extra) { o << "# " << *extra << endl; } + o.precision(17); + const int num_feats = FD::NumFeats(); + for (int i = 1; i < num_feats; ++i) { + const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); + if (hide_zero_value_features && val == 0.0) continue; + o << FD::Convert(i) << ' ' << val << endl; + } + } else { + o.write("_PHWf", 5); + const size_t keys = FD::NumFeats(); + assert(keys <= wv_.size()); + o.write(reinterpret_cast(&keys), sizeof(keys)); + o.write(reinterpret_cast(&wv_[0]), keys * sizeof(weight_t)); } } -void Weights::InitVector(std::vector* w) const { +void Weights::InitVector(std::vector* w) const { *w = wv_; } -void Weights::InitSparseVector(SparseVector* w) const { +void Weights::InitSparseVector(SparseVector* w) const { for (int i = 1; i < wv_.size(); ++i) { - const double& weight = wv_[i]; + const weight_t& weight = wv_[i]; if (weight) w->set_value(i, weight); } } -void Weights::InitFromVector(const std::vector& w) { +void Weights::InitFromVector(const std::vector& w) { wv_ = w; if (wv_.size() > FD::NumFeats()) cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; wv_.resize(FD::NumFeats(), 0); } -void Weights::InitFromVector(const SparseVector& w) { +void Weights::InitFromVector(const SparseVector& w) { wv_.clear(); wv_.resize(FD::NumFeats(), 0.0); for (int i = 1; i < FD::NumFeats(); ++i) wv_[i] = w.value(i); } + diff --git a/utils/weights.h b/utils/weights.h index cc20283c..7664810b 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -2,21 +2,23 @@ #define _WEIGHTS_H_ #include -#include #include #include "sparse_vector.h" +// warning: in the future this will become float +typedef double weight_t; + class Weights { public: Weights() {} void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; - void InitVector(std::vector* w) const; - void InitSparseVector(SparseVector* w) const; - void InitFromVector(const std::vector& w); - void InitFromVector(const SparseVector& w); + void InitVector(std::vector* w) const; + void InitSparseVector(SparseVector* w) const; + void InitFromVector(const std::vector& w); + void InitFromVector(const SparseVector& w); private: - std::vector wv_; + std::vector wv_; }; #endif -- cgit v1.2.3