summaryrefslogtreecommitdiff
path: root/utils/weights.cc
diff options
context:
space:
mode:
Diffstat (limited to 'utils/weights.cc')
-rw-r--r--utils/weights.cc132
1 files changed, 89 insertions, 43 deletions
diff --git a/utils/weights.cc b/utils/weights.cc
index b994a2fe..0916b72a 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -13,40 +13,75 @@ void Weights::InitFromFile(const std::string& filename, vector<string>* feature_
ReadFile in_file(filename);
istream& in = *in_file.stream();
assert(in);
- int weight_count = 0;
- bool fl = false;
- string buf;
- double val = 0;
- while (in) {
- getline(in, buf);
- if (buf.size() == 0) continue;
- if (buf[0] == '#') continue;
- for (int i = 0; i < buf.size(); ++i)
- if (buf[i] == '=') buf[i] = ' ';
- int start = 0;
- while(start < buf.size() && buf[start] == ' ') ++start;
- int end = 0;
- while(end < buf.size() && buf[end] != ' ') ++end;
- const int fid = FD::Convert(buf.substr(start, end - start));
- while(end < buf.size() && buf[end] == ' ') ++end;
- val = strtod(&buf.c_str()[end], NULL);
- if (isnan(val)) {
- cerr << FD::Convert(fid) << " has weight NaN!\n";
- abort();
+
+ bool read_text = true;
+ if (1) {
+ ReadFile hdrrf(filename);
+ istream& hi = *hdrrf.stream();
+ assert(hi);
+ char buf[10];
+ hi.get(buf, 6);
+ assert(hi.good());
+ if (strncmp(buf, "_PHWf", 5) == 0) {
+ read_text = false;
+ }
+ }
+
+ if (read_text) {
+ int weight_count = 0;
+ bool fl = false;
+ string buf;
+ weight_t val = 0;
+ while (in) {
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ if (buf[0] == '#') continue;
+ if (buf[0] == ' ') {
+ cerr << "Weights file lines may not start with whitespace.\n" << buf << endl;
+ abort();
+ }
+ for (int i = buf.size() - 1; i > 0; --i)
+ if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; }
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = 0;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ const int fid = FD::Convert(buf.substr(start, end - start));
+ while(end < buf.size() && buf[end] == ' ') ++end;
+ val = strtod(&buf.c_str()[end], NULL);
+ if (isnan(val)) {
+ cerr << FD::Convert(fid) << " has weight NaN!\n";
+ abort();
+ }
+ if (wv_.size() <= fid)
+ wv_.resize(fid + 1);
+ wv_[fid] = val;
+ if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
+ ++weight_count;
+ if (!SILENT) {
+ if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ }
}
- if (wv_.size() <= fid)
- wv_.resize(fid + 1);
- wv_[fid] = val;
- if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
- ++weight_count;
if (!SILENT) {
- if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
- if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ if (fl) { cerr << endl; }
+ cerr << "Loaded " << weight_count << " feature weights\n";
+ }
+ } else { // !read_text
+ char buf[6];
+ in.get(buf, 6);
+ size_t num_keys[2];
+ in.get(reinterpret_cast<char*>(&num_keys[0]), sizeof(size_t) + 1);
+ if (num_keys[0] != FD::NumFeats()) {
+ cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys[0] << endl;
+ abort();
+ }
+ wv_.resize(num_keys[0]);
+ in.get(reinterpret_cast<char*>(&wv_[0]), num_keys[0] * sizeof(weight_t));
+ if (!in.good()) {
+ cerr << "Error loading weights!\n";
+ abort();
}
- }
- if (!SILENT) {
- if (fl) { cerr << endl; }
- cerr << "Loaded " << weight_count << " feature weights\n";
}
}
@@ -54,37 +89,48 @@ void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_feature
WriteFile out(fname);
ostream& o = *out.stream();
assert(o);
- if (extra) { o << "# " << *extra << endl; }
- o.precision(17);
- const int num_feats = FD::NumFeats();
- for (int i = 1; i < num_feats; ++i) {
- const double val = (i < wv_.size() ? wv_[i] : 0.0);
- if (hide_zero_value_features && val == 0.0) continue;
- o << FD::Convert(i) << ' ' << val << endl;
+ bool write_text = !FD::UsingPerfectHashFunction();
+
+ if (write_text) {
+ if (extra) { o << "# " << *extra << endl; }
+ o.precision(17);
+ const int num_feats = FD::NumFeats();
+ for (int i = 1; i < num_feats; ++i) {
+ const weight_t val = (i < wv_.size() ? wv_[i] : 0.0);
+ if (hide_zero_value_features && val == 0.0) continue;
+ o << FD::Convert(i) << ' ' << val << endl;
+ }
+ } else {
+ o.write("_PHWf", 5);
+ const size_t keys = FD::NumFeats();
+ assert(keys <= wv_.size());
+ o.write(reinterpret_cast<const char*>(&keys), sizeof(keys));
+ o.write(reinterpret_cast<const char*>(&wv_[0]), keys * sizeof(weight_t));
}
}
-void Weights::InitVector(std::vector<double>* w) const {
+void Weights::InitVector(std::vector<weight_t>* w) const {
*w = wv_;
}
-void Weights::InitSparseVector(SparseVector<double>* w) const {
+void Weights::InitSparseVector(SparseVector<weight_t>* w) const {
for (int i = 1; i < wv_.size(); ++i) {
- const double& weight = wv_[i];
+ const weight_t& weight = wv_[i];
if (weight) w->set_value(i, weight);
}
}
-void Weights::InitFromVector(const std::vector<double>& w) {
+void Weights::InitFromVector(const std::vector<weight_t>& w) {
wv_ = w;
if (wv_.size() > FD::NumFeats())
cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
wv_.resize(FD::NumFeats(), 0);
}
-void Weights::InitFromVector(const SparseVector<double>& w) {
+void Weights::InitFromVector(const SparseVector<weight_t>& w) {
wv_.clear();
wv_.resize(FD::NumFeats(), 0.0);
for (int i = 1; i < FD::NumFeats(); ++i)
wv_[i] = w.value(i);
}
+