summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-28 17:06:08 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-09-28 17:06:08 +0000
commit3db6b004ae1e2319f52862d428c20be5a1538993 (patch)
treeb35c697027afd92324d8d9a63c8e6b27c32d2339
parent521dc2fdbf7eee7d6a86410f490ba7a76691590b (diff)
use boost mpi, fix L1 stochastic optimizer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@659 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--configure.ac10
-rw-r--r--decoder/apply_models.cc2
-rw-r--r--decoder/decoder.cc5
-rw-r--r--decoder/decoder.h1
-rw-r--r--decoder/ff_wordalign.cc14
-rw-r--r--decoder/viterbi.cc2
-rw-r--r--training/Makefile.am4
-rw-r--r--training/mpi_online_optimize.cc164
-rw-r--r--training/online_optimizer.h50
-rw-r--r--utils/sparse_vector.h33
-rw-r--r--utils/weights.cc7
-rw-r--r--utils/weights.h1
12 files changed, 177 insertions, 116 deletions
diff --git a/configure.ac b/configure.ac
index ab8c93a1..d143f33d 100644
--- a/configure.ac
+++ b/configure.ac
@@ -8,11 +8,8 @@ AC_PROG_CXX
AC_LANG_CPLUSPLUS
BOOST_REQUIRE
BOOST_PROGRAM_OPTIONS
-# BOOST_REGEX
BOOST_THREADS
CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS"
-#LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_REGEX_LDFLAGS $BOOST_THREAD_LDFLAGS"
-#LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_REGEX_LIBS $BOOST_THREAD_LIBS"
LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_THREAD_LDFLAGS"
LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_THREAD_LIBS"
@@ -31,6 +28,13 @@ AC_ARG_ENABLE(mpi,
])
AM_CONDITIONAL([MPI], [test "x$mpi" = xyes])
+if test "x$mpi" = xyes
+then
+ BOOST_SERIALIZATION
+ # TODO BOOST_MPI needs to be implemented
+ LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS -lmpi++ -lmpi"
+fi
+
AM_CONDITIONAL([SRI_LM], false)
AC_ARG_WITH(srilm,
[AC_HELP_STRING([--with-srilm=PATH], [(optional) path to SRI's LM toolkit])],
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index aec31a4f..18460950 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -319,7 +319,7 @@ struct NoPruningRescorer {
in(i),
out(*o),
nodemap(i.nodes_.size()) {
- cerr << " Rescoring forest (full intersection)\n";
+ if (!SILENT) cerr << " Rescoring forest (full intersection)\n";
node_states_.reserve(kRESERVE_NUM_NODES);
}
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 1a233fc5..537fdffa 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -147,6 +147,7 @@ struct DecoderImpl {
void SetWeights(const vector<double>& weights) {
feature_weights = weights;
}
+ void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) {
cerr << viterbi_stats(forest,name,true,show_tree,show_deriv);
@@ -622,6 +623,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); }
Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); }
Decoder::~Decoder() {}
+void Decoder::SetId(int next_sent_id) { pimpl_->SetId(next_sent_id); }
bool Decoder::Decode(const string& input, DecoderObserver* o) {
bool del = false;
if (!o) { o = new DecoderObserver; del = true; }
@@ -818,8 +820,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteAsCFG(forest);
if (has_ref) {
if (HG::Intersect(ref, &forest)) {
- if (!SILENT) cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
- if (!SILENT) cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
+ if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
if (crf_uniform_empirical) {
if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
for (int i = 0; i < forest.edges_.size(); ++i)
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 2a1a43ce..abaf3740 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -26,6 +26,7 @@ struct Decoder {
Decoder(std::istream* config_file);
bool Decode(const std::string& input, DecoderObserver* observer = NULL);
void SetWeights(const std::vector<double>& weights);
+ void SetId(int id);
~Decoder();
const boost::program_options::variables_map& GetConf() const { return conf; }
private:
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index 087bff0c..a1968159 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -72,11 +72,11 @@ RelativeSentencePosition::RelativeSentencePosition(const string& param) :
pos_.push_back(v);
for (int i = 0; i < v.size(); ++i)
classes.insert(v[i]);
- for (set<WordID>::iterator i = classes.begin(); i != classes.end(); ++i) {
- ostringstream os;
- os << "RelPos_FC:" << TD::Convert(*i);
- fids_[*i] = FD::Convert(os.str());
- }
+ }
+ for (set<WordID>::iterator i = classes.begin(); i != classes.end(); ++i) {
+ ostringstream os;
+ os << "RelPos_FC:" << TD::Convert(*i);
+ fids_[*i] = FD::Convert(os.str());
}
} else {
condition_on_fclass_ = false;
@@ -104,7 +104,9 @@ void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& sme
if (condition_on_fclass_) {
assert(smeta.GetSentenceID() < pos_.size());
const WordID cur_fclass = pos_[smeta.GetSentenceID()][edge.i_];
- const int fid = fids_.find(cur_fclass)->second;
+ std::map<WordID, int>::const_iterator fidit = fids_.find(cur_fclass);
+ assert(fidit != fids_.end());
+ const int fid = fidit->second;
features->set_value(fid, val);
}
// cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl;
diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc
index fac9dd70..a8192a9f 100644
--- a/decoder/viterbi.cc
+++ b/decoder/viterbi.cc
@@ -16,7 +16,7 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es
if (estring) {
vector<WordID> trans;
const prob_t vs = ViterbiESentence(hg, &trans);
- o<<name<<" Viterbi logp: "<<log(vs)<<endl;
+ o<<name<<" Viterbi logp: "<<log(vs)<< " (norm=" << log(vs)/trans.size() << ")" << endl;
o<<name<<" Viterbi: "<<TD::GetString(trans)<<endl;
}
if (etree) {
diff --git a/training/Makefile.am b/training/Makefile.am
index ea637d9e..2679adea 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -22,10 +22,10 @@ bin_PROGRAMS += mpi_batch_optimize \
mpi_online_optimize
mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc
-mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -lmpi++ -lmpi
+mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc
-mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -lmpi++ -lmpi
+mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
endif
online_train_SOURCES = online_train.cc online_optimizer.cc
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index 62821aa3..6f5988a4 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -6,6 +6,7 @@
#include <cmath>
#include <mpi.h>
+#include <boost/mpi.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -66,10 +67,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch")
("freeze_feature_set,Z", "The feature set specified in the initial weights file is frozen throughout the duration of training")
("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)")
+ ("fully_random,r", "Fully random draws from the training corpus")
+ ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("eta_0,e", po::value<double>()->default_value(0.2), "Initial learning rate for SGD (eta_0)")
("L1,1","Use L1 regularization")
- ("gaussian_prior,g","Use a Gaussian prior on the weights")
- ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior");
+ ("regularization_strength,C", po::value<double>()->default_value(1.0), "Regularization strength (C)");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
@@ -165,7 +167,7 @@ struct TrainingObserver : public DecoderObserver {
}
assert(!isnan(log_ref_z));
ref_exp -= cur_model_exp;
- acc_grad -= ref_exp;
+ acc_grad += ref_exp;
acc_obj += (cur_obj - log_ref_z);
}
@@ -176,6 +178,12 @@ struct TrainingObserver : public DecoderObserver {
}
}
+ void GetGradient(SparseVector<double>* g) const {
+ g->clear();
+ for (SparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it)
+ g->set_value(it->first, it->second);
+ }
+
int total_complete;
SparseVector<prob_t> cur_model_exp;
SparseVector<prob_t> acc_grad;
@@ -193,10 +201,20 @@ inline void Shuffle(vector<T>* c, MT19937* rng) {
}
}
+namespace mpi = boost::mpi;
+
+namespace boost { namespace mpi {
+ template<>
+ struct is_commutative<std::plus<SparseVector<double> >, SparseVector<double> >
+ : mpl::true_ { };
+} } // end namespace boost::mpi
+
+
int main(int argc, char** argv) {
- MPI::Init(argc, argv);
- const int size = MPI::COMM_WORLD.Get_size();
- const int rank = MPI::COMM_WORLD.Get_rank();
+ mpi::environment env(argc, argv);
+ mpi::communicator world;
+ const int size = world.size();
+ const int rank = world.rank();
SetSilent(true); // turn off verbose decoder output
cerr << "MPI: I am " << rank << '/' << size << endl;
register_feature_functions();
@@ -219,7 +237,7 @@ int main(int argc, char** argv) {
Decoder decoder(ini_rf.stream());
if (decoder.GetConf()["input"].as<string>() != "-") {
cerr << "cdec.ini must not set an input file\n";
- MPI::COMM_WORLD.Abort(1);
+ abort();
}
vector<string> corpus;
@@ -228,105 +246,87 @@ int main(int argc, char** argv) {
std::tr1::shared_ptr<OnlineOptimizer> o;
std::tr1::shared_ptr<LearningRateSchedule> lr;
- vector<int> order;
+ vector<int> order(corpus.size());
+
+ const bool fully_random = conf.count("fully_random");
+ const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>();
+ const unsigned batch_size = size_per_proc * size;
if (rank == 0) {
+ cerr << "Corpus: " << corpus.size() << " batch size: " << batch_size << endl;
+ if (batch_size > corpus.size()) {
+ cerr << " Reduce minibatch_size_per_proc!";
+ abort();
+ }
+
// TODO config
- lr.reset(new ExponentialDecayLearningRate(corpus.size(), conf["eta_0"].as<double>()));
+ lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as<double>()));
const string omethod = conf["optimization_method"].as<string>();
if (omethod == "sgd") {
- const double C = 1.0;
+ const double C = conf["regularization_strength"].as<double>();
o.reset(new CumulativeL1OnlineOptimizer(lr, corpus.size(), C));
} else {
assert(!"fail");
}
- // randomize corpus
- rng = new MT19937;
- order.resize(corpus.size());
for (unsigned i = 0; i < order.size(); ++i) order[i]=i;
- Shuffle(&order, rng);
+ // randomize corpus
+ if (conf.count("random_seed"))
+ rng = new MT19937(conf["random_seed"].as<uint32_t>());
+ else
+ rng = new MT19937;
}
+ SparseVector<double> x;
+ int miter = corpus.size(); // hack to cause initial broadcast of order info
+ TrainingObserver observer;
double objective = 0;
- vector<double> lambdas;
- weights.InitVector(&lambdas);
bool converged = false;
- const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>();
- for (int i = 0; i < size_per_proc; ++i)
- cerr << "i=" << i << ": " << order[i] << endl;
- abort();
- TrainingObserver observer;
+
+ int iter = -1;
+ vector<double> lambdas;
while (!converged) {
+ weights.InitFromVector(x);
+ weights.InitVector(&lambdas);
+ ++miter; ++iter;
observer.Reset();
- if (rank == 0) {
- cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n";
- }
decoder.SetWeights(lambdas);
-#if 0
- for (int i = 0; i < corpus.size(); ++i)
- decoder.Decode(corpus[i], &observer);
-
- fill(gradient.begin(), gradient.end(), 0);
- fill(rcv_grad.begin(), rcv_grad.end(), 0);
- observer.SetLocalGradientAndObjective(&gradient, &objective);
-
- double to = 0;
- MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);
- MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);
- swap(gradient, rcv_grad);
- objective = to;
-
- if (rank == 0) { // run optimizer only on rank=0 node
- if (gaussian_prior) {
- const double sigsq = conf["sigma_squared"].as<double>();
- double norm = 0;
- for (int k = 1; k < lambdas.size(); ++k) {
- const double& lambda_k = lambdas[k];
- if (lambda_k) {
- const double param = (lambda_k - means[k]);
- norm += param * param;
- gradient[k] += param / sigsq;
- }
- }
- const double reg = norm / (2.0 * sigsq);
- cerr << "REGULARIZATION TERM: " << reg << endl;
- objective += reg;
- }
- cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl;
- double gnorm = 0;
- for (int i = 0; i < gradient.size(); ++i)
- gnorm += gradient[i] * gradient[i];
- cerr << " GNORM=" << sqrt(gnorm) << endl;
- vector<double> old = lambdas;
- int c = 0;
- while (old == lambdas) {
- ++c;
- if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; }
- o->Optimize(objective, gradient, &lambdas);
- assert(c < 5);
- }
- old.clear();
+ if (rank == 0) {
SanityCheck(lambdas);
ShowLargestFeatures(lambdas);
- weights.InitFromVector(lambdas);
-
- converged = o->HasConverged();
- if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; }
-
string fname = "weights.cur.gz";
if (converged) { fname = "weights.final.gz"; }
ostringstream vv;
- vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")";
+ vv << "Objective = " << objective; // << " (eval count=" << o->EvaluationCount() << ")";
const string svv = vv.str();
weights.WriteToFile(fname, true, &svv);
- } // rank == 0
- int cint = converged;
- MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);
- MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);
- MPI::COMM_WORLD.Barrier();
- converged = cint;
-#endif
+ }
+
+ if (fully_random || size * size_per_proc * miter > corpus.size()) {
+ if (rank == 0)
+ Shuffle(&order, rng);
+ miter = 0;
+ broadcast(world, order, 0);
+ }
+ if (rank == 0)
+ cerr << "Starting decoding. minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << " training data proc. = " << (iter * batch_size / static_cast<double>(corpus.size())) << " eta=" << lr->eta(iter) << endl;
+
+ const int beg = size * miter * size_per_proc + rank * size_per_proc;
+ const int end = beg + size_per_proc;
+ for (int i = beg; i < end; ++i) {
+ int ex_num = order[i % order.size()];
+ if (rank ==0 && size < 3) cerr << rank << ": ex_num=" << ex_num << endl;
+ decoder.SetId(ex_num);
+ decoder.Decode(corpus[ex_num], &observer);
+ }
+ SparseVector<double> local_grad, g;
+ observer.GetGradient(&local_grad);
+ reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0);
+ if (rank == 0) {
+ g /= batch_size;
+ o->UpdateWeights(g, FD::NumFeats(), &x);
+ }
+ broadcast(world, x, 0);
+ world.barrier();
}
- MPI::Finalize();
return 0;
}
diff --git a/training/online_optimizer.h b/training/online_optimizer.h
index d2718f93..963c0380 100644
--- a/training/online_optimizer.h
+++ b/training/online_optimizer.h
@@ -8,16 +8,22 @@
struct LearningRateSchedule {
virtual ~LearningRateSchedule();
- // returns the learning rate for iteration k
+ // returns the learning rate for the kth iteration
virtual double eta(int k) const = 0;
};
+// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N
+// to mean the batch size in most places, but it doesn't completely
+// make sense to me in the learning rate schedules-- this needs
+// to be worked out to make sure they didn't mean corpus size
+// in some places and batch size in others (since in the paper they
+// only ever work with batch sizes of 1)
struct StandardLearningRate : public LearningRateSchedule {
StandardLearningRate(
- size_t training_instances,
+ size_t batch_size, // batch size, not corpus size!
double eta_0 = 0.2) :
eta_0_(eta_0),
- N_(static_cast<double>(training_instances)) {}
+ N_(static_cast<double>(batch_size)) {}
virtual double eta(int k) const;
@@ -28,11 +34,11 @@ struct StandardLearningRate : public LearningRateSchedule {
struct ExponentialDecayLearningRate : public LearningRateSchedule {
ExponentialDecayLearningRate(
- size_t training_instances,
+ size_t batch_size, // batch size, not corpus size!
double eta_0 = 0.2,
double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009)
) : eta_0_(eta_0),
- N_(static_cast<double>(training_instances)),
+ N_(static_cast<double>(batch_size)),
alpha_(alpha) {
assert(alpha > 0);
assert(alpha < 1.0);
@@ -50,17 +56,17 @@ class OnlineOptimizer {
public:
virtual ~OnlineOptimizer();
OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
- size_t training_instances)
- : N_(training_instances),schedule_(s),k_() {}
- void UpdateWeights(const SparseVector<double>& approx_g, SparseVector<double>* weights) {
+ size_t batch_size)
+ : N_(batch_size),schedule_(s),k_() {}
+ void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
++k_;
const double eta = schedule_->eta(k_);
- UpdateWeightsImpl(eta, approx_g, weights);
+ UpdateWeightsImpl(eta, approx_g, max_feat, weights);
}
protected:
- virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) = 0;
- const size_t N_; // number of training instances
+ virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0;
+ const size_t N_; // number of training instances per batch
private:
std::tr1::shared_ptr<LearningRateSchedule> schedule_;
@@ -74,11 +80,11 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
OnlineOptimizer(s, training_instances), C_(C), u_() {}
protected:
- void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) {
+ void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) {
u_ += eta * C_ / N_;
(*weights) += eta * approx_g;
- for (SparseVector<double>::const_iterator it = approx_g.begin(); it != approx_g.end(); ++it)
- ApplyPenalty(it->first, weights);
+ for (int i = 1; i < max_feat; ++i)
+ ApplyPenalty(i, weights);
}
private:
@@ -86,13 +92,19 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
const double z = w->value(i);
double w_i = z;
double q_i = q_.value(i);
- if (w_i > 0)
+ if (w_i > 0.0)
w_i = std::max(0.0, w_i - (u_ + q_i));
- else
- w_i = std::max(0.0, w_i + (u_ - q_i));
+ else if (w_i < 0.0)
+ w_i = std::min(0.0, w_i + (u_ - q_i));
q_i += w_i - z;
- q_.set_value(i, q_i);
- w->set_value(i, w_i);
+ if (q_i == 0.0)
+ q_.erase(i);
+ else
+ q_.set_value(i, q_i);
+ if (w_i == 0.0)
+ w->erase(i);
+ else
+ w->set_value(i, w_i);
}
const double C_; // reguarlization strength
diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h
index 5d0dac27..c5e18b96 100644
--- a/utils/sparse_vector.h
+++ b/utils/sparse_vector.h
@@ -56,6 +56,10 @@ TODO: specialize for int value types, where it probably makes sense to check if
#include "small_vector.h"
#include "string_to.h"
+#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP
+#include <boost/serialization/map.hpp>
+#endif
+
template <class T>
inline T & extend_vector(std::vector<T> &v,int i) {
if (i>=v.size())
@@ -510,6 +514,35 @@ public:
private:
MapType values_;
+
+#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP
+ friend class boost::serialization::access;
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ (void) version;
+ int eff_size = values_.size();
+ const_iterator it = this->begin();
+ if (values_.find(0) != values_.end()) { ++it; --eff_size; }
+ ar & eff_size;
+ while (it != this->end()) {
+ const std::pair<std::string, T> wire_pair(FD::Convert(it->first), it->second);
+ ar & wire_pair;
+ ++it;
+ }
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ (void) version;
+ this->clear();
+ int sz; ar & sz;
+ for (int i = 0; i < sz; ++i) {
+ std::pair<std::string, T> wire_pair;
+ ar & wire_pair;
+ this->set_value(FD::Convert(wire_pair.first), wire_pair.second);
+ }
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
+#endif
};
template <class T>
diff --git a/utils/weights.cc b/utils/weights.cc
index ea8bd816..53089f89 100644
--- a/utils/weights.cc
+++ b/utils/weights.cc
@@ -81,3 +81,10 @@ void Weights::InitFromVector(const std::vector<double>& w) {
cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
wv_.resize(FD::NumFeats(), 0);
}
+
+void Weights::InitFromVector(const SparseVector<double>& w) {
+ wv_.clear();
+ wv_.resize(FD::NumFeats(), 0.0);
+ for (int i = 1; i < FD::NumFeats(); ++i)
+ wv_[i] = w.value(i);
+}
diff --git a/utils/weights.h b/utils/weights.h
index 1849f959..cc20283c 100644
--- a/utils/weights.h
+++ b/utils/weights.h
@@ -14,6 +14,7 @@ class Weights {
void InitVector(std::vector<double>* w) const;
void InitSparseVector(SparseVector<double>* w) const;
void InitFromVector(const std::vector<double>& w);
+ void InitFromVector(const SparseVector<double>& w);
private:
std::vector<double> wv_;
};