diff options
-rw-r--r-- | configure.ac | 10 | ||||
-rw-r--r-- | decoder/apply_models.cc | 2 | ||||
-rw-r--r-- | decoder/decoder.cc | 5 | ||||
-rw-r--r-- | decoder/decoder.h | 1 | ||||
-rw-r--r-- | decoder/ff_wordalign.cc | 14 | ||||
-rw-r--r-- | decoder/viterbi.cc | 2 | ||||
-rw-r--r-- | training/Makefile.am | 4 | ||||
-rw-r--r-- | training/mpi_online_optimize.cc | 164 | ||||
-rw-r--r-- | training/online_optimizer.h | 50 | ||||
-rw-r--r-- | utils/sparse_vector.h | 33 | ||||
-rw-r--r-- | utils/weights.cc | 7 | ||||
-rw-r--r-- | utils/weights.h | 1 |
12 files changed, 177 insertions, 116 deletions
diff --git a/configure.ac b/configure.ac index ab8c93a1..d143f33d 100644 --- a/configure.ac +++ b/configure.ac @@ -8,11 +8,8 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE BOOST_PROGRAM_OPTIONS -# BOOST_REGEX BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -#LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_REGEX_LDFLAGS $BOOST_THREAD_LDFLAGS" -#LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_REGEX_LIBS $BOOST_THREAD_LIBS" LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_THREAD_LDFLAGS" LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_THREAD_LIBS" @@ -31,6 +28,13 @@ AC_ARG_ENABLE(mpi, ]) AM_CONDITIONAL([MPI], [test "x$mpi" = xyes]) +if test "x$mpi" = xyes +then + BOOST_SERIALIZATION + # TODO BOOST_MPI needs to be implemented + LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS -lmpi++ -lmpi" +fi + AM_CONDITIONAL([SRI_LM], false) AC_ARG_WITH(srilm, [AC_HELP_STRING([--with-srilm=PATH], [(optional) path to SRI's LM toolkit])], diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index aec31a4f..18460950 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -319,7 +319,7 @@ struct NoPruningRescorer { in(i), out(*o), nodemap(i.nodes_.size()) { - cerr << " Rescoring forest (full intersection)\n"; + if (!SILENT) cerr << " Rescoring forest (full intersection)\n"; node_states_.reserve(kRESERVE_NUM_NODES); } diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 1a233fc5..537fdffa 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -147,6 +147,7 @@ struct DecoderImpl { void SetWeights(const vector<double>& weights) { feature_weights = weights; } + void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,WeightVector *weights=0,bool show_deriv=false) { cerr << viterbi_stats(forest,name,true,show_tree,show_deriv); @@ -622,6 +623,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); } Decoder::~Decoder() {} +void Decoder::SetId(int next_sent_id) { pimpl_->SetId(next_sent_id); } bool Decoder::Decode(const string& input, DecoderObserver* o) { bool del = false; if (!o) { o = new DecoderObserver; del = true; } @@ -818,8 +820,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteAsCFG(forest); if (has_ref) { if (HG::Intersect(ref, &forest)) { - if (!SILENT) cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; - if (!SILENT) cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; + if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); if (crf_uniform_empirical) { if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; for (int i = 0; i < forest.edges_.size(); ++i) diff --git a/decoder/decoder.h b/decoder/decoder.h index 2a1a43ce..abaf3740 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -26,6 +26,7 @@ struct Decoder { Decoder(std::istream* config_file); bool Decode(const std::string& input, DecoderObserver* observer = NULL); void SetWeights(const std::vector<double>& weights); + void SetId(int id); ~Decoder(); const boost::program_options::variables_map& GetConf() const { return conf; } private: diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index 087bff0c..a1968159 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -72,11 +72,11 @@ RelativeSentencePosition::RelativeSentencePosition(const string& param) : pos_.push_back(v); for (int i = 0; i < v.size(); ++i) classes.insert(v[i]); - for (set<WordID>::iterator i = classes.begin(); i != classes.end(); ++i) { - ostringstream os; - os << "RelPos_FC:" << TD::Convert(*i); - fids_[*i] = FD::Convert(os.str()); - } + } + for (set<WordID>::iterator i = classes.begin(); i != classes.end(); ++i) { + ostringstream os; + os << "RelPos_FC:" << TD::Convert(*i); + fids_[*i] = FD::Convert(os.str()); } } else { condition_on_fclass_ = false; @@ -104,7 +104,9 @@ void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& sme if (condition_on_fclass_) { assert(smeta.GetSentenceID() < pos_.size()); const WordID cur_fclass = pos_[smeta.GetSentenceID()][edge.i_]; - const int fid = fids_.find(cur_fclass)->second; + std::map<WordID, int>::const_iterator fidit = fids_.find(cur_fclass); + assert(fidit != fids_.end()); + const int fid = fidit->second; features->set_value(fid, val); } // cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl; diff --git a/decoder/viterbi.cc b/decoder/viterbi.cc index fac9dd70..a8192a9f 100644 --- a/decoder/viterbi.cc +++ b/decoder/viterbi.cc @@ -16,7 +16,7 @@ std::string viterbi_stats(Hypergraph const& hg, std::string const& name, bool es if (estring) { vector<WordID> trans; const prob_t vs = ViterbiESentence(hg, &trans); - o<<name<<" Viterbi logp: "<<log(vs)<<endl; + o<<name<<" Viterbi logp: "<<log(vs)<< " (norm=" << log(vs)/trans.size() << ")" << endl; o<<name<<" Viterbi: "<<TD::GetString(trans)<<endl; } if (etree) { diff --git a/training/Makefile.am b/training/Makefile.am index ea637d9e..2679adea 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -22,10 +22,10 @@ bin_PROGRAMS += mpi_batch_optimize \ mpi_online_optimize mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc -mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -lmpi++ -lmpi +mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc -mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -lmpi++ -lmpi +mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz endif online_train_SOURCES = online_train.cc online_optimizer.cc diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 62821aa3..6f5988a4 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -6,6 +6,7 @@ #include <cmath> #include <mpi.h> +#include <boost/mpi.hpp> #include <boost/shared_ptr.hpp> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -66,10 +67,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch") ("freeze_feature_set,Z", "The feature set specified in the initial weights file is frozen throughout the duration of training") ("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)") + ("fully_random,r", "Fully random draws from the training corpus") + ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("eta_0,e", po::value<double>()->default_value(0.2), "Initial learning rate for SGD (eta_0)") ("L1,1","Use L1 regularization") - ("gaussian_prior,g","Use a Gaussian prior on the weights") - ("sigma_squared", po::value<double>()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); + ("regularization_strength,C", po::value<double>()->default_value(1.0), "Regularization strength (C)"); po::options_description clo("Command line options"); clo.add_options() ("config", po::value<string>(), "Configuration file") @@ -165,7 +167,7 @@ struct TrainingObserver : public DecoderObserver { } assert(!isnan(log_ref_z)); ref_exp -= cur_model_exp; - acc_grad -= ref_exp; + acc_grad += ref_exp; acc_obj += (cur_obj - log_ref_z); } @@ -176,6 +178,12 @@ struct TrainingObserver : public DecoderObserver { } } + void GetGradient(SparseVector<double>* g) const { + g->clear(); + for (SparseVector<prob_t>::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + g->set_value(it->first, it->second); + } + int total_complete; SparseVector<prob_t> cur_model_exp; SparseVector<prob_t> acc_grad; @@ -193,10 +201,20 @@ inline void Shuffle(vector<T>* c, MT19937* rng) { } } +namespace mpi = boost::mpi; + +namespace boost { namespace mpi { + template<> + struct is_commutative<std::plus<SparseVector<double> >, SparseVector<double> > + : mpl::true_ { }; +} } // end namespace boost::mpi + + int main(int argc, char** argv) { - MPI::Init(argc, argv); - const int size = MPI::COMM_WORLD.Get_size(); - const int rank = MPI::COMM_WORLD.Get_rank(); + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); SetSilent(true); // turn off verbose decoder output cerr << "MPI: I am " << rank << '/' << size << endl; register_feature_functions(); @@ -219,7 +237,7 @@ int main(int argc, char** argv) { Decoder decoder(ini_rf.stream()); if (decoder.GetConf()["input"].as<string>() != "-") { cerr << "cdec.ini must not set an input file\n"; - MPI::COMM_WORLD.Abort(1); + abort(); } vector<string> corpus; @@ -228,105 +246,87 @@ int main(int argc, char** argv) { std::tr1::shared_ptr<OnlineOptimizer> o; std::tr1::shared_ptr<LearningRateSchedule> lr; - vector<int> order; + vector<int> order(corpus.size()); + + const bool fully_random = conf.count("fully_random"); + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); + const unsigned batch_size = size_per_proc * size; if (rank == 0) { + cerr << "Corpus: " << corpus.size() << " batch size: " << batch_size << endl; + if (batch_size > corpus.size()) { + cerr << " Reduce minibatch_size_per_proc!"; + abort(); + } + // TODO config - lr.reset(new ExponentialDecayLearningRate(corpus.size(), conf["eta_0"].as<double>())); + lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as<double>())); const string omethod = conf["optimization_method"].as<string>(); if (omethod == "sgd") { - const double C = 1.0; + const double C = conf["regularization_strength"].as<double>(); o.reset(new CumulativeL1OnlineOptimizer(lr, corpus.size(), C)); } else { assert(!"fail"); } - // randomize corpus - rng = new MT19937; - order.resize(corpus.size()); for (unsigned i = 0; i < order.size(); ++i) order[i]=i; - Shuffle(&order, rng); + // randomize corpus + if (conf.count("random_seed")) + rng = new MT19937(conf["random_seed"].as<uint32_t>()); + else + rng = new MT19937; } + SparseVector<double> x; + int miter = corpus.size(); // hack to cause initial broadcast of order info + TrainingObserver observer; double objective = 0; - vector<double> lambdas; - weights.InitVector(&lambdas); bool converged = false; - const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); - for (int i = 0; i < size_per_proc; ++i) - cerr << "i=" << i << ": " << order[i] << endl; - abort(); - TrainingObserver observer; + + int iter = -1; + vector<double> lambdas; while (!converged) { + weights.InitFromVector(x); + weights.InitVector(&lambdas); + ++miter; ++iter; observer.Reset(); - if (rank == 0) { - cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; - } decoder.SetWeights(lambdas); -#if 0 - for (int i = 0; i < corpus.size(); ++i) - decoder.Decode(corpus[i], &observer); - - fill(gradient.begin(), gradient.end(), 0); - fill(rcv_grad.begin(), rcv_grad.end(), 0); - observer.SetLocalGradientAndObjective(&gradient, &objective); - - double to = 0; - MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); - MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); - swap(gradient, rcv_grad); - objective = to; - - if (rank == 0) { // run optimizer only on rank=0 node - if (gaussian_prior) { - const double sigsq = conf["sigma_squared"].as<double>(); - double norm = 0; - for (int k = 1; k < lambdas.size(); ++k) { - const double& lambda_k = lambdas[k]; - if (lambda_k) { - const double param = (lambda_k - means[k]); - norm += param * param; - gradient[k] += param / sigsq; - } - } - const double reg = norm / (2.0 * sigsq); - cerr << "REGULARIZATION TERM: " << reg << endl; - objective += reg; - } - cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; - double gnorm = 0; - for (int i = 0; i < gradient.size(); ++i) - gnorm += gradient[i] * gradient[i]; - cerr << " GNORM=" << sqrt(gnorm) << endl; - vector<double> old = lambdas; - int c = 0; - while (old == lambdas) { - ++c; - if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } - o->Optimize(objective, gradient, &lambdas); - assert(c < 5); - } - old.clear(); + if (rank == 0) { SanityCheck(lambdas); ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); - - converged = o->HasConverged(); - if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } - string fname = "weights.cur.gz"; if (converged) { fname = "weights.final.gz"; } ostringstream vv; - vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; + vv << "Objective = " << objective; // << " (eval count=" << o->EvaluationCount() << ")"; const string svv = vv.str(); weights.WriteToFile(fname, true, &svv); - } // rank == 0 - int cint = converged; - MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); - MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); - MPI::COMM_WORLD.Barrier(); - converged = cint; -#endif + } + + if (fully_random || size * size_per_proc * miter > corpus.size()) { + if (rank == 0) + Shuffle(&order, rng); + miter = 0; + broadcast(world, order, 0); + } + if (rank == 0) + cerr << "Starting decoding. minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << " training data proc. = " << (iter * batch_size / static_cast<double>(corpus.size())) << " eta=" << lr->eta(iter) << endl; + + const int beg = size * miter * size_per_proc + rank * size_per_proc; + const int end = beg + size_per_proc; + for (int i = beg; i < end; ++i) { + int ex_num = order[i % order.size()]; + if (rank ==0 && size < 3) cerr << rank << ": ex_num=" << ex_num << endl; + decoder.SetId(ex_num); + decoder.Decode(corpus[ex_num], &observer); + } + SparseVector<double> local_grad, g; + observer.GetGradient(&local_grad); + reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0); + if (rank == 0) { + g /= batch_size; + o->UpdateWeights(g, FD::NumFeats(), &x); + } + broadcast(world, x, 0); + world.barrier(); } - MPI::Finalize(); return 0; } diff --git a/training/online_optimizer.h b/training/online_optimizer.h index d2718f93..963c0380 100644 --- a/training/online_optimizer.h +++ b/training/online_optimizer.h @@ -8,16 +8,22 @@ struct LearningRateSchedule { virtual ~LearningRateSchedule(); - // returns the learning rate for iteration k + // returns the learning rate for the kth iteration virtual double eta(int k) const = 0; }; +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) struct StandardLearningRate : public LearningRateSchedule { StandardLearningRate( - size_t training_instances, + size_t batch_size, // batch size, not corpus size! double eta_0 = 0.2) : eta_0_(eta_0), - N_(static_cast<double>(training_instances)) {} + N_(static_cast<double>(batch_size)) {} virtual double eta(int k) const; @@ -28,11 +34,11 @@ struct StandardLearningRate : public LearningRateSchedule { struct ExponentialDecayLearningRate : public LearningRateSchedule { ExponentialDecayLearningRate( - size_t training_instances, + size_t batch_size, // batch size, not corpus size! double eta_0 = 0.2, double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) ) : eta_0_(eta_0), - N_(static_cast<double>(training_instances)), + N_(static_cast<double>(batch_size)), alpha_(alpha) { assert(alpha > 0); assert(alpha < 1.0); @@ -50,17 +56,17 @@ class OnlineOptimizer { public: virtual ~OnlineOptimizer(); OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s, - size_t training_instances) - : N_(training_instances),schedule_(s),k_() {} - void UpdateWeights(const SparseVector<double>& approx_g, SparseVector<double>* weights) { + size_t batch_size) + : N_(batch_size),schedule_(s),k_() {} + void UpdateWeights(const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { ++k_; const double eta = schedule_->eta(k_); - UpdateWeightsImpl(eta, approx_g, weights); + UpdateWeightsImpl(eta, approx_g, max_feat, weights); } protected: - virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) = 0; - const size_t N_; // number of training instances + virtual void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) = 0; + const size_t N_; // number of training instances per batch private: std::tr1::shared_ptr<LearningRateSchedule> schedule_; @@ -74,11 +80,11 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { OnlineOptimizer(s, training_instances), C_(C), u_() {} protected: - void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, SparseVector<double>* weights) { + void UpdateWeightsImpl(const double& eta, const SparseVector<double>& approx_g, int max_feat, SparseVector<double>* weights) { u_ += eta * C_ / N_; (*weights) += eta * approx_g; - for (SparseVector<double>::const_iterator it = approx_g.begin(); it != approx_g.end(); ++it) - ApplyPenalty(it->first, weights); + for (int i = 1; i < max_feat; ++i) + ApplyPenalty(i, weights); } private: @@ -86,13 +92,19 @@ class CumulativeL1OnlineOptimizer : public OnlineOptimizer { const double z = w->value(i); double w_i = z; double q_i = q_.value(i); - if (w_i > 0) + if (w_i > 0.0) w_i = std::max(0.0, w_i - (u_ + q_i)); - else - w_i = std::max(0.0, w_i + (u_ - q_i)); + else if (w_i < 0.0) + w_i = std::min(0.0, w_i + (u_ - q_i)); q_i += w_i - z; - q_.set_value(i, q_i); - w->set_value(i, w_i); + if (q_i == 0.0) + q_.erase(i); + else + q_.set_value(i, q_i); + if (w_i == 0.0) + w->erase(i); + else + w->set_value(i, w_i); } const double C_; // reguarlization strength diff --git a/utils/sparse_vector.h b/utils/sparse_vector.h index 5d0dac27..c5e18b96 100644 --- a/utils/sparse_vector.h +++ b/utils/sparse_vector.h @@ -56,6 +56,10 @@ TODO: specialize for int value types, where it probably makes sense to check if #include "small_vector.h" #include "string_to.h" +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP +#include <boost/serialization/map.hpp> +#endif + template <class T> inline T & extend_vector(std::vector<T> &v,int i) { if (i>=v.size()) @@ -510,6 +514,35 @@ public: private: MapType values_; + +#if HAVE_BOOST_ARCHIVE_TEXT_OARCHIVE_HPP + friend class boost::serialization::access; + template<class Archive> + void save(Archive & ar, const unsigned int version) const { + (void) version; + int eff_size = values_.size(); + const_iterator it = this->begin(); + if (values_.find(0) != values_.end()) { ++it; --eff_size; } + ar & eff_size; + while (it != this->end()) { + const std::pair<std::string, T> wire_pair(FD::Convert(it->first), it->second); + ar & wire_pair; + ++it; + } + } + template<class Archive> + void load(Archive & ar, const unsigned int version) { + (void) version; + this->clear(); + int sz; ar & sz; + for (int i = 0; i < sz; ++i) { + std::pair<std::string, T> wire_pair; + ar & wire_pair; + this->set_value(FD::Convert(wire_pair.first), wire_pair.second); + } + } + BOOST_SERIALIZATION_SPLIT_MEMBER() +#endif }; template <class T> diff --git a/utils/weights.cc b/utils/weights.cc index ea8bd816..53089f89 100644 --- a/utils/weights.cc +++ b/utils/weights.cc @@ -81,3 +81,10 @@ void Weights::InitFromVector(const std::vector<double>& w) { cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; wv_.resize(FD::NumFeats(), 0); } + +void Weights::InitFromVector(const SparseVector<double>& w) { + wv_.clear(); + wv_.resize(FD::NumFeats(), 0.0); + for (int i = 1; i < FD::NumFeats(); ++i) + wv_[i] = w.value(i); +} diff --git a/utils/weights.h b/utils/weights.h index 1849f959..cc20283c 100644 --- a/utils/weights.h +++ b/utils/weights.h @@ -14,6 +14,7 @@ class Weights { void InitVector(std::vector<double>* w) const; void InitSparseVector(SparseVector<double>* w) const; void InitFromVector(const std::vector<double>& w); + void InitFromVector(const SparseVector<double>& w); private: std::vector<double> wv_; }; |