summaryrefslogtreecommitdiff
path: root/utils/weights.cc
blob: 6b7e58edfb247e9e1ae429b9f4bae6ded12af60f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include "weights.h"

#include <sstream>

#include "fdict.h"
#include "filelib.h"
#include "verbose.h"

using namespace std;

void Weights::InitFromFile(const std::string& filename, vector<string>* feature_list) {
  if (!SILENT) cerr << "Reading weights from " << filename << endl;
  ReadFile in_file(filename);
  istream& in = *in_file.stream();
  assert(in);
  int weight_count = 0;
  bool fl = false;
  string buf;
  double val = 0;
  while (in) {
    getline(in, buf);
    if (buf.size() == 0) continue;
    if (buf[0] == '#') continue;
    for (int i = 0; i < buf.size(); ++i)
      if (buf[i] == '=') buf[i] = ' ';
    int start = 0;
    while(start < buf.size() && buf[start] == ' ') ++start;
    int end = 0;
    while(end < buf.size() && buf[end] != ' ') ++end;
    const int fid = FD::Convert(buf.substr(start, end - start));
    while(end < buf.size() && buf[end] == ' ') ++end;
    val = strtod(&buf.c_str()[end], NULL);
    if (isnan(val)) {
      cerr << FD::Convert(fid) << " has weight NaN!\n";
     abort();
    }
    if (wv_.size() <= fid)
      wv_.resize(fid + 1);
    wv_[fid] = val;
    if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
    ++weight_count;
    if (!SILENT) {
      if (weight_count %   50000 == 0) { cerr << '.' << flush; fl = true; }
      if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
    }
  }
  if (!SILENT) {
    if (fl) { cerr << endl; }
    cerr << "Loaded " << weight_count << " feature weights\n";
  }
}

void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features, const string* extra) const {
  WriteFile out(fname);
  ostream& o = *out.stream();
  assert(o);
  if (extra) { o << "# " << *extra << endl; }
  o.precision(17);
  const int num_feats = FD::NumFeats();
  for (int i = 1; i < num_feats; ++i) {
    const double val = (i < wv_.size() ? wv_[i] : 0.0);
    if (hide_zero_value_features && val == 0.0) continue;
    o << FD::Convert(i) << ' ' << val << endl;
  }
}

void Weights::InitVector(std::vector<double>* w) const {
  *w = wv_;
}

void Weights::InitSparseVector(SparseVector<double>* w) const {
  for (int i = 1; i < wv_.size(); ++i) {
    const double& weight = wv_[i];
    if (weight) w->set_value(i, weight);
  }
}

void Weights::InitFromVector(const std::vector<double>& w) {
  wv_ = w;
  if (wv_.size() > FD::NumFeats())
    cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
  wv_.resize(FD::NumFeats(), 0);
}

void Weights::InitFromVector(const SparseVector<double>& w) {
  wv_.clear();
  wv_.resize(FD::NumFeats(), 0.0);
  for (int i = 1; i < FD::NumFeats(); ++i)
    wv_[i] = w.value(i);
}

void Weights::SetWeight(SparseVector<double>* v, const string fname, const double w) {
  WordID fid = FD::Convert(fname);
  cout << "fid " << fid << endl;
  SetWeight(v, fid, w);
}

void Weights::SetWeight(SparseVector<double>* v, const WordID fid, const double w) {
  wv_.resize(FD::NumFeats(), 0.0);
  wv_[fid] = w;
  //v->set_value(fid, w); 
}

void Weights::sz()
{
  cout << "wv_.size() " << wv_.size() << endl;
}