summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 12:46:02 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 12:46:02 -0700
commit769dfa1e69c22d4aea37840a955db7fd2cf3a4d7 (patch)
tree0f7cad7c0bb484af673e583c9906cfa6a2e09a10
parent895dfd64ea5599ab16981cbfb538ec5f4073c8c1 (diff)
Save/load weights in stream mira
-rw-r--r--training/mira/kbest_cut_mira.cc15
-rw-r--r--utils/weights.cc26
-rw-r--r--utils/weights.h5
3 files changed, 45 insertions, 1 deletions
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index a9a4aeb6..59fa860a 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -745,10 +745,23 @@ int main(int argc, char** argv) {
delim = buf.find(" ||| ");
ds->update(buf.substr(delim + 5));
buf = buf.substr(0, delim);
+ } else if (cmd == "WEIGHTS") {
+ // WEIGHTS ||| WRITE
+ if (buf == "WRITE") {
+ cout << Weights::GetString(dense_weights) << endl;
+ // WEIGHTS ||| f1=w1 f2=w2 ...
+ } else {
+ Weights::UpdateFromString(buf, dense_weights);
+ }
+ continue;
+ } else {
+ cerr << "Error: cannot parse command, skipping line:" << endl;
+ cerr << cmd << " ||| " << buf << endl;
+ continue;
}
- // TODO: additional commands
}
}
+ // Regular mode or LEARN line from stream mode
//TODO: allow batch updating
lambdas.init_vector(&dense_weights);
dense_w_local = dense_weights;
diff --git a/utils/weights.cc b/utils/weights.cc
index 575877b6..1284f686 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -4,6 +4,7 @@
#include "fdict.h"
#include "filelib.h"
+#include "stringlib.h"
#include "verbose.h"
using namespace std;
@@ -156,4 +157,29 @@ void Weights::ShowLargestFeatures(const vector<weight_t>& w) {
cerr << endl;
}
+string Weights::GetString(const vector<weight_t>& w,
+ bool hide_zero_value_features) {
+ ostringstream os;
+ os.precision(17);
+ int nf = FD::NumFeats();
+ for (unsigned i = 1; i < nf; i++) {
+ if (hide_zero_value_features && w[i] == 0.0) {
+ continue;
+ }
+ os << FD::Convert(i) << '=' << w[i];
+ if (i < nf - 1) {
+ os << ' ';
+ }
+ }
+ return os.str();
+}
+void Weights::UpdateFromString(string& w_string,
+ vector<weight_t>& w) {
+ vector<string> tok = SplitOnWhitespace(w_string);
+ for (vector<string>::iterator i = tok.begin(); i != tok.end(); i++) {
+ int delim = i->find('=');
+ int fid = FD::Convert(i->substr(0, delim));
+ w[fid] = strtod(i->substr(delim + 1).c_str(), NULL);
+ }
+}
diff --git a/utils/weights.h b/utils/weights.h
index 30f71db0..920fdd75 100644
--- a/utils/weights.h
+++ b/utils/weights.h
@@ -23,6 +23,11 @@ class Weights {
static void SanityCheck(const std::vector<weight_t>& w);
// write weights with largest magnitude to cerr
static void ShowLargestFeatures(const std::vector<weight_t>& w);
+ static std::string GetString(const std::vector<weight_t>& w,
+ bool hide_zero_value_features = true);
+ // Assumes weights are already initialized for now
+ static void UpdateFromString(std::string& w_string,
+ std::vector<weight_t>& w);
private:
Weights();
};