summaryrefslogtreecommitdiff
path: root/utils/weights.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-20 02:31:25 +0200
committerPatrick Simianer <p@simianer.de>2011-10-20 02:31:25 +0200
commita5a92ebe23c5819ed104313426012011e32539da (patch)
tree3416818c758d5ece4e71fe522c571e75ea04f100 /utils/weights.cc
parentb88332caac2cbe737c99b8098813f868ca876d8b (diff)
parent78baccbb4231bb84a456702d4f574f8e601a8182 (diff)
finalized merge
Diffstat (limited to 'utils/weights.cc')
-rw-r--r--utils/weights.cc195
1 files changed, 122 insertions, 73 deletions
diff --git a/utils/weights.cc b/utils/weights.cc
index 6b7e58ed..ac407dfb 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -8,101 +8,150 @@
using namespace std;
-void Weights::InitFromFile(const std::string& filename, vector<string>* feature_list) {
+void Weights::InitFromFile(const string& filename,
+ vector<weight_t>* pweights,
+ vector<string>* feature_list) {
+ vector<weight_t>& weights = *pweights;
if (!SILENT) cerr << "Reading weights from " << filename << endl;
ReadFile in_file(filename);
istream& in = *in_file.stream();
assert(in);
- int weight_count = 0;
- bool fl = false;
- string buf;
- double val = 0;
- while (in) {
- getline(in, buf);
- if (buf.size() == 0) continue;
- if (buf[0] == '#') continue;
- for (int i = 0; i < buf.size(); ++i)
- if (buf[i] == '=') buf[i] = ' ';
- int start = 0;
- while(start < buf.size() && buf[start] == ' ') ++start;
- int end = 0;
- while(end < buf.size() && buf[end] != ' ') ++end;
- const int fid = FD::Convert(buf.substr(start, end - start));
- while(end < buf.size() && buf[end] == ' ') ++end;
- val = strtod(&buf.c_str()[end], NULL);
- if (isnan(val)) {
- cerr << FD::Convert(fid) << " has weight NaN!\n";
- abort();
+
+ bool read_text = true;
+ if (1) {
+ ReadFile hdrrf(filename);
+ istream& hi = *hdrrf.stream();
+ assert(hi);
+ char buf[10];
+ hi.read(buf, 5);
+ assert(hi.good());
+ if (strncmp(buf, "_PHWf", 5) == 0) {
+ read_text = false;
+ }
+ }
+
+ if (read_text) {
+ int weight_count = 0;
+ bool fl = false;
+ string buf;
+ weight_t val = 0;
+ while (in) {
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ if (buf[0] == '#') continue;
+ if (buf[0] == ' ') {
+ cerr << "Weights file lines may not start with whitespace.\n" << buf << endl;
+ abort();
+ }
+ for (int i = buf.size() - 1; i > 0; --i)
+ if (buf[i] == '=' || buf[i] == '\t') { buf[i] = ' '; break; }
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = 0;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ const int fid = FD::Convert(buf.substr(start, end - start));
+ if (feature_list) { feature_list->push_back(buf.substr(start, end - start)); }
+ while(end < buf.size() && buf[end] == ' ') ++end;
+ val = strtod(&buf.c_str()[end], NULL);
+ if (isnan(val)) {
+ cerr << FD::Convert(fid) << " has weight NaN!\n";
+ abort();
+ }
+ if (weights.size() <= fid)
+ weights.resize(fid + 1);
+ weights[fid] = val;
+ ++weight_count;
+ if (!SILENT) {
+ if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ }
}
- if (wv_.size() <= fid)
- wv_.resize(fid + 1);
- wv_[fid] = val;
- if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
- ++weight_count;
if (!SILENT) {
- if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
- if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ if (fl) { cerr << endl; }
+ cerr << "Loaded " << weight_count << " feature weights\n";
+ }
+ } else { // !read_text
+ char buf[6];
+ in.read(buf, 5);
+ size_t num_keys;
+ in.read(reinterpret_cast<char*>(&num_keys), sizeof(size_t));
+ if (num_keys != FD::NumFeats()) {
+ cerr << "Hash function reports " << FD::NumFeats() << " keys but weights file contains " << num_keys << endl;
+ abort();
+ }
+ weights.resize(num_keys);
+ in.read(reinterpret_cast<char*>(&weights.front()), num_keys * sizeof(weight_t));
+ if (!in.good()) {
+ cerr << "Error loading weights!\n";
+ abort();
+ } else {
+ cerr << " Successfully loaded " << (num_keys * sizeof(weight_t)) << " bytes\n";
}
- }
- if (!SILENT) {
- if (fl) { cerr << endl; }
- cerr << "Loaded " << weight_count << " feature weights\n";
}
}
-void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features, const string* extra) const {
+void Weights::WriteToFile(const string& fname,
+ const vector<weight_t>& weights,
+ bool hide_zero_value_features,
+ const string* extra) {
WriteFile out(fname);
ostream& o = *out.stream();
assert(o);
- if (extra) { o << "# " << *extra << endl; }
- o.precision(17);
- const int num_feats = FD::NumFeats();
- for (int i = 1; i < num_feats; ++i) {
- const double val = (i < wv_.size() ? wv_[i] : 0.0);
- if (hide_zero_value_features && val == 0.0) continue;
- o << FD::Convert(i) << ' ' << val << endl;
- }
-}
+ bool write_text = !FD::UsingPerfectHashFunction();
-void Weights::InitVector(std::vector<double>* w) const {
- *w = wv_;
-}
-
-void Weights::InitSparseVector(SparseVector<double>* w) const {
- for (int i = 1; i < wv_.size(); ++i) {
- const double& weight = wv_[i];
- if (weight) w->set_value(i, weight);
+ if (write_text) {
+ if (extra) { o << "# " << *extra << endl; }
+ o.precision(17);
+ const int num_feats = FD::NumFeats();
+ for (int i = 1; i < num_feats; ++i) {
+ const weight_t val = (i < weights.size() ? weights[i] : 0.0);
+ if (hide_zero_value_features && val == 0.0) continue;
+ o << FD::Convert(i) << ' ' << val << endl;
+ }
+ } else {
+ o.write("_PHWf", 5);
+ const size_t keys = FD::NumFeats();
+ assert(keys <= weights.size());
+ o.write(reinterpret_cast<const char*>(&keys), sizeof(keys));
+ o.write(reinterpret_cast<const char*>(&weights[0]), keys * sizeof(weight_t));
}
}
-void Weights::InitFromVector(const std::vector<double>& w) {
- wv_ = w;
- if (wv_.size() > FD::NumFeats())
- cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
- wv_.resize(FD::NumFeats(), 0);
+void Weights::InitSparseVector(const vector<weight_t>& dv,
+ SparseVector<weight_t>* sv) {
+ sv->clear();
+ for (unsigned i = 1; i < dv.size(); ++i) {
+ if (dv[i]) sv->set_value(i, dv[i]);
+ }
}
-void Weights::InitFromVector(const SparseVector<double>& w) {
- wv_.clear();
- wv_.resize(FD::NumFeats(), 0.0);
- for (int i = 1; i < FD::NumFeats(); ++i)
- wv_[i] = w.value(i);
+void Weights::SanityCheck(const vector<weight_t>& w) {
+ for (int i = 0; i < w.size(); ++i) {
+ assert(!isnan(w[i]));
+ assert(!isinf(w[i]));
+ }
}
-void Weights::SetWeight(SparseVector<double>* v, const string fname, const double w) {
- WordID fid = FD::Convert(fname);
- cout << "fid " << fid << endl;
- SetWeight(v, fid, w);
-}
+struct FComp {
+ const vector<weight_t>& w_;
+ FComp(const vector<weight_t>& w) : w_(w) {}
+ bool operator()(int a, int b) const {
+ return fabs(w_[a]) > fabs(w_[b]);
+ }
+};
-void Weights::SetWeight(SparseVector<double>* v, const WordID fid, const double w) {
- wv_.resize(FD::NumFeats(), 0.0);
- wv_[fid] = w;
- //v->set_value(fid, w);
+void Weights::ShowLargestFeatures(const vector<weight_t>& w) {
+ vector<int> fnums(w.size());
+ for (int i = 0; i < w.size(); ++i)
+ fnums[i] = i;
+ vector<int>::iterator mid = fnums.begin();
+ mid += (w.size() > 10 ? 10 : w.size());
+ partial_sort(fnums.begin(), mid, fnums.end(), FComp(w));
+ cerr << "TOP FEATURES:";
+ for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) {
+ cerr << ' ' << FD::Convert(*i) << '=' << w[*i];
+ }
+ cerr << endl;
}
-void Weights::sz()
-{
- cout << "wv_.size() " << wv_.size() << endl;
-}