From b6754386f1109b960b05cdf2eabbc97bdd38e8df Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Mon, 29 Apr 2013 15:24:39 +0200
Subject: fix, cleaned up headers
---
training/dtrain/dtrain.cc | 28 +++++++++++++++++-----------
1 file changed, 17 insertions(+), 11 deletions(-)
(limited to 'training/dtrain/dtrain.cc')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 149f87d4..83e4e440 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -1,4 +1,10 @@
#include "dtrain.h"
+#include "score.h"
+#include "kbestget.h"
+#include "ksampler.h"
+#include "pairsampling.h"
+
+using namespace dtrain;
bool
@@ -138,23 +144,23 @@ main(int argc, char** argv)
string scorer_str = cfg["scorer"].as();
LocalScorer* scorer;
if (scorer_str == "bleu") {
- scorer = dynamic_cast(new BleuScorer);
+ scorer = static_cast(new BleuScorer);
} else if (scorer_str == "stupid_bleu") {
- scorer = dynamic_cast(new StupidBleuScorer);
+ scorer = static_cast(new StupidBleuScorer);
} else if (scorer_str == "fixed_stupid_bleu") {
- scorer = dynamic_cast(new FixedStupidBleuScorer);
+ scorer = static_cast(new FixedStupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
- scorer = dynamic_cast(new SmoothBleuScorer);
+ scorer = static_cast(new SmoothBleuScorer);
} else if (scorer_str == "sum_bleu") {
- scorer = dynamic_cast(new SumBleuScorer);
+ scorer = static_cast(new SumBleuScorer);
} else if (scorer_str == "sumexp_bleu") {
- scorer = dynamic_cast(new SumExpBleuScorer);
+ scorer = static_cast(new SumExpBleuScorer);
} else if (scorer_str == "sumwhatever_bleu") {
- scorer = dynamic_cast(new SumWhateverBleuScorer);
+ scorer = static_cast(new SumWhateverBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d));
+ scorer = static_cast(new ApproxBleuScorer(N, approx_bleu_d));
} else if (scorer_str == "lc_bleu") {
- scorer = dynamic_cast(new LinearBleuScorer(N));
+ scorer = static_cast(new LinearBleuScorer(N));
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
@@ -166,9 +172,9 @@ main(int argc, char** argv)
MT19937 rng; // random number generator, only for forest sampling
HypSampler* observer;
if (sample_from == "kbest")
- observer = dynamic_cast(new KBestGetter(k, filter_type));
+ observer = static_cast(new KBestGetter(k, filter_type));
else
- observer = dynamic_cast(new KSampler(k, &rng));
+ observer = static_cast(new KSampler(k, &rng));
observer->SetScorer(scorer);
// init weights
--
cgit v1.2.3
From 72c9dedc8124977712462c6babbc0c1b0375f813 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Wed, 15 May 2013 13:25:36 +0200
Subject: fixed l1 reg
---
training/dtrain/dtrain.cc | 46 +++++++++++++++++++++++++++-------------------
1 file changed, 27 insertions(+), 19 deletions(-)
(limited to 'training/dtrain/dtrain.cc')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 83e4e440..e1d5a2d4 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -366,6 +366,9 @@ main(int argc, char** argv)
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
+ SparseVector lambdas_copy;
+ if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
+
for (vector >::iterator it = pairs.begin();
it != pairs.end(); it++) {
bool rank_error;
@@ -389,23 +392,26 @@ main(int argc, char** argv)
}
// l1 regularization
- // please note that this penalizes _all_ weights
- // (contrary to only the ones changed by the last update)
- // after a _sentence_ (not after each example/pair)
+ // please note that this regularizations happen
+ // after a _sentence_ -- not after each example/pair!
if (l1naive) {
FastSparseVector::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- it->second -= sign(it->second) * l1_reg;
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second -= sign(it->second) * l1_reg;
+ }
}
} else if (l1clip) {
FastSparseVector::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- if (v > 0) {
- it->second = max(0., v - l1_reg);
- } else {
- it->second = min(0., v + l1_reg);
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ if (v > 0) {
+ it->second = max(0., v - l1_reg);
+ } else {
+ it->second = min(0., v + l1_reg);
+ }
}
}
}
@@ -413,16 +419,18 @@ main(int argc, char** argv)
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
FastSparseVector::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- weight_t penalized = 0.;
- if (v > 0) {
- penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
- } else {
- penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ weight_t penalized = 0.;
+ if (v > 0) {
+ penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
+ } else {
+ penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ }
+ it->second = penalized;
+ cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
- it->second = penalized;
- cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
}
}
--
cgit v1.2.3
From 4ee4f74ae8cf88fd2335267c26cbfb73f3ef8f28 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Mon, 27 May 2013 20:56:57 +0200
Subject: fix
---
training/dtrain/dtrain.cc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'training/dtrain/dtrain.cc')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index e1d5a2d4..0ee2f124 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -378,7 +378,7 @@ main(int argc, char** argv)
margin = std::numeric_limits::max();
} else {
rank_error = it->first.model <= it->second.model;
- margin = fabs(fabs(it->first.model) - fabs(it->second.model));
+ margin = fabs(it->first.model - it->second.model);
if (!rank_error && margin < loss_margin) margin_violations++;
}
if (rank_error) rank_errors++;
--
cgit v1.2.3