From 67456f9f7af754750faeea6f1e66b14b910d8751 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 18 Jun 2012 20:28:42 -0400 Subject: add non-const iterators to sparse vector, speed up model1 code --- utils/ccrp_onetable.h | 2 +- utils/corpus_tools.cc | 20 ++++++++++++ utils/corpus_tools.h | 4 +++ utils/fast_sparse_vector.h | 80 +++++++++++++++++++++++++++++++++++++++++++++- utils/sampler.h | 16 +++++++++- 5 files changed, 119 insertions(+), 3 deletions(-) (limited to 'utils') diff --git a/utils/ccrp_onetable.h b/utils/ccrp_onetable.h index 1fe01b0e..abe399ea 100644 --- a/utils/ccrp_onetable.h +++ b/utils/ccrp_onetable.h @@ -183,7 +183,7 @@ class CCRP_OneTable { assert(has_discount_prior() || has_alpha_prior()); DiscountResampler dr(*this); ConcentrationResampler cr(*this); - for (int iter = 0; iter < nloop; ++iter) { + for (unsigned iter = 0; iter < nloop; ++iter) { if (has_alpha_prior()) { alpha_ = slice_sampler1d(cr, alpha_, *rng, 0.0, std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); diff --git a/utils/corpus_tools.cc b/utils/corpus_tools.cc index d17785af..191153a2 100644 --- a/utils/corpus_tools.cc +++ b/utils/corpus_tools.cc @@ -8,6 +8,26 @@ using namespace std; +void CorpusTools::ReadLine(const string& line, + vector* src, + vector* trg) { + static const WordID kDIV = TD::Convert("|||"); + static vector tmp; + src->clear(); + trg->clear(); + TD::ConvertSentence(line, &tmp); + unsigned i = 0; + while(i < tmp.size() && tmp[i] != kDIV) { + src->push_back(tmp[i]); + ++i; + } + if (i < tmp.size() && tmp[i] == kDIV) { + ++i; + for (; i < tmp.size() ; ++i) + trg->push_back(tmp[i]); + } +} + void CorpusTools::ReadFromFile(const string& filename, vector >* src, set* src_vocab, diff --git a/utils/corpus_tools.h b/utils/corpus_tools.h index 97bdaa94..f6699d87 100644 --- a/utils/corpus_tools.h +++ b/utils/corpus_tools.h @@ -7,6 +7,10 @@ #include "wordid.h" struct CorpusTools { + static void ReadLine(const std::string& line, + std::vector* src, + std::vector* trg); + static void ReadFromFile(const std::string& filename, std::vector >* src, std::set* src_vocab = NULL, diff --git a/utils/fast_sparse_vector.h b/utils/fast_sparse_vector.h index e86cbdc1..6e5dfb14 100644 --- a/utils/fast_sparse_vector.h +++ b/utils/fast_sparse_vector.h @@ -66,6 +66,60 @@ BOOST_STATIC_ASSERT(sizeof(PairIntT) == sizeof(std::pair) template class FastSparseVector { public: + struct iterator { + iterator(FastSparseVector& v, const bool is_end) : local_(!v.is_remote_) { + if (local_) { + local_it_ = &v.data_.local[is_end ? v.local_size_ : 0]; + } else { + if (is_end) + remote_it_ = v.data_.rbmap->end(); + else + remote_it_ = v.data_.rbmap->begin(); + } + } + iterator(FastSparseVector& v, const bool, const unsigned k) : local_(!v.is_remote_) { + if (local_) { + unsigned i = 0; + while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; } + local_it_ = &v.data_.local[i]; + } else { + remote_it_ = v.data_.rbmap->find(k); + } + } + const bool local_; + PairIntT* local_it_; + typename std::map::iterator remote_it_; + std::pair& operator*() const { + if (local_) + return *reinterpret_cast*>(local_it_); + else + return *remote_it_; + } + + std::pair* operator->() const { + if (local_) + return reinterpret_cast*>(local_it_); + else + return &*remote_it_; + } + + iterator& operator++() { + if (local_) ++local_it_; else ++remote_it_; + return *this; + } + + inline bool operator==(const iterator& o) const { + if (o.local_ != local_) return false; + if (local_) { + return local_it_ == o.local_it_; + } else { + return remote_it_ == o.remote_it_; + } + } + inline bool operator!=(const iterator& o) const { + return !(o == *this); + } + }; struct const_iterator { const_iterator(const FastSparseVector& v, const bool is_end) : local_(!v.is_remote_) { if (local_) { @@ -77,12 +131,21 @@ class FastSparseVector { remote_it_ = v.data_.rbmap->begin(); } } + const_iterator(const FastSparseVector& v, const bool, const unsigned k) : local_(!v.is_remote_) { + if (local_) { + unsigned i = 0; + while(i < v.local_size_ && v.data_.local[i].first() != k) { ++i; } + local_it_ = &v.data_.local[i]; + } else { + remote_it_ = v.data_.rbmap->find(k); + } + } const bool local_; const PairIntT* local_it_; typename std::map::const_iterator remote_it_; const std::pair& operator*() const { if (local_) - return *reinterpret_cast*>(local_it_); + return *reinterpret_cast*>(local_it_); else return *remote_it_; } @@ -160,6 +223,9 @@ class FastSparseVector { bool nonzero(unsigned k) const { return static_cast(value(k)); } + inline T& operator[](unsigned k) { + return get_or_create_bin(k); + } inline void set_value(unsigned k, const T& v) { get_or_create_bin(k) = v; } @@ -283,6 +349,18 @@ class FastSparseVector { } return o; } + iterator find(unsigned k) { + return iterator(*this, false, k); + } + iterator begin() { + return iterator(*this, false); + } + iterator end() { + return iterator(*this, true); + } + const_iterator find(unsigned k) const { + return const_iterator(*this, false, k); + } const_iterator begin() const { return const_iterator(*this, false); } diff --git a/utils/sampler.h b/utils/sampler.h index b237c716..3e4a4086 100644 --- a/utils/sampler.h +++ b/utils/sampler.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -76,6 +77,18 @@ struct RandomNumberGenerator { return boost::poisson_distribution(lambda)(m_random); } + double NextGamma(double shape, double scale = 1.0) { + boost::gamma_distribution<> gamma(shape); + boost::variate_generator > vg(m_generator, gamma); + return vg() * scale; + } + + double NextBeta(double alpha, double beta) { + double x = NextGamma(alpha); + double y = NextGamma(beta); + return x / (x + y); + } + bool AcceptMetropolisHastings(const prob_t& p_cur, const prob_t& p_prev, const prob_t& q_cur, @@ -123,11 +136,12 @@ size_t RandomNumberGenerator::SelectSample(const SampleSet& ss, double T const bool anneal = (T != 1.0); F sum = F(0); if (anneal) { - for (int i = 0; i < ss.m_scores.size(); ++i) + for (unsigned i = 0; i < ss.m_scores.size(); ++i) sum += pow(ss.m_scores[i], annealing_factor); // p^(1/T) } else { sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), F(0)); } + //std::cerr << "SUM: " << sum << std::endl; //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ","; //std::cerr << std::endl; -- cgit v1.2.3