From b6754386f1109b960b05cdf2eabbc97bdd38e8df Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Mon, 29 Apr 2013 15:24:39 +0200 Subject: fix, cleaned up headers --- training/dtrain/dtrain.cc | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) (limited to 'training/dtrain/dtrain.cc') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 149f87d4..83e4e440 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -1,4 +1,10 @@ #include "dtrain.h" +#include "score.h" +#include "kbestget.h" +#include "ksampler.h" +#include "pairsampling.h" + +using namespace dtrain; bool @@ -138,23 +144,23 @@ main(int argc, char** argv) string scorer_str = cfg["scorer"].as(); LocalScorer* scorer; if (scorer_str == "bleu") { - scorer = dynamic_cast(new BleuScorer); + scorer = static_cast(new BleuScorer); } else if (scorer_str == "stupid_bleu") { - scorer = dynamic_cast(new StupidBleuScorer); + scorer = static_cast(new StupidBleuScorer); } else if (scorer_str == "fixed_stupid_bleu") { - scorer = dynamic_cast(new FixedStupidBleuScorer); + scorer = static_cast(new FixedStupidBleuScorer); } else if (scorer_str == "smooth_bleu") { - scorer = dynamic_cast(new SmoothBleuScorer); + scorer = static_cast(new SmoothBleuScorer); } else if (scorer_str == "sum_bleu") { - scorer = dynamic_cast(new SumBleuScorer); + scorer = static_cast(new SumBleuScorer); } else if (scorer_str == "sumexp_bleu") { - scorer = dynamic_cast(new SumExpBleuScorer); + scorer = static_cast(new SumExpBleuScorer); } else if (scorer_str == "sumwhatever_bleu") { - scorer = dynamic_cast(new SumWhateverBleuScorer); + scorer = static_cast(new SumWhateverBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d)); + scorer = static_cast(new ApproxBleuScorer(N, approx_bleu_d)); } else if (scorer_str == "lc_bleu") { - scorer = dynamic_cast(new LinearBleuScorer(N)); + scorer = static_cast(new LinearBleuScorer(N)); } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); @@ -166,9 +172,9 @@ main(int argc, char** argv) MT19937 rng; // random number generator, only for forest sampling HypSampler* observer; if (sample_from == "kbest") - observer = dynamic_cast(new KBestGetter(k, filter_type)); + observer = static_cast(new KBestGetter(k, filter_type)); else - observer = dynamic_cast(new KSampler(k, &rng)); + observer = static_cast(new KSampler(k, &rng)); observer->SetScorer(scorer); // init weights -- cgit v1.2.3 From 72c9dedc8124977712462c6babbc0c1b0375f813 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Wed, 15 May 2013 13:25:36 +0200 Subject: fixed l1 reg --- training/dtrain/dtrain.cc | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) (limited to 'training/dtrain/dtrain.cc') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 83e4e440..e1d5a2d4 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -366,6 +366,9 @@ main(int argc, char** argv) PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); + SparseVector lambdas_copy; + if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; + for (vector >::iterator it = pairs.begin(); it != pairs.end(); it++) { bool rank_error; @@ -389,23 +392,26 @@ main(int argc, char** argv) } // l1 regularization - // please note that this penalizes _all_ weights - // (contrary to only the ones changed by the last update) - // after a _sentence_ (not after each example/pair) + // please note that this regularizations happen + // after a _sentence_ -- not after each example/pair! if (l1naive) { FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - it->second -= sign(it->second) * l1_reg; + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second -= sign(it->second) * l1_reg; + } } } else if (l1clip) { FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - if (v > 0) { - it->second = max(0., v - l1_reg); - } else { - it->second = min(0., v + l1_reg); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + if (v > 0) { + it->second = max(0., v - l1_reg); + } else { + it->second = min(0., v + l1_reg); + } } } } @@ -413,16 +419,18 @@ main(int argc, char** argv) weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - weight_t penalized = 0.; - if (v > 0) { - penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); - } else { - penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + weight_t penalized = 0.; + if (v > 0) { + penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); + } else { + penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + } + it->second = penalized; + cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } - it->second = penalized; - cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } } } -- cgit v1.2.3 From 4ee4f74ae8cf88fd2335267c26cbfb73f3ef8f28 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Mon, 27 May 2013 20:56:57 +0200 Subject: fix --- training/dtrain/dtrain.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'training/dtrain/dtrain.cc') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index e1d5a2d4..0ee2f124 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -378,7 +378,7 @@ main(int argc, char** argv) margin = std::numeric_limits::max(); } else { rank_error = it->first.model <= it->second.model; - margin = fabs(fabs(it->first.model) - fabs(it->second.model)); + margin = fabs(it->first.model - it->second.model); if (!rank_error && margin < loss_margin) margin_violations++; } if (rank_error) rank_errors++; -- cgit v1.2.3