summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2016-04-26 16:17:25 +0200
committerPatrick Simianer <p@simianer.de>2016-04-26 16:17:25 +0200
commitd903f9e05865d134d272b87a8ca989f159188164 (patch)
treeb59ae70a1745a09db3a330f3b2a4d55a641c40f5
parentbc5be91606bd57243cd7f1ba10404c9477630497 (diff)
dtrain: all pairs
-rw-r--r--training/dtrain/dtrain.cc5
-rw-r--r--training/dtrain/dtrain.h4
-rw-r--r--training/dtrain/update.h27
3 files changed, 25 insertions, 11 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index bbe14547..30c64037 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -48,6 +48,7 @@ main(int argc, char** argv)
const string adadelta_output = conf["adadelta_output"].as<string>();
const size_t max_input = conf["stop_after"].as<size_t>();
const bool batch = conf["batch"].as<bool>();
+ const bool all = conf["all"].as<bool>();
// setup decoder
register_feature_functions();
@@ -267,14 +268,14 @@ main(int argc, char** argv)
num_up += update_structured(sample, updates, margin,
out_up, i);
else if (all_pairs)
- num_up += updates_all(sample, updates, max_up, threshold,
+ num_up += updates_all(sample, updates, max_up, margin, threshold, all,
out_up, i);
else if (pro)
num_up += updates_pro(sample, updates, cut, max_up, threshold,
out_up, i);
else
num_up += updates_multipartite(sample, updates, cut, margin,
- max_up, threshold, adjust_cut,
+ max_up, threshold, adjust_cut, all,
out_up, i);
SparseVector<weight_t> lambdas_copy;
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 883e6028..83dcd945 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -71,6 +71,8 @@ dtrain_init(int argc,
"use top/bottom 10% (default) of k-best as 'good' and 'bad' for pair sampling, 0 to use all pairs")
("adjust,A", po::bool_switch()->default_value(false),
"adjust cut for optimal pos. in k-best to cut")
+ ("all,A", po::bool_switch()->default_value(false),
+ "update using all pairs, ignoring margin and threshold")
("score,s", po::value<string>()->default_value("nakov"),
"per-sentence BLEU (approx.)")
("nakov_fix", po::value<weight_t>()->default_value(1.0),
@@ -106,7 +108,7 @@ dtrain_init(int argc,
"output raw data (e.g. k-best lists) [to filename]")
("stop_after", po::value<size_t>()->default_value(numeric_limits<size_t>::max()),
"only look at this number of segments")
- ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
+ ("print_weights", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
"list of weights to print after each iteration");
po::options_description clopts("Command Line Options");
clopts.add_options()
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index 405a3f76..edcfe391 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -20,6 +20,7 @@ updates_multipartite(vector<Hyp>* sample,
size_t max_up,
weight_t threshold,
bool adjust,
+ bool all,
WriteFile& output,
size_t id)
{
@@ -44,9 +45,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t i = 0; i < sep_hi; i++) {
for (size_t j = sep_hi; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.model-second.model)>margin
- || (first.gold==second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold==second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;
@@ -64,9 +67,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t i = sep_hi; i < sep_lo; i++) {
for (size_t j = sep_lo; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.model-second.model)>margin
- || (first.gold==second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold==second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;
@@ -83,13 +88,16 @@ updates_multipartite(vector<Hyp>* sample,
* all pairs
* only ignore a pair if gold scores are
* identical
+ * FIXME: that's really _all_
*
*/
inline size_t
updates_all(vector<Hyp>* sample,
SparseVector<weight_t>& updates,
size_t max_up,
+ weight_t margin,
weight_t threshold,
+ bool all,
WriteFile output,
size_t id)
{
@@ -102,8 +110,11 @@ updates_all(vector<Hyp>* sample,
for (size_t i = 0; i < sz-1; i++) {
for (size_t j = i+1; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.gold == second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold == second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;