#include "sparse_vector.h"

#include <iostream>
#include <cstring>

#include "b64tools.h"

using namespace std;

namespace B64 {

void Encode(double objective, const SparseVector<double>& v, ostream* out) {
  const int num_feats = v.num_active();
  size_t tot_size = 0;
  const size_t off_objective = tot_size;
  tot_size += sizeof(double);                   // objective
  const size_t off_num_feats = tot_size;
  tot_size += sizeof(int);                      // num_feats
  const size_t off_data = tot_size;
  tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names;
  typedef SparseVector<double>::const_iterator const_iterator;
  for (const_iterator it = v.begin(); it != v.end(); ++it)
    tot_size += FD::Convert(it->first).size();   // feature names;
  tot_size += sizeof(double) * num_feats;        // gradient
  const size_t off_magic = tot_size;
  tot_size += 4;                                 // magic

  // size_t b64_size = tot_size * 4 / 3;
  // cerr << "Sparse vector binary size: " << tot_size << "  (b64 size=" << b64_size << ")\n";
  char* data = new char[tot_size];
  *reinterpret_cast<double*>(&data[off_objective]) = objective;
  *reinterpret_cast<int*>(&data[off_num_feats]) = num_feats;
  char* cur = &data[off_data];
  assert(cur - data == off_data);
  for (const_iterator it = v.begin(); it != v.end(); ++it) {
    const string& fname = FD::Convert(it->first);
    *cur++ = static_cast<char>(fname.size());   // name len
    memcpy(cur, &fname[0], fname.size());
    cur += fname.size();
    *reinterpret_cast<double*>(cur) = it->second;
    cur += sizeof(double);
  }
  assert(cur - data == off_magic);
  *reinterpret_cast<unsigned int*>(cur) = 0xBAABABBAu;
  cur += sizeof(unsigned int);
  assert(cur - data == tot_size);
  b64encode(data, tot_size, out);
  delete[] data;
}

bool Decode(double* objective, SparseVector<double>* v, const char* in, size_t size) {
  v->clear();
  if (size % 4 != 0) {
    cerr << "B64 error - line % 4 != 0\n";
    return false;
  }
  const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int);
  const size_t buf_size = decoded_size + sizeof(unsigned int);
  if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; }
  char* data = new char[buf_size];
  if (!b64decode(reinterpret_cast<const unsigned char*>(in), size, data, buf_size)) {
    delete[] data;
    return false;
  }
  size_t cur = 0;
  *objective = *reinterpret_cast<double*>(data);
  cur += sizeof(double);
  const int num_feats = *reinterpret_cast<int*>(&data[cur]);
  cur += sizeof(int);
  int fc = 0;
  while(fc < num_feats && cur < decoded_size) {
    ++fc;
    const int fname_len = data[cur++];
    assert(fname_len > 0);
    assert(fname_len < 256);
    string fname(fname_len, '\0');
    memcpy(&fname[0], &data[cur], fname_len);
    cur += fname_len;
    const double val = *reinterpret_cast<double*>(&data[cur]);
    cur += sizeof(double);
    int fid = FD::Convert(fname);
    v->set_value(fid, val);
  }
  if(num_feats != fc) {
    cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n";
    delete[] data;
    return false;
  }
  if (*reinterpret_cast<unsigned int*>(&data[cur]) != 0xBAABABBAu) {
    cerr << "SparseVector decodeding error : magic does not match!\n";
    delete[] data;
    return false;
  }
  delete[] data;
  return true;
}

}