summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-13 13:25:46 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-13 13:25:46 +0100
commite7993fb83537105a56c274d78ed9d51a79a8a854 (patch)
treeda88ab86e1173025c40113c303f08e02e6476c14
parentc41704e876930311539f0cfb5f5125f3401d08ae (diff)
optional support for doing perfect hashing of feature strings to save lots of memory
-rw-r--r--decoder/decoder.cc22
-rw-r--r--utils/Makefile.am9
-rw-r--r--utils/fdict.cc4
-rw-r--r--utils/fdict.h36
-rw-r--r--utils/perfect_hash.cc37
-rw-r--r--utils/perfect_hash.h24
-rw-r--r--utils/phmt.cc44
-rw-r--r--utils/weights.cc132
-rw-r--r--utils/weights.h14
9 files changed, 269 insertions, 53 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 76f31352..25eb2de4 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -328,6 +328,7 @@ struct DecoderImpl {
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
+ bool remove_intersected_rule_annotations;
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
@@ -361,6 +362,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
("list_feature_functions,L","List available feature functions")
+#ifdef HAVE_CMPH
+ ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
+#endif
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
@@ -433,7 +437,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
- ("forest_output,O",po::value<string>(),"Directory to write forests to");
+ ("forest_output,O",po::value<string>(),"Directory to write forests to")
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
@@ -443,7 +448,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
- ("help,h", "Print this help message and exit")
+ ("help,?", "Print this help message and exit")
("usage,u", po::value<string>(), "Describe a feature function type")
("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")
;
@@ -645,6 +650,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
FD::Freeze(); // this means we can't see the feature names of not-weighted features
}
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
+ }
+
// set up translation back end
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
@@ -695,6 +706,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
+ remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
#ifdef FSA_RESCORING
cfg_options.Validate();
@@ -1010,6 +1022,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
+ if (remove_intersected_rule_annotations) {
+ for (unsigned i = 0; i < forest.edges_.size(); ++i)
+ if (forest.edges_[i].rule_ &&
+ forest.edges_[i].rule_->parent_rule_)
+ forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_;
+ }
forest.Reweight(last_weights);
if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
diff --git a/utils/Makefile.am b/utils/Makefile.am
index 94f9be30..c50747bf 100644
--- a/utils/Makefile.am
+++ b/utils/Makefile.am
@@ -1,5 +1,5 @@
-noinst_PROGRAMS = ts
-TESTS = ts
+noinst_PROGRAMS = ts phmt
+TESTS = ts phmt
if HAVE_GTEST
noinst_PROGRAMS += \
@@ -27,6 +27,11 @@ libutils_a_SOURCES = \
verbose.cc \
weights.cc
+if HAVE_CMPH
+ libutils_a_SOURCES += perfect_hash.cc
+endif
+
+phmt_SOURCES = phmt.cc
ts_SOURCES = ts.cc
dict_test_SOURCES = dict_test.cc
dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS)
diff --git a/utils/fdict.cc b/utils/fdict.cc
index baa0b552..676c951c 100644
--- a/utils/fdict.cc
+++ b/utils/fdict.cc
@@ -9,6 +9,10 @@ using namespace std;
Dict FD::dict_;
bool FD::frozen_ = false;
+#ifdef HAVE_CMPH
+PerfectHashFunction* FD::hash_ = NULL;
+#endif
+
std::string FD::Convert(std::vector<WordID> const& v) {
return Convert(&*v.begin(),&*v.end());
}
diff --git a/utils/fdict.h b/utils/fdict.h
index f9673023..771e8b91 100644
--- a/utils/fdict.h
+++ b/utils/fdict.h
@@ -1,23 +1,56 @@
#ifndef _FDICT_H_
#define _FDICT_H_
+#include "config.h"
+
+#include <iostream>
#include <string>
#include <vector>
#include "dict.h"
+#ifdef HAVE_CMPH
+#include "perfect_hash.h"
+#include "string_to.h"
+#endif
+
struct FD {
// once the FD is frozen, new features not already in the
// dictionary will return 0
static void Freeze() {
frozen_ = true;
}
+ static bool UsingPerfectHashFunction() {
+#ifdef HAVE_CMPH
+ return hash_;
+#else
+ return false;
+#endif
+ }
+ static void EnableHash(const std::string& cmph_file) {
+#ifdef HAVE_CMPH
+ hash_ = new PerfectHashFunction(cmph_file);
+#endif
+ }
static inline int NumFeats() {
+#ifdef HAVE_CMPH
+ if (hash_) return hash_->number_of_keys();
+#endif
return dict_.max() + 1;
}
static inline WordID Convert(const std::string& s) {
+#ifdef HAVE_CMPH
+ if (hash_) return (*hash_)(s);
+#endif
return dict_.Convert(s, frozen_);
}
static inline const std::string& Convert(const WordID& w) {
+#ifdef HAVE_CMPH
+ if (hash_) {
+ static std::string tls;
+ tls = to_string(w);
+ return tls;
+ }
+#endif
return dict_.Convert(w);
}
static std::string Convert(WordID const *i,WordID const* e);
@@ -29,6 +62,9 @@ struct FD {
static Dict dict_;
private:
static bool frozen_;
+#ifdef HAVE_CMPH
+ static PerfectHashFunction* hash_;
+#endif
};
#endif
diff --git a/utils/perfect_hash.cc b/utils/perfect_hash.cc
new file mode 100644
index 00000000..706e2741
--- /dev/null
+++ b/utils/perfect_hash.cc
@@ -0,0 +1,37 @@
+#include "config.h"
+
+#ifdef HAVE_CMPH
+
+#include "perfect_hash.h"
+
+#include <cstdio>
+#include <iostream>
+
+using namespace std;
+
+PerfectHashFunction::~PerfectHashFunction() {
+ cmph_destroy(mphf_);
+}
+
+PerfectHashFunction::PerfectHashFunction(const string& fname) {
+ FILE* f = fopen(fname.c_str(), "r");
+ if (!f) {
+ cerr << "Failed to open file " << fname << " for reading: cannot load hash function.\n";
+ abort();
+ }
+ mphf_ = cmph_load(f);
+ if (!mphf_) {
+ cerr << "cmph_load failed on " << fname << "!\n";
+ abort();
+ }
+}
+
+size_t PerfectHashFunction::operator()(const string& key) const {
+ return cmph_search(mphf_, &key[0], key.size());
+}
+
+size_t PerfectHashFunction::number_of_keys() const {
+ return cmph_size(mphf_);
+}
+
+#endif
diff --git a/utils/perfect_hash.h b/utils/perfect_hash.h
new file mode 100644
index 00000000..8ac11f18
--- /dev/null
+++ b/utils/perfect_hash.h
@@ -0,0 +1,24 @@
+#ifndef _PERFECT_HASH_MAP_H_
+#define _PERFECT_HASH_MAP_H_
+
+#include "config.h"
+
+#ifndef HAVE_CMPH
+#error libcmph is required to use PerfectHashFunction
+#endif
+
+#include <vector>
+#include <boost/utility.hpp>
+#include "cmph.h"
+
+class PerfectHashFunction : boost::noncopyable {
+ public:
+ explicit PerfectHashFunction(const std::string& fname);
+ ~PerfectHashFunction();
+ size_t operator()(const std::string& key) const;
+ size_t number_of_keys() const;
+ private:
+ cmph_t *mphf_;
+};
+
+#endif
diff --git a/utils/phmt.cc b/utils/phmt.cc
new file mode 100644
index 00000000..1f59afaf
--- /dev/null
+++ b/utils/phmt.cc
@@ -0,0 +1,44 @@
+#include "config.h"
+
+#ifndef HAVE_CMPH
+int main() {
+ return 0;
+}
+#else
+
+#include <iostream>
+#include "weights.h"
+#include "fdict.h"
+
+using namespace std;
+
+int main(int argc, char** argv) {
+ if (argc != 2) { cerr << "Usage: " << argv[0] << " file.mphf\n"; return 1; }
+ FD::EnableHash(argv[1]);
+ cerr << "Number of keys: " << FD::NumFeats() << endl;
+ cerr << "LexFE = " << FD::Convert("LexFE") << endl;
+ cerr << "LexEF = " << FD::Convert("LexEF") << endl;
+ {
+ Weights w;
+ vector<weight_t> v(FD::NumFeats());
+ v[FD::Convert("LexFE")] = 1.0;
+ v[FD::Convert("LexEF")] = 0.5;
+ w.InitFromVector(v);
+ cerr << "Writing...\n";
+ w.WriteToFile("weights.bin");
+ cerr << "Done.\n";
+ }
+ {
+ Weights w;
+ vector<weight_t> v(FD::NumFeats());
+ cerr << "Reading...\n";
+ w.InitFromFile("weights.bin");
+ cerr << "Done.\n";
+ w.InitVector(&v);
+ assert(v[FD::Convert("LexFE")] == 1.0);
+ assert(v[FD::Convert("LexEF")] == 0.5);
+ }
+}
+
+#endif
+
diff --git a/utils/weights.cc b/utils/weights.cc
index b994a2fe..0916b72a 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector<string>* feature_
ReadFile in_file(filename);
istream& in = *in_file.stream();
assert(in);
- int weight_count = 0;
- bool fl = false;
- string buf;
- double val = 0;
- while (in) {
- getline(in, buf);
- if (buf.size() == 0) continue;
- if (buf[0] == '#') continue;
- for (int i = 0; i < buf.size(); ++i)
- if (buf[i] == '=') buf[i] = ' ';
- int start = 0;
- while(start < buf.size() && buf[start] == ' ') ++start;
- int end = 0;
- while(end < buf.size() && buf[end] != ' ') ++end;
- const int fid = FD::Convert(buf.substr(start, end - start));
- while(end < buf.size() && buf[end] == ' ') ++end;
- val = strtod(&buf.c_str()[end], NULL);
- if (isnan(val)) {
- cerr << FD::Convert(fid) << " has weight NaN!\n";
- abort();
+
+ bool read_text = true;
+ if (1) {
+ ReadFile hdrrf(filename);
+ istream& hi = *hdrrf.stream();
+ assert(hi);
+ char buf[10];
+ hi.get(buf, 6);
+ assert(hi.good());
+ if (strncmp(buf, "_PHWf", 5) == 0) {
+ read_text = false;
+ }
+ }
+
+ if (read_text) {
+ int weight_count = 0;
+ bool fl = false;
+ string buf;
+ weight_t val = 0;
+ while (in) {
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ if (buf[0] == '#') continue;
+ if (buf[0] == ' ') {
+ cerr << "Weights file lines may not start with whitespace.\n" << buf << endl;
+ abort();
+ }
+ for (int i = buf.size() - 1; i > 0; --i)
+ if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; }
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = 0;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ const int fid = FD::Convert(buf.substr(start, end - start));
+ while(end < buf.size() && buf[end] == ' ') ++end;
+ val = strtod(&buf.c_str()[end], NULL);
+ if (isnan(val)) {
+ cerr << FD::Convert(fid) << " has weight NaN!\n";
+ abort();
+ }
+ if (wv_.size() <= fid)
+ wv_.resize(fid + 1);
+ wv_[fid] = val;
+ if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
+ ++weight_count;
+ if (!SILENT) {
+ if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ }
}
- if (wv_.size() <= fid)
- wv_.resize(fid + 1);
- wv_[fid] = val;
- if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
- ++weight_count;
if (!SILENT) {
- if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
- if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ if (fl) { cerr << endl; }
+ cerr << "Loaded " << weight_count << " feature weights\n";
+ }
+ } else { // !read_text
+ char buf[6];
+ in.get(buf, 6);
+ size_t num_keys[2];
+ in.get(reinterpret_cast<char*>(&num_keys[0]), sizeof(size_t) + 1);
+ if (num_keys[0] != FD::NumFeats()) {
+ cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl;
+ abort();
+ }
+ wv_.resize(num_keys[0]);
+ in.get(reinterpret_cast<char*>(&wv_[0]), num_keys[0] * sizeof(weight_t));
+ if (!in.good()) {
+ cerr << "Error loading weights!\n";
+ abort();
}
- }
- if (!SILENT) {
- if (fl) { cerr << endl; }
- cerr << "Loaded " << weight_count << " feature weights\n";
}
}
@@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature
WriteFile out(fname);
ostream& o = *out.stream();
assert(o);
- if (extra) { o << "# " << *extra << endl; }
- o.precision(17);
- const int num_feats = FD::NumFeats();
- for (int i = 1; i < num_feats; ++i) {
- const double val = (i < wv_.size() ? wv_[i] : 0.0);
- if (hide_zero_value_features && val == 0.0) continue;
- o << FD::Convert(i) << ' ' << val << endl;
+ bool write_text = !FD::UsingPerfectHashFunction();
+
+ if (write_text) {
+ if (extra) { o << "# " << *extra << endl; }
+ o.precision(17);
+ const int num_feats = FD::NumFeats();
+ for (int i = 1; i < num_feats; ++i) {
+ const weight_t val = (i < wv_.size() ? wv_[i] : 0.0);
+ if (hide_zero_value_features && val == 0.0) continue;
+ o << FD::Convert(i) << ' ' << val << endl;
+ }
+ } else {
+ o.write("_PHWf", 5);
+ const size_t keys = FD::NumFeats();
+ assert(keys <= wv_.size());
+ o.write(reinterpret_cast<const char*>(&keys), sizeof(keys));
+ o.write(reinterpret_cast<const char*>(&wv_[0]), keys * sizeof(weight_t));
}
}
-void Weights::InitVector(std::vector<double>* w) const {
+void Weights::InitVector(std::vector<weight_t>* w) const {
*w = wv_;
}
-void Weights::InitSparseVector(SparseVector<double>* w) const {
+void Weights::InitSparseVector(SparseVector<weight_t>* w) const {
for (int i = 1; i < wv_.size(); ++i) {
- const double& weight = wv_[i];
+ const weight_t& weight = wv_[i];
if (weight) w->set_value(i, weight);
}
}
-void Weights::InitFromVector(const std::vector<double>& w) {
+void Weights::InitFromVector(const std::vector<weight_t>& w) {
wv_ = w;
if (wv_.size() > FD::NumFeats())
cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
wv_.resize(FD::NumFeats(), 0);
}
-void Weights::InitFromVector(const SparseVector<double>& w) {
+void Weights::InitFromVector(const SparseVector<weight_t>& w) {
wv_.clear();
wv_.resize(FD::NumFeats(), 0.0);
for (int i = 1; i < FD::NumFeats(); ++i)
wv_[i] = w.value(i);
}
+
diff --git a/utils/weights.h b/utils/weights.h
index cc20283c..7664810b 100644
--- a/utils/weights.h
+++ b/utils/weights.h
@@ -2,21 +2,23 @@
#define _WEIGHTS_H_
#include <string>
-#include <map>
#include <vector>
#include "sparse_vector.h"
+// warning: in the future this will become float
+typedef double weight_t;
+
class Weights {
public:
Weights() {}
void InitFromFile(const std::string& fname, std::vector<std::string>* feature_list = NULL);
void WriteToFile(const std::string& fname, bool hide_zero_value_features = true, const std::string* extra = NULL) const;
- void InitVector(std::vector<double>* w) const;
- void InitSparseVector(SparseVector<double>* w) const;
- void InitFromVector(const std::vector<double>& w);
- void InitFromVector(const SparseVector<double>& w);
+ void InitVector(std::vector<weight_t>* w) const;
+ void InitSparseVector(SparseVector<weight_t>* w) const;
+ void InitFromVector(const std::vector<weight_t>& w);
+ void InitFromVector(const SparseVector<weight_t>& w);
private:
- std::vector<double> wv_;
+ std::vector<weight_t> wv_;
};
#endif