diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-13 13:25:46 +0100 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-13 13:25:46 +0100 | 
| commit | e7993fb83537105a56c274d78ed9d51a79a8a854 (patch) | |
| tree | da88ab86e1173025c40113c303f08e02e6476c14 | |
| parent | c41704e876930311539f0cfb5f5125f3401d08ae (diff) | |
optional support for doing perfect hashing of feature strings to save lots of memory
| -rw-r--r-- | decoder/decoder.cc | 22 | ||||
| -rw-r--r-- | utils/Makefile.am | 9 | ||||
| -rw-r--r-- | utils/fdict.cc | 4 | ||||
| -rw-r--r-- | utils/fdict.h | 36 | ||||
| -rw-r--r-- | utils/perfect_hash.cc | 37 | ||||
| -rw-r--r-- | utils/perfect_hash.h | 24 | ||||
| -rw-r--r-- | utils/phmt.cc | 44 | ||||
| -rw-r--r-- | utils/weights.cc | 132 | ||||
| -rw-r--r-- | utils/weights.h | 14 | 
9 files changed, 269 insertions, 53 deletions
| diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 76f31352..25eb2de4 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -328,6 +328,7 @@ struct DecoderImpl {    bool write_gradient; // TODO Observer    bool feature_expectations; // TODO Observer    bool output_training_vector; // TODO Observer +  bool remove_intersected_rule_annotations;    static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {      for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) @@ -361,6 +362,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")          ("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")          ("list_feature_functions,L","List available feature functions") +#ifdef HAVE_CMPH +        ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features") +#endif          ("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")          ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)") @@ -433,7 +437,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")          ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")          ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") -        ("forest_output,O",po::value<string>(),"Directory to write forests to"); +        ("forest_output,O",po::value<string>(),"Directory to write forests to") +        ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");    // ob.AddOptions(&opts);  #ifdef FSA_RESCORING @@ -443,7 +448,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    po::options_description clo("Command line options");    clo.add_options()      ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority") -        ("help,h", "Print this help message and exit") +        ("help,?", "Print this help message and exit")      ("usage,u", po::value<string>(), "Describe a feature function type")      ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")      ; @@ -645,6 +650,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream      FD::Freeze(); // this means we can't see the feature names of not-weighted features    } +  if (conf.count("cmph_perfect_feature_hash")) { +    cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n"; +    FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>()); +    cerr << "  " << FD::NumFeats() << " features in map\n"; +  } +    // set up translation back end    if (formalism == "scfg")      translator.reset(new SCFGTranslator(conf)); @@ -695,6 +706,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    unique_kbest = conf.count("unique_k_best");    get_oracle_forest = conf.count("get_oracle_forest");    oracle.show_derivation=conf.count("show_derivations"); +  remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");  #ifdef FSA_RESCORING    cfg_options.Validate(); @@ -1010,6 +1022,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {  //        if (!SILENT) cerr << "  USING UNIFORM WEIGHTS\n";  //        for (int i = 0; i < forest.edges_.size(); ++i)  //          forest.edges_[i].edge_prob_=prob_t::One(); } +      if (remove_intersected_rule_annotations) { +        for (unsigned i = 0; i < forest.edges_.size(); ++i) +          if (forest.edges_[i].rule_ && +              forest.edges_[i].rule_->parent_rule_) +            forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; +      }        forest.Reweight(last_weights);        if (!SILENT) forest_stats(forest,"  Constr. forest",show_tree_structure,oracle.show_derivation);        if (!SILENT) cerr << "  Constr. VitTree: " << ViterbiFTree(forest) << endl; diff --git a/utils/Makefile.am b/utils/Makefile.am index 94f9be30..c50747bf 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -1,5 +1,5 @@ -noinst_PROGRAMS = ts -TESTS = ts +noinst_PROGRAMS = ts phmt +TESTS = ts phmt  if HAVE_GTEST  noinst_PROGRAMS += \ @@ -27,6 +27,11 @@ libutils_a_SOURCES = \    verbose.cc \    weights.cc +if HAVE_CMPH +  libutils_a_SOURCES += perfect_hash.cc +endif + +phmt_SOURCES = phmt.cc  ts_SOURCES = ts.cc  dict_test_SOURCES = dict_test.cc  dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) diff --git a/utils/fdict.cc b/utils/fdict.cc index baa0b552..676c951c 100644 --- a/utils/fdict.cc +++ b/utils/fdict.cc @@ -9,6 +9,10 @@ using namespace std;  Dict FD::dict_;  bool FD::frozen_ = false; +#ifdef HAVE_CMPH +PerfectHashFunction* FD::hash_ = NULL; +#endif +  std::string FD::Convert(std::vector<WordID> const& v) {      return Convert(&*v.begin(),&*v.end());  } diff --git a/utils/fdict.h b/utils/fdict.h index f9673023..771e8b91 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -1,23 +1,56 @@  #ifndef _FDICT_H_  #define _FDICT_H_ +#include "config.h" + +#include <iostream>  #include <string>  #include <vector>  #include "dict.h" +#ifdef HAVE_CMPH +#include "perfect_hash.h" +#include "string_to.h" +#endif +  struct FD {    // once the FD is frozen, new features not already in the    // dictionary will return 0    static void Freeze() {      frozen_ = true;    } +  static bool UsingPerfectHashFunction() { +#ifdef HAVE_CMPH +    return hash_; +#else +    return false; +#endif +  } +  static void EnableHash(const std::string& cmph_file) { +#ifdef HAVE_CMPH +    hash_ = new PerfectHashFunction(cmph_file); +#endif +  }    static inline int NumFeats() { +#ifdef HAVE_CMPH +    if (hash_) return hash_->number_of_keys(); +#endif      return dict_.max() + 1;    }    static inline WordID Convert(const std::string& s) { +#ifdef HAVE_CMPH +    if (hash_) return (*hash_)(s); +#endif      return dict_.Convert(s, frozen_);    }    static inline const std::string& Convert(const WordID& w) { +#ifdef HAVE_CMPH +    if (hash_) { +      static std::string tls; +      tls = to_string(w); +      return tls; +    } +#endif      return dict_.Convert(w);    }    static std::string Convert(WordID const *i,WordID const* e); @@ -29,6 +62,9 @@ struct FD {    static Dict dict_;   private:    static bool frozen_; +#ifdef HAVE_CMPH +  static PerfectHashFunction* hash_; +#endif  };  #endif diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc new file mode 100644 index 00000000..706e2741 --- /dev/null +++ b/utils/perfect_hash.cc @@ -0,0 +1,37 @@ +#include "config.h" + +#ifdef HAVE_CMPH + +#include "perfect_hash.h" + +#include <cstdio> +#include <iostream> + +using namespace std; + +PerfectHashFunction::~PerfectHashFunction() { +  cmph_destroy(mphf_); +} + +PerfectHashFunction::PerfectHashFunction(const string& fname) { +  FILE* f = fopen(fname.c_str(), "r"); +  if (!f) { +    cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n"; +    abort(); +  } +  mphf_ = cmph_load(f); +  if (!mphf_) { +    cerr << "cmph_load failed on " << fname << "!\n"; +    abort(); +  } +} + +size_t PerfectHashFunction::operator()(const string& key) const { +  return cmph_search(mphf_, &key[0], key.size()); +} + +size_t PerfectHashFunction::number_of_keys() const { +  return cmph_size(mphf_); +} + +#endif diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h new file mode 100644 index 00000000..8ac11f18 --- /dev/null +++ b/utils/perfect_hash.h @@ -0,0 +1,24 @@ +#ifndef _PERFECT_HASH_MAP_H_ +#define _PERFECT_HASH_MAP_H_ + +#include "config.h" + +#ifndef HAVE_CMPH +#error libcmph is required to use PerfectHashFunction +#endif + +#include <vector> +#include <boost/utility.hpp> +#include "cmph.h" + +class PerfectHashFunction : boost::noncopyable { + public: +  explicit PerfectHashFunction(const std::string& fname); +  ~PerfectHashFunction(); +  size_t operator()(const std::string& key) const; +  size_t number_of_keys() const; + private: +  cmph_t *mphf_; +}; + +#endif diff --git a/utils/phmt.cc b/utils/phmt.cc new file mode 100644 index 00000000..1f59afaf --- /dev/null +++ b/utils/phmt.cc @@ -0,0 +1,44 @@ +#include "config.h" + +#ifndef HAVE_CMPH +int main() { +  return 0; +} +#else + +#include <iostream> +#include "weights.h" +#include "fdict.h" + +using namespace std; + +int main(int argc, char** argv) { +  if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; } +  FD::EnableHash(argv[1]); +  cerr << "Number of keys: " << FD::NumFeats() << endl; +  cerr << "LexFE = " << FD::Convert("LexFE") << endl; +  cerr << "LexEF = " << FD::Convert("LexEF") << endl; +  { +    Weights w; +    vector<weight_t> v(FD::NumFeats()); +    v[FD::Convert("LexFE")] = 1.0; +    v[FD::Convert("LexEF")] = 0.5; +    w.InitFromVector(v); +    cerr << "Writing...\n"; +    w.WriteToFile("weights.bin"); +    cerr << "Done.\n"; +  } +  { +    Weights w; +    vector<weight_t> v(FD::NumFeats()); +    cerr << "Reading...\n"; +    w.InitFromFile("weights.bin"); +    cerr << "Done.\n"; +    w.InitVector(&v); +    assert(v[FD::Convert("LexFE")] == 1.0); +    assert(v[FD::Convert("LexEF")] == 0.5); +  } +} + +#endif + diff --git a/utils/weights.cc b/utils/weights.cc index b994a2fe..0916b72a 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector<string>* feature_    ReadFile in_file(filename);    istream& in = *in_file.stream();    assert(in); -  int weight_count = 0; -  bool fl = false; -  string buf; -  double val = 0; -  while (in) { -    getline(in, buf); -    if (buf.size() == 0) continue; -    if (buf[0] == '#') continue; -    for (int i = 0; i < buf.size(); ++i) -      if (buf[i] == '=') buf[i] = ' '; -    int start = 0; -    while(start < buf.size() && buf[start] == ' ') ++start; -    int end = 0; -    while(end < buf.size() && buf[end] != ' ') ++end; -    const int fid = FD::Convert(buf.substr(start, end - start)); -    while(end < buf.size() && buf[end] == ' ') ++end; -    val = strtod(&buf.c_str()[end], NULL); -    if (isnan(val)) { -      cerr << FD::Convert(fid) << " has weight NaN!\n"; -     abort(); +   +  bool read_text = true; +  if (1) { +    ReadFile hdrrf(filename); +    istream& hi = *hdrrf.stream(); +    assert(hi); +    char buf[10]; +    hi.get(buf, 6); +    assert(hi.good()); +    if (strncmp(buf, "_PHWf", 5) == 0) { +      read_text = false; +    } +  } + +  if (read_text) { +    int weight_count = 0; +    bool fl = false; +    string buf; +    weight_t val = 0; +    while (in) { +      getline(in, buf); +      if (buf.size() == 0) continue; +      if (buf[0] == '#') continue; +      if (buf[0] == ' ') { +        cerr << "Weights file lines may not start with whitespace.\n" << buf << endl; +        abort(); +      } +      for (int i = buf.size() - 1; i > 0; --i) +        if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; } +      int start = 0; +      while(start < buf.size() && buf[start] == ' ') ++start; +      int end = 0; +      while(end < buf.size() && buf[end] != ' ') ++end; +      const int fid = FD::Convert(buf.substr(start, end - start)); +      while(end < buf.size() && buf[end] == ' ') ++end; +      val = strtod(&buf.c_str()[end], NULL); +      if (isnan(val)) { +        cerr << FD::Convert(fid) << " has weight NaN!\n"; +        abort(); +      } +      if (wv_.size() <= fid) +        wv_.resize(fid + 1); +      wv_[fid] = val; +      if (feature_list) { feature_list->push_back(FD::Convert(fid)); } +      ++weight_count; +      if (!SILENT) { +        if (weight_count %   50000 == 0) { cerr << '.' << flush; fl = true; } +        if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } +      }      } -    if (wv_.size() <= fid) -      wv_.resize(fid + 1); -    wv_[fid] = val; -    if (feature_list) { feature_list->push_back(FD::Convert(fid)); } -    ++weight_count;      if (!SILENT) { -      if (weight_count %   50000 == 0) { cerr << '.' << flush; fl = true; } -      if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } +      if (fl) { cerr << endl; } +      cerr << "Loaded " << weight_count << " feature weights\n"; +    } +  } else {   // !read_text +    char buf[6]; +    in.get(buf, 6); +    size_t num_keys[2]; +    in.get(reinterpret_cast<char*>(&num_keys[0]), sizeof(size_t) + 1); +    if (num_keys[0] != FD::NumFeats()) { +      cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl; +      abort(); +    } +    wv_.resize(num_keys[0]); +    in.get(reinterpret_cast<char*>(&wv_[0]), num_keys[0] * sizeof(weight_t)); +    if (!in.good()) { +      cerr << "Error loading weights!\n"; +      abort();      } -  } -  if (!SILENT) { -    if (fl) { cerr << endl; } -    cerr << "Loaded " << weight_count << " feature weights\n";    }  } @@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature    WriteFile out(fname);    ostream& o = *out.stream();    assert(o); -  if (extra) { o << "# " << *extra << endl; } -  o.precision(17); -  const int num_feats = FD::NumFeats(); -  for (int i = 1; i < num_feats; ++i) { -    const double val = (i < wv_.size() ? wv_[i] : 0.0); -    if (hide_zero_value_features && val == 0.0) continue; -    o << FD::Convert(i) << ' ' << val << endl; +  bool write_text = !FD::UsingPerfectHashFunction(); + +  if (write_text) { +    if (extra) { o << "# " << *extra << endl; } +    o.precision(17); +    const int num_feats = FD::NumFeats(); +    for (int i = 1; i < num_feats; ++i) { +      const weight_t val = (i < wv_.size() ? wv_[i] : 0.0); +      if (hide_zero_value_features && val == 0.0) continue; +      o << FD::Convert(i) << ' ' << val << endl; +    } +  } else { +    o.write("_PHWf", 5); +    const size_t keys = FD::NumFeats(); +    assert(keys <= wv_.size()); +    o.write(reinterpret_cast<const char*>(&keys), sizeof(keys)); +    o.write(reinterpret_cast<const char*>(&wv_[0]), keys * sizeof(weight_t));    }  } -void Weights::InitVector(std::vector<double>* w) const { +void Weights::InitVector(std::vector<weight_t>* w) const {    *w = wv_;  } -void Weights::InitSparseVector(SparseVector<double>* w) const { +void Weights::InitSparseVector(SparseVector<weight_t>* w) const {    for (int i = 1; i < wv_.size(); ++i) { -    const double& weight = wv_[i]; +    const weight_t& weight = wv_[i];      if (weight) w->set_value(i, weight);    }  } -void Weights::InitFromVector(const std::vector<double>& w) { +void Weights::InitFromVector(const std::vector<weight_t>& w) {    wv_ = w;    if (wv_.size() > FD::NumFeats())      cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";    wv_.resize(FD::NumFeats(), 0);  } -void Weights::InitFromVector(const SparseVector<double>& w) { +void Weights::InitFromVector(const SparseVector<weight_t>& w) {    wv_.clear();    wv_.resize(FD::NumFeats(), 0.0);    for (int i = 1; i < FD::NumFeats(); ++i)      wv_[i] = w.value(i);  } + diff --git a/utils/weights.h b/utils/weights.h index cc20283c..7664810b 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -2,21 +2,23 @@  #define _WEIGHTS_H_  #include <string> -#include <map>  #include <vector>  #include "sparse_vector.h" +// warning: in the future this will become float +typedef double weight_t; +  class Weights {   public:    Weights() {}    void InitFromFile(const std::string& fname, std::vector<std::string>* feature_list = NULL);    void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const; -  void InitVector(std::vector<double>* w) const; -  void InitSparseVector(SparseVector<double>* w) const; -  void InitFromVector(const std::vector<double>& w); -  void InitFromVector(const SparseVector<double>& w); +  void InitVector(std::vector<weight_t>* w) const; +  void InitSparseVector(SparseVector<weight_t>* w) const; +  void InitFromVector(const std::vector<weight_t>& w); +  void InitFromVector(const SparseVector<weight_t>& w);   private: -  std::vector<double> wv_; +  std::vector<weight_t> wv_;  };  #endif | 
