summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-08-19 08:24:48 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-08-19 08:24:48 -0700
commitcd666f441d91109d402e4f3993a9ec3c45306dd0 (patch)
tree3ef5083b5a52929b89ed18730104aace4934faf6 /training/dtrain/dtrain.cc
parentac469cdf4c70154a1c2cedce9edf5cdc3bdb2d61 (diff)
parent951e7daa9539ffe640f9421897c374f786af53e7 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc76
1 files changed, 45 insertions, 31 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 149f87d4..0ee2f124 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -1,4 +1,10 @@
#include "dtrain.h"
+#include "score.h"
+#include "kbestget.h"
+#include "ksampler.h"
+#include "pairsampling.h"
+
+using namespace dtrain;
bool
@@ -138,23 +144,23 @@ main(int argc, char** argv)
string scorer_str = cfg["scorer"].as<string>();
LocalScorer* scorer;
if (scorer_str == "bleu") {
- scorer = dynamic_cast<BleuScorer*>(new BleuScorer);
+ scorer = static_cast<BleuScorer*>(new BleuScorer);
} else if (scorer_str == "stupid_bleu") {
- scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer);
+ scorer = static_cast<StupidBleuScorer*>(new StupidBleuScorer);
} else if (scorer_str == "fixed_stupid_bleu") {
- scorer = dynamic_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
+ scorer = static_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
- scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
+ scorer = static_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "sum_bleu") {
- scorer = dynamic_cast<SumBleuScorer*>(new SumBleuScorer);
+ scorer = static_cast<SumBleuScorer*>(new SumBleuScorer);
} else if (scorer_str == "sumexp_bleu") {
- scorer = dynamic_cast<SumExpBleuScorer*>(new SumExpBleuScorer);
+ scorer = static_cast<SumExpBleuScorer*>(new SumExpBleuScorer);
} else if (scorer_str == "sumwhatever_bleu") {
- scorer = dynamic_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer);
+ scorer = static_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
+ scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
} else if (scorer_str == "lc_bleu") {
- scorer = dynamic_cast<LinearBleuScorer*>(new LinearBleuScorer(N));
+ scorer = static_cast<LinearBleuScorer*>(new LinearBleuScorer(N));
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
@@ -166,9 +172,9 @@ main(int argc, char** argv)
MT19937 rng; // random number generator, only for forest sampling
HypSampler* observer;
if (sample_from == "kbest")
- observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
+ observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type));
else
- observer = dynamic_cast<KSampler*>(new KSampler(k, &rng));
+ observer = static_cast<KSampler*>(new KSampler(k, &rng));
observer->SetScorer(scorer);
// init weights
@@ -360,6 +366,9 @@ main(int argc, char** argv)
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
+ SparseVector<weight_t> lambdas_copy;
+ if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
+
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
bool rank_error;
@@ -369,7 +378,7 @@ main(int argc, char** argv)
margin = std::numeric_limits<float>::max();
} else {
rank_error = it->first.model <= it->second.model;
- margin = fabs(fabs(it->first.model) - fabs(it->second.model));
+ margin = fabs(it->first.model - it->second.model);
if (!rank_error && margin < loss_margin) margin_violations++;
}
if (rank_error) rank_errors++;
@@ -383,23 +392,26 @@ main(int argc, char** argv)
}
// l1 regularization
- // please note that this penalizes _all_ weights
- // (contrary to only the ones changed by the last update)
- // after a _sentence_ (not after each example/pair)
+ // please note that this regularizations happen
+ // after a _sentence_ -- not after each example/pair!
if (l1naive) {
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- it->second -= sign(it->second) * l1_reg;
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second -= sign(it->second) * l1_reg;
+ }
}
} else if (l1clip) {
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- if (v > 0) {
- it->second = max(0., v - l1_reg);
- } else {
- it->second = min(0., v + l1_reg);
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ if (v > 0) {
+ it->second = max(0., v - l1_reg);
+ } else {
+ it->second = min(0., v + l1_reg);
+ }
}
}
}
@@ -407,16 +419,18 @@ main(int argc, char** argv)
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- weight_t penalized = 0.;
- if (v > 0) {
- penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
- } else {
- penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ weight_t penalized = 0.;
+ if (v > 0) {
+ penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
+ } else {
+ penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ }
+ it->second = penalized;
+ cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
- it->second = penalized;
- cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
}
}