summaryrefslogtreecommitdiff
path: root/utils/b64featvector.cc
blob: c7d08b298d9b938477d9c48f96c8d09d31b41438 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include "b64featvector.h"

#include <sstream>
#include <boost/scoped_array.hpp>
#include "b64tools.h"
#include "fdict.h"

using namespace std;

static inline void EncodeFeatureWeight(const string &featname, weight_t weight,
                                       ostream *output) {
  output->write(featname.data(), featname.size() + 1);
  output->write(reinterpret_cast<char *>(&weight), sizeof(weight_t));
}

string EncodeFeatureVector(const SparseVector<weight_t> &vec) {
  string b64;
  {
    ostringstream base64_strm;
    {
      ostringstream strm;
      for (SparseVector<weight_t>::const_iterator it = vec.begin();
           it != vec.end(); ++it)
        if (it->second != 0)
          EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm);
      string data(strm.str());
      B64::b64encode(data.data(), data.size(), &base64_strm);
    }
    b64 = base64_strm.str();
  }
  return b64;
}

void DecodeFeatureVector(const string &data, SparseVector<weight_t> *vec) {
  vec->clear();
  if (data.empty()) return;
  // Decode data
  size_t b64_len = data.size(), len = b64_len / 4 * 3;
  boost::scoped_array<char> buf(new char[len]);
  bool res =
      B64::b64decode(reinterpret_cast<const unsigned char *>(data.data()),
                     b64_len, buf.get(), len);
  assert(res);
  // Apply updates
  size_t cur = 0;
  while (cur < len) {
    string feat_name(buf.get() + cur);
    if (feat_name.empty()) break;  // Encountered trailing \0
    int feat_id = FD::Convert(feat_name);
    weight_t feat_delta =
        *reinterpret_cast<weight_t *>(buf.get() + cur + feat_name.size() + 1);
    (*vec)[feat_id] = feat_delta;
    cur += feat_name.size() + 1 + sizeof(weight_t);
  }
}