diff options
Diffstat (limited to 'utils')
| -rw-r--r-- | utils/Makefile.am | 9 | ||||
| -rw-r--r-- | utils/fdict.cc | 4 | ||||
| -rw-r--r-- | utils/fdict.h | 36 | ||||
| -rw-r--r-- | utils/perfect_hash.cc | 37 | ||||
| -rw-r--r-- | utils/perfect_hash.h | 24 | ||||
| -rw-r--r-- | utils/phmt.cc | 44 | ||||
| -rw-r--r-- | utils/weights.cc | 132 | ||||
| -rw-r--r-- | utils/weights.h | 14 | 
8 files changed, 249 insertions, 51 deletions
| diff --git a/utils/Makefile.am b/utils/Makefile.am index 94f9be30..c50747bf 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,5 +1,5 @@ -noinst_PROGRAMS = ts -TESTS = ts +noinst_PROGRAMS = ts phmt +TESTS = ts phmt  if HAVE_GTEST  noinst_PROGRAMS += \ @@ -27,6 +27,11 @@ libutils_a_SOURCES = \    verbose.cc \    weights.cc +if HAVE_CMPH +  libutils_a_SOURCES += perfect_hash.cc +endif + +phmt_SOURCES = phmt.cc  ts_SOURCES = ts.cc  dict_test_SOURCES = dict_test.cc  dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) diff --git a/utils/fdict.cc b/utils/fdict.cc index baa0b552..676c951c 100644 --- a/utils/fdict.cc +++ b/utils/fdict.cc @@ -9,6 +9,10 @@ using namespace std;  Dict FD::dict_;  bool FD::frozen_ = false; +#ifdef HAVE_CMPH +PerfectHashFunction* FD::hash_ = NULL; +#endif +  std::string FD::Convert(std::vector<WordID> const& v) {      return Convert(&*v.begin(),&*v.end());  } diff --git a/utils/fdict.h b/utils/fdict.h index f9673023..771e8b91 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,23 +1,56 @@  #ifndef _FDICT_H_  #define _FDICT_H_ +#include "config.h" + +#include <iostream>  #include <string>  #include <vector>  #include "dict.h" +#ifdef HAVE_CMPH +#include "perfect_hash.h" +#include "string_to.h" +#endif +  struct FD {    // once the FD is frozen, new features not already in the    // dictionary will return 0    static void Freeze() {      frozen_ = true;    } +  static bool UsingPerfectHashFunction() { +#ifdef HAVE_CMPH +    return hash_; +#else +    return false; +#endif +  } +  static void EnableHash(const std::string& cmph_file) { +#ifdef HAVE_CMPH +    hash_ = new PerfectHashFunction(cmph_file); +#endif +  }    static inline int NumFeats() { +#ifdef HAVE_CMPH +    if (hash_) return hash_->number_of_keys(); +#endif      return dict_.max() + 1;    }    static inline WordID Convert(const std::string& s) { +#ifdef HAVE_CMPH +    if (hash_) return (*hash_)(s); +#endif      return dict_.Convert(s, frozen_);    }    static inline const std::string& Convert(const WordID& w) { +#ifdef HAVE_CMPH +    if (hash_) { +      static std::string tls; +      tls = to_string(w); +      return tls; +    } +#endif      return dict_.Convert(w);    }    static std::string Convert(WordID const *i,WordID const* e); @@ -29,6 +62,9 @@ struct FD {    static Dict dict_;   private:    static bool frozen_; +#ifdef HAVE_CMPH +  static PerfectHashFunction* hash_; +#endif  };  #endif diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc new file mode 100644 index 00000000..706e2741 --- /dev/null +++ b/utils/perfect_hash.cc @@ -0,0 +1,37 @@ +#include "config.h" + +#ifdef HAVE_CMPH + +#include "perfect_hash.h" + +#include <cstdio> +#include <iostream> + +using namespace std; + +PerfectHashFunction::~PerfectHashFunction() { +  cmph_destroy(mphf_); +} + +PerfectHashFunction::PerfectHashFunction(const string& fname) { +  FILE* f = fopen(fname.c_str(), "r"); +  if (!f) { +    cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n"; +    abort(); +  } +  mphf_ = cmph_load(f); +  if (!mphf_) { +    cerr << "cmph_load failed on " << fname << "!\n"; +    abort(); +  } +} + +size_t PerfectHashFunction::operator()(const string& key) const { +  return cmph_search(mphf_, &key[0], key.size()); +} + +size_t PerfectHashFunction::number_of_keys() const { +  return cmph_size(mphf_); +} + +#endif diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h new file mode 100644 index 00000000..8ac11f18 --- /dev/null +++ b/utils/perfect_hash.h @@ -0,0 +1,24 @@ +#ifndef _PERFECT_HASH_MAP_H_ +#define _PERFECT_HASH_MAP_H_ + +#include "config.h" + +#ifndef HAVE_CMPH +#error libcmph is required to use PerfectHashFunction +#endif + +#include <vector> +#include <boost/utility.hpp> +#include "cmph.h" + +class PerfectHashFunction : boost::noncopyable { + public: +  explicit PerfectHashFunction(const std::string& fname); +  ~PerfectHashFunction(); +  size_t operator()(const std::string& key) const; +  size_t number_of_keys() const; + private: +  cmph_t *mphf_; +}; + +#endif diff --git a/utils/phmt.cc b/utils/phmt.cc new file mode 100644 index 00000000..1f59afaf --- /dev/null +++ b/utils/phmt.cc @@ -0,0 +1,44 @@ +#include "config.h" + +#ifndef HAVE_CMPH +int main() { +  return 0; +} +#else + +#include <iostream> +#include "weights.h" +#include "fdict.h" + +using namespace std; + +int main(int argc, char** argv) { +  if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; } +  FD::EnableHash(argv[1]); +  cerr << "Number of keys: " << FD::NumFeats() << endl; +  cerr << "LexFE = " << FD::Convert("LexFE") << endl; +  cerr << "LexEF = " << FD::Convert("LexEF") << endl; +  { +    Weights w; +    vector<weight_t> v(FD::NumFeats()); +    v[FD::Convert("LexFE")] = 1.0; +    v[FD::Convert("LexEF")] = 0.5; +    w.InitFromVector(v); +    cerr << "Writing...\n"; +    w.WriteToFile("weights.bin"); +    cerr << "Done.\n"; +  } +  { +    Weights w; +    vector<weight_t> v(FD::NumFeats()); +    cerr << "Reading...\n"; +    w.InitFromFile("weights.bin"); +    cerr << "Done.\n"; +    w.InitVector(&v); +    assert(v[FD::Convert("LexFE")] == 1.0); +    assert(v[FD::Convert("LexEF")] == 0.5); +  } +} + +#endif + diff --git a/utils/weights.cc b/utils/weights.cc index b994a2fe..0916b72a 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector<string>* feature_    ReadFile in_file(filename);    istream& in = *in_file.stream();    assert(in); -  int weight_count = 0; -  bool fl = false; -  string buf; -  double val = 0; -  while (in) { -    getline(in, buf); -    if (buf.size() == 0) continue; -    if (buf[0] == '#') continue; -    for (int i = 0; i < buf.size(); ++i) -      if (buf[i] == '=') buf[i] = ' '; -    int start = 0; -    while(start < buf.size() && buf[start] == ' ') ++start; -    int end = 0; -    while(end < buf.size() && buf[end] != ' ') ++end; -    const int fid = FD::Convert(buf.substr(start, end - start)); -    while(end < buf.size() && buf[end] == ' ') ++end; -    val = strtod(&buf.c_str()[end], NULL); -    if (isnan(val)) { -      cerr << FD::Convert(fid) << " has weight NaN!\n"; -     abort(); +   +  bool read_text = true; +  if (1) { +    ReadFile hdrrf(filename); +    istream& hi = *hdrrf.stream(); +    assert(hi); +    char buf[10]; +    hi.get(buf, 6); +    assert(hi.good()); +    if (strncmp(buf, "_PHWf", 5) == 0) { +      read_text = false; +    } +  } + +  if (read_text) { +    int weight_count = 0; +    bool fl = false; +    string buf; +    weight_t val = 0; +    while (in) { +      getline(in, buf); +      if (buf.size() == 0) continue; +      if (buf[0] == '#') continue; +      if (buf[0] == ' ') { +        cerr << "Weights file lines may not start with whitespace.\n" << buf << endl; +        abort(); +      } +      for (int i = buf.size() - 1; i > 0; --i) +        if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; } +      int start = 0; +      while(start < buf.size() && buf[start] == ' ') ++start; +      int end = 0; +      while(end < buf.size() && buf[end] != ' ') ++end; +      const int fid = FD::Convert(buf.substr(start, end - start)); +      while(end < buf.size() && buf[end] == ' ') ++end; +      val = strtod(&buf.c_str()[end], NULL); +      if (isnan(val)) { +        cerr << FD::Convert(fid) << " has weight NaN!\n"; +        abort(); +      } +      if (wv_.size() <= fid) +        wv_.resize(fid + 1); +      wv_[fid] = val; +      if (feature_list) { feature_list->push_back(FD::Convert(fid)); } +      ++weight_count; +      if (!SILENT) { +        if (weight_count %   50000 == 0) { cerr << '.' << flush; fl = true; } +        if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } +      }      } -    if (wv_.size() <= fid) -      wv_.resize(fid + 1); -    wv_[fid] = val; -    if (feature_list) { feature_list->push_back(FD::Convert(fid)); } -    ++weight_count;      if (!SILENT) { -      if (weight_count %   50000 == 0) { cerr << '.' << flush; fl = true; } -      if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } +      if (fl) { cerr << endl; } +      cerr << "Loaded " << weight_count << " feature weights\n"; +    } +  } else {   // !read_text +    char buf[6]; +    in.get(buf, 6); +    size_t num_keys[2]; +    in.get(reinterpret_cast<char*>(&num_keys[0]), sizeof(size_t) + 1); +    if (num_keys[0] != FD::NumFeats()) { +      cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; +      abort(); +    } +    wv_.resize(num_keys[0]); +    in.get(reinterpret_cast<char*>(&wv_[0]), num_keys[0] * sizeof(weight_t)); +    if (!in.good()) { +      cerr << "Error loading weights!\n"; +      abort();      } -  } -  if (!SILENT) { -    if (fl) { cerr << endl; } -    cerr << "Loaded " << weight_count << " feature weights\n";    }  } @@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature    WriteFile out(fname);    ostream& o = *out.stream();    assert(o); -  if (extra) { o << "# " << *extra << endl; } -  o.precision(17); -  const int num_feats = FD::NumFeats(); -  for (int i = 1; i < num_feats; ++i) { -    const double val = (i < wv_.size() ? wv_[i] : 0.0); -    if (hide_zero_value_features && val == 0.0) continue; -    o << FD::Convert(i) << ' ' << val << endl; +  bool write_text = !FD::UsingPerfectHashFunction(); + +  if (write_text) { +    if (extra) { o << "# " << *extra << endl; } +    o.precision(17); +    const int num_feats = FD::NumFeats(); +    for (int i = 1; i < num_feats; ++i) { +      const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); +      if (hide_zero_value_features && val == 0.0) continue; +      o << FD::Convert(i) << ' ' << val << endl; +    } +  } else { +    o.write("_PHWf", 5); +    const size_t keys = FD::NumFeats(); +    assert(keys <= wv_.size()); +    o.write(reinterpret_cast<const char*>(&keys), sizeof(keys)); +    o.write(reinterpret_cast<const char*>(&wv_[0]), keys * sizeof(weight_t));    }  } -void Weights::InitVector(std::vector<double>* w) const { +void Weights::InitVector(std::vector<weight_t>* w) const {    *w = wv_;  } -void Weights::InitSparseVector(SparseVector<double>* w) const { +void Weights::InitSparseVector(SparseVector<weight_t>* w) const {    for (int i = 1; i < wv_.size(); ++i) { -    const double& weight = wv_[i]; +    const weight_t& weight = wv_[i];      if (weight) w->set_value(i, weight);    }  } -void Weights::InitFromVector(const std::vector<double>& w) { +void Weights::InitFromVector(const std::vector<weight_t>& w) {    wv_ = w;    if (wv_.size() > FD::NumFeats())      cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";    wv_.resize(FD::NumFeats(), 0);  } -void Weights::InitFromVector(const SparseVector<double>& w) { +void Weights::InitFromVector(const SparseVector<weight_t>& w) {    wv_.clear();    wv_.resize(FD::NumFeats(), 0.0);    for (int i = 1; i < FD::NumFeats(); ++i)      wv_[i] = w.value(i);  } + diff --git a/utils/weights.h b/utils/weights.h index cc20283c..7664810b 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -2,21 +2,23 @@  #define _WEIGHTS_H_  #include <string> -#include <map>  #include <vector>  #include "sparse_vector.h" +// warning: in the future this will become float +typedef double weight_t; +  class Weights {   public:    Weights() {}    void InitFromFile(const std::string& fname, std::vector<std::string>* feature_list = NULL);    void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; -  void InitVector(std::vector<double>* w) const; -  void InitSparseVector(SparseVector<double>* w) const; -  void InitFromVector(const std::vector<double>& w); -  void InitFromVector(const SparseVector<double>& w); +  void InitVector(std::vector<weight_t>* w) const; +  void InitSparseVector(SparseVector<weight_t>* w) const; +  void InitFromVector(const std::vector<weight_t>& w); +  void InitFromVector(const SparseVector<weight_t>& w);   private: -  std::vector<double> wv_; +  std::vector<weight_t> wv_;  };  #endif | 
