diff options
-rw-r--r-- | training/mira/kbest_cut_mira.cc | 15 | ||||
-rw-r--r-- | utils/weights.cc | 26 | ||||
-rw-r--r-- | utils/weights.h | 5 |
3 files changed, 45 insertions, 1 deletions
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index a9a4aeb6..59fa860a 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -745,10 +745,23 @@ int main(int argc, char** argv) { delim = buf.find(" ||| "); ds->update(buf.substr(delim + 5)); buf = buf.substr(0, delim); + } else if (cmd == "WEIGHTS") { + // WEIGHTS ||| WRITE + if (buf == "WRITE") { + cout << Weights::GetString(dense_weights) << endl; + // WEIGHTS ||| f1=w1 f2=w2 ... + } else { + Weights::UpdateFromString(buf, dense_weights); + } + continue; + } else { + cerr << "Error: cannot parse command, skipping line:" << endl; + cerr << cmd << " ||| " << buf << endl; + continue; } - // TODO: additional commands } } + // Regular mode or LEARN line from stream mode //TODO: allow batch updating lambdas.init_vector(&dense_weights); dense_w_local = dense_weights; diff --git a/utils/weights.cc b/utils/weights.cc index 575877b6..1284f686 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -4,6 +4,7 @@ #include "fdict.h" #include "filelib.h" +#include "stringlib.h" #include "verbose.h" using namespace std; @@ -156,4 +157,29 @@ void Weights::ShowLargestFeatures(const vector<weight_t>& w) { cerr << endl; } +string Weights::GetString(const vector<weight_t>& w, + bool hide_zero_value_features) { + ostringstream os; + os.precision(17); + int nf = FD::NumFeats(); + for (unsigned i = 1; i < nf; i++) { + if (hide_zero_value_features && w[i] == 0.0) { + continue; + } + os << FD::Convert(i) << '=' << w[i]; + if (i < nf - 1) { + os << ' '; + } + } + return os.str(); +} +void Weights::UpdateFromString(string& w_string, + vector<weight_t>& w) { + vector<string> tok = SplitOnWhitespace(w_string); + for (vector<string>::iterator i = tok.begin(); i != tok.end(); i++) { + int delim = i->find('='); + int fid = FD::Convert(i->substr(0, delim)); + w[fid] = strtod(i->substr(delim + 1).c_str(), NULL); + } +} diff --git a/utils/weights.h b/utils/weights.h index 30f71db0..920fdd75 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -23,6 +23,11 @@ class Weights { static void SanityCheck(const std::vector<weight_t>& w); // write weights with largest magnitude to cerr static void ShowLargestFeatures(const std::vector<weight_t>& w); + static std::string GetString(const std::vector<weight_t>& w, + bool hide_zero_value_features = true); + // Assumes weights are already initialized for now + static void UpdateFromString(std::string& w_string, + std::vector<weight_t>& w); private: Weights(); }; |