summaryrefslogtreecommitdiff
path: root/training/utils
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-11-13 18:12:10 +0100
committerPatrick Simianer <p@simianer.de>2013-11-13 18:12:10 +0100
commitd6e6babf2cfe49fed040b651624b7e34d1a9b507 (patch)
tree2a00ab18f10a7f93e7e172551c01b48cc9f20b8c /training/utils
parent2d2d5eced93d58bc77894d8c328195cd9950b96d (diff)
parent8a24bb77bc2e9fd17a6f6529a2942cde96a6af49 (diff)
merge w/ upstream
Diffstat (limited to 'training/utils')
-rw-r--r--training/utils/candidate_set.cc11
-rw-r--r--training/utils/online_optimizer.h8
-rw-r--r--training/utils/optimize_test.cc6
3 files changed, 14 insertions, 11 deletions
diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc
index 087efec3..33dae9a3 100644
--- a/training/utils/candidate_set.cc
+++ b/training/utils/candidate_set.cc
@@ -1,6 +1,11 @@
#include "candidate_set.h"
-#include <tr1/unordered_set>
+#ifndef HAVE_OLD_CPP
+# include <unordered_set>
+#else
+# include <tr1/unordered_set>
+namespace std { using std::tr1::unordered_set; }
+#endif
#include <boost/functional/hash.hpp>
@@ -139,12 +144,12 @@ void CandidateSet::ReadFromFile(const string& file) {
void CandidateSet::Dedup() {
if(!SILENT) cerr << "Dedup in=" << cs.size();
- tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u;
+ unordered_set<Candidate, CandidateHasher, CandidateCompare> u;
while(cs.size() > 0) {
u.insert(cs.back());
cs.pop_back();
}
- tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin();
+ unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin();
while (it != u.end()) {
cs.push_back(*it);
it = u.erase(it);
diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h
index 28d89344..19223e9d 100644
--- a/training/utils/online_optimizer.h
+++ b/training/utils/online_optimizer.h
@@ -1,10 +1,10 @@
#ifndef _ONL_OPTIMIZE_H_
#define _ONL_OPTIMIZE_H_
-#include <tr1/memory>
#include <set>
#include <string>
#include <cmath>
+#include <boost/shared_ptr.hpp>
#include "sparse_vector.h"
struct LearningRateSchedule {
@@ -56,7 +56,7 @@ struct ExponentialDecayLearningRate : public LearningRateSchedule {
class OnlineOptimizer {
public:
virtual ~OnlineOptimizer();
- OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ OnlineOptimizer(const boost::shared_ptr<LearningRateSchedule>& s,
size_t batch_size,
const std::vector<int>& frozen_feats = std::vector<int>())
: N_(batch_size),schedule_(s),k_() {
@@ -77,13 +77,13 @@ class OnlineOptimizer {
std::set<int> frozen_; // frozen (non-optimizing) features
private:
- std::tr1::shared_ptr<LearningRateSchedule> schedule_;
+ boost::shared_ptr<LearningRateSchedule> schedule_;
int k_; // iteration count
};
class CumulativeL1OnlineOptimizer : public OnlineOptimizer {
public:
- CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr<LearningRateSchedule>& s,
+ CumulativeL1OnlineOptimizer(const boost::shared_ptr<LearningRateSchedule>& s,
size_t training_instances, double C,
const std::vector<int>& frozen) :
OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {}
diff --git a/training/utils/optimize_test.cc b/training/utils/optimize_test.cc
index bff2ca03..72fcef6d 100644
--- a/training/utils/optimize_test.cc
+++ b/training/utils/optimize_test.cc
@@ -2,6 +2,7 @@
#include <iostream>
#include <sstream>
#include <boost/program_options/variables_map.hpp>
+#include <boost/shared_ptr.hpp>
#include "optimize.h"
#include "online_optimizer.h"
#include "sparse_vector.h"
@@ -96,14 +97,11 @@ void TestOptimizerVariants(int num_vars) {
cerr << oa.Name() << " SUCCESS\n";
}
-using namespace std::tr1;
-
void TestOnline() {
size_t N = 20;
double C = 1.0;
double eta0 = 0.2;
- std::tr1::shared_ptr<LearningRateSchedule> r(new ExponentialDecayLearningRate(N, eta0, 0.85));
- //shared_ptr<LearningRateSchedule> r(new StandardLearningRate(N, eta0));
+ boost::shared_ptr<LearningRateSchedule> r(new ExponentialDecayLearningRate(N, eta0, 0.85));
CumulativeL1OnlineOptimizer opt(r, N, C, std::vector<int>());
assert(r->eta(10) < r->eta(1));
}