From d903f9e05865d134d272b87a8ca989f159188164 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Tue, 26 Apr 2016 16:17:25 +0200
Subject: dtrain: all pairs
---
training/dtrain/dtrain.cc | 5 +++--
training/dtrain/dtrain.h | 4 +++-
training/dtrain/update.h | 27 +++++++++++++++++++--------
3 files changed, 25 insertions(+), 11 deletions(-)
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index bbe14547..30c64037 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -48,6 +48,7 @@ main(int argc, char** argv)
const string adadelta_output = conf["adadelta_output"].as();
const size_t max_input = conf["stop_after"].as();
const bool batch = conf["batch"].as();
+ const bool all = conf["all"].as();
// setup decoder
register_feature_functions();
@@ -267,14 +268,14 @@ main(int argc, char** argv)
num_up += update_structured(sample, updates, margin,
out_up, i);
else if (all_pairs)
- num_up += updates_all(sample, updates, max_up, threshold,
+ num_up += updates_all(sample, updates, max_up, margin, threshold, all,
out_up, i);
else if (pro)
num_up += updates_pro(sample, updates, cut, max_up, threshold,
out_up, i);
else
num_up += updates_multipartite(sample, updates, cut, margin,
- max_up, threshold, adjust_cut,
+ max_up, threshold, adjust_cut, all,
out_up, i);
SparseVector lambdas_copy;
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 883e6028..83dcd945 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -71,6 +71,8 @@ dtrain_init(int argc,
"use top/bottom 10% (default) of k-best as 'good' and 'bad' for pair sampling, 0 to use all pairs")
("adjust,A", po::bool_switch()->default_value(false),
"adjust cut for optimal pos. in k-best to cut")
+ ("all,A", po::bool_switch()->default_value(false),
+ "update using all pairs, ignoring margin and threshold")
("score,s", po::value()->default_value("nakov"),
"per-sentence BLEU (approx.)")
("nakov_fix", po::value()->default_value(1.0),
@@ -106,7 +108,7 @@ dtrain_init(int argc,
"output raw data (e.g. k-best lists) [to filename]")
("stop_after", po::value()->default_value(numeric_limits::max()),
"only look at this number of segments")
- ("print_weights,P", po::value()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
+ ("print_weights", po::value()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
"list of weights to print after each iteration");
po::options_description clopts("Command Line Options");
clopts.add_options()
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index 405a3f76..edcfe391 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -20,6 +20,7 @@ updates_multipartite(vector* sample,
size_t max_up,
weight_t threshold,
bool adjust,
+ bool all,
WriteFile& output,
size_t id)
{
@@ -44,9 +45,11 @@ updates_multipartite(vector* sample,
for (size_t i = 0; i < sep_hi; i++) {
for (size_t j = sep_hi; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.model-second.model)>margin
- || (first.gold==second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold==second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;
@@ -64,9 +67,11 @@ updates_multipartite(vector* sample,
for (size_t i = sep_hi; i < sep_lo; i++) {
for (size_t j = sep_lo; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.model-second.model)>margin
- || (first.gold==second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold==second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;
@@ -83,13 +88,16 @@ updates_multipartite(vector* sample,
* all pairs
* only ignore a pair if gold scores are
* identical
+ * FIXME: that's really _all_
*
*/
inline size_t
updates_all(vector* sample,
SparseVector& updates,
size_t max_up,
+ weight_t margin,
weight_t threshold,
+ bool all,
WriteFile output,
size_t id)
{
@@ -102,8 +110,11 @@ updates_all(vector* sample,
for (size_t i = 0; i < sz-1; i++) {
for (size_t j = i+1; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if ((first.gold == second.gold)
- || (threshold && (first.gold-second.gold < threshold)))
+ if (first.gold == second.gold)
+ continue;
+ if (!all
+ && (((first.model-second.model)>margin)
+ || (threshold && (first.gold-second.gold < threshold))))
continue;
if (output)
*output << id << "\t" << first.f-second.f << endl;
--
cgit v1.2.3