summaryrefslogtreecommitdiff
path: root/decoder/sparse_vector.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
commit7cc92b65a3185aa242088d830e166e495674efc9 (patch)
tree681fe5237612a4e96ce36fb9fabef00042c8ee61 /decoder/sparse_vector.cc
parent37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff)
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/sparse_vector.cc')
-rw-r--r--decoder/sparse_vector.cc98
1 files changed, 98 insertions, 0 deletions
diff --git a/decoder/sparse_vector.cc b/decoder/sparse_vector.cc
new file mode 100644
index 00000000..4035b9ef
--- /dev/null
+++ b/decoder/sparse_vector.cc
@@ -0,0 +1,98 @@
+#include "sparse_vector.h"
+
+#include <iostream>
+#include <cstring>
+
+#include "hg_io.h"
+
+using namespace std;
+
+namespace B64 {
+
+void Encode(double objective, const SparseVector<double>& v, ostream* out) {
+ const int num_feats = v.num_active();
+ size_t tot_size = 0;
+ const size_t off_objective = tot_size;
+ tot_size += sizeof(double); // objective
+ const size_t off_num_feats = tot_size;
+ tot_size += sizeof(int); // num_feats
+ const size_t off_data = tot_size;
+ tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names;
+ typedef SparseVector<double>::const_iterator const_iterator;
+ for (const_iterator it = v.begin(); it != v.end(); ++it)
+ tot_size += FD::Convert(it->first).size(); // feature names;
+ tot_size += sizeof(double) * num_feats; // gradient
+ const size_t off_magic = tot_size;
+ tot_size += 4; // magic
+
+ // size_t b64_size = tot_size * 4 / 3;
+ // cerr << "Sparse vector binary size: " << tot_size << " (b64 size=" << b64_size << ")\n";
+ char* data = new char[tot_size];
+ *reinterpret_cast<double*>(&data[off_objective]) = objective;
+ *reinterpret_cast<int*>(&data[off_num_feats]) = num_feats;
+ char* cur = &data[off_data];
+ assert(cur - data == off_data);
+ for (const_iterator it = v.begin(); it != v.end(); ++it) {
+ const string& fname = FD::Convert(it->first);
+ *cur++ = static_cast<char>(fname.size()); // name len
+ memcpy(cur, &fname[0], fname.size());
+ cur += fname.size();
+ *reinterpret_cast<double*>(cur) = it->second;
+ cur += sizeof(double);
+ }
+ assert(cur - data == off_magic);
+ *reinterpret_cast<unsigned int*>(cur) = 0xBAABABBAu;
+ cur += sizeof(unsigned int);
+ assert(cur - data == tot_size);
+ b64encode(data, tot_size, out);
+ delete[] data;
+}
+
+bool Decode(double* objective, SparseVector<double>* v, const char* in, size_t size) {
+ v->clear();
+ if (size % 4 != 0) {
+ cerr << "B64 error - line % 4 != 0\n";
+ return false;
+ }
+ const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int);
+ const size_t buf_size = decoded_size + sizeof(unsigned int);
+ if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; }
+ char* data = new char[buf_size];
+ if (!b64decode(reinterpret_cast<const unsigned char*>(in), size, data, buf_size)) {
+ delete[] data;
+ return false;
+ }
+ size_t cur = 0;
+ *objective = *reinterpret_cast<double*>(data);
+ cur += sizeof(double);
+ const int num_feats = *reinterpret_cast<int*>(&data[cur]);
+ cur += sizeof(int);
+ int fc = 0;
+ while(fc < num_feats && cur < decoded_size) {
+ ++fc;
+ const int fname_len = data[cur++];
+ assert(fname_len > 0);
+ assert(fname_len < 256);
+ string fname(fname_len, '\0');
+ memcpy(&fname[0], &data[cur], fname_len);
+ cur += fname_len;
+ const double val = *reinterpret_cast<double*>(&data[cur]);
+ cur += sizeof(double);
+ int fid = FD::Convert(fname);
+ v->set_value(fid, val);
+ }
+ if(num_feats != fc) {
+ cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n";
+ delete[] data;
+ return false;
+ }
+ if (*reinterpret_cast<unsigned int*>(&data[cur]) != 0xBAABABBAu) {
+ cerr << "SparseVector decodeding error : magic does not match!\n";
+ delete[] data;
+ return false;
+ }
+ delete[] data;
+ return true;
+}
+
+}