summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-10-13 23:50:28 +0200
committerPatrick Simianer <p@simianer.de>2011-10-13 23:50:28 +0200
commit5a5b00f2ad1ef2cb50e9c58bcb77246f3ed99057 (patch)
tree4a06fc2df328afc9151d3234cd3c19be991d0767 /dtrain
parentc7735ab60e22bfec7245dc7af7f14b74459dada8 (diff)
fixed approx bleu
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc7
-rw-r--r--dtrain/kbestget.h21
-rw-r--r--dtrain/ksampler.h2
-rw-r--r--dtrain/score.cc56
-rw-r--r--dtrain/score.h42
-rw-r--r--dtrain/test/example/dtrain.ini1
6 files changed, 61 insertions, 68 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 25858738..f679c9f6 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -105,12 +105,13 @@ main(int argc, char** argv)
string scorer_str = cfg["scorer"].as<string>();
LocalScorer* scorer;
if (scorer_str == "bleu") {
+ scorer = dynamic_cast<BleuScorer*>(new BleuScorer);
} else if (scorer_str == "stupid_bleu") {
scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer);
} else if (scorer_str == "smooth_bleu") {
scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = dynamic_cast<StupidBleuScorer*>(new StupidBleuScorer); // FIXME
+ scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
@@ -145,7 +146,7 @@ main(int argc, char** argv)
// input
string input_fn = cfg["input"].as<string>();
ReadFile input(input_fn);
- // buffer input for t > 0
+ // buffer input for t > 0
vector<string> src_str_buf; // source strings
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
vector<string> weights_files; // remember weights for each iteration
@@ -341,6 +342,8 @@ main(int argc, char** argv)
} // input loop
+ if (scorer_str == "approx_bleu") scorer->Reset();
+
if (t == 0) {
in_sz = ii; // remember size of input (# lines)
grammar_buf_out.close();
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index c0fd3f47..d141da60 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -14,7 +14,7 @@ namespace dtrain
{
-typedef double score_t; // float
+typedef double score_t;
struct ScoredHyp
{
@@ -31,9 +31,11 @@ struct LocalScorer
vector<score_t> w_;
virtual score_t
- Score(vector<WordID>& hyp, vector<WordID>& ref)=0;
+ Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank)=0;
- void
+ void Reset() {} // only for approx bleu
+
+ inline void
Init(unsigned N, vector<score_t> weights)
{
assert(N > 0);
@@ -42,7 +44,7 @@ struct LocalScorer
else w_ = weights;
}
- score_t
+ inline score_t
brevity_penaly(const unsigned hyp_len, const unsigned ref_len)
{
if (hyp_len > ref_len) return 1;
@@ -55,11 +57,10 @@ struct HypSampler : public DecoderObserver
LocalScorer* scorer_;
vector<WordID>* ref_;
virtual vector<ScoredHyp>* GetSamples()=0;
- void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
- void SetRef(vector<WordID>& ref) { ref_ = &ref; }
+ inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
+ inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
};
-/////////////////////////////////////////////////////////////////////
-// wtf
+///////////////////////////////////////////////////////////////////////////////
@@ -107,7 +108,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_);
+ h.score = scorer_->Score(h.w, *ref_, i);
s_.push_back(h);
}
}
@@ -126,7 +127,7 @@ struct KBestGetter : public HypSampler
h.f = d->feature_values;
h.model = log(d->score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_);
+ h.score = scorer_->Score(h.w, *ref_, i);
s_.push_back(h);
}
}
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h
index 7567f43a..276f2cc9 100644
--- a/dtrain/ksampler.h
+++ b/dtrain/ksampler.h
@@ -37,7 +37,7 @@ struct KSampler : public HypSampler
h.f = samples[i].fmap;
h.model = log(samples[i].model_score);
h.rank = i;
- h.score = scorer_->Score(h.w, *ref_);
+ h.score = scorer_->Score(h.w, *ref_, i);
s_.push_back(h);
}
}
diff --git a/dtrain/score.cc b/dtrain/score.cc
index 93c4e80b..f5e920a0 100644
--- a/dtrain/score.cc
+++ b/dtrain/score.cc
@@ -28,7 +28,8 @@ BleuScorer::Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref
}
score_t
-BleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
+BleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref,
+ const unsigned rank)
{
unsigned hyp_len = hyp.size(), ref_len = ref.size();
if (hyp_len == 0 || ref_len == 0) return 0;
@@ -47,7 +48,8 @@ BleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
* NOTE: 0 iff no 1gram match
*/
score_t
-StupidBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
+StupidBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref,
+ const unsigned rank)
{
unsigned hyp_len = hyp.size(), ref_len = ref.size();
if (hyp_len == 0 || ref_len == 0) return 0;
@@ -72,7 +74,8 @@ StupidBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
* NOTE: max is 0.9375
*/
score_t
-SmoothBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
+SmoothBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref,
+ const unsigned rank)
{
unsigned hyp_len = hyp.size(), ref_len = ref.size();
if (hyp_len == 0 || ref_len == 0) return 0;
@@ -87,7 +90,6 @@ SmoothBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
return brevity_penaly(hyp_len, ref_len) * sum;
}
-// FIXME
/*
* approx. bleu
*
@@ -95,38 +97,28 @@ SmoothBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref)
* and Structural Translation Features"
* (Chiang et al. '08)
*/
-/*void
-ApproxBleuScorer::Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len)
-{
- glob_onebest_counts += counts;
- glob_hyp_len += hyp_len;
- glob_ref_len += ref_len;
-}
-
-void
-ApproxBleuScorer::Reset()
-{
- glob_onebest_counts.Zero();
- glob_hyp_len = 0;
- glob_ref_len = 0;
-}
-
score_t
-ApproxBleuScorer::Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id)
+ApproxBleuScorer::Score(vector<WordID>& hyp, vector<WordID>& ref,
+ const unsigned rank)
{
- NgramCounts counts = make_ngram_counts(hyp.w, ref_ids, N_);
- if (id == 0) reset();
- unsigned hyp_len = 0, ref_len = 0;
- if (hyp.rank == 0) { // 'context of 1best translations'
- scorer->prep(counts, hyp.w.size(), ref_ids.size());
- counts.reset();
+ unsigned hyp_len = hyp.size(), ref_len = ref.size();
+ if (hyp_len == 0 || ref_len == 0) return 0;
+ NgramCounts counts = make_ngram_counts(hyp, ref, N_);
+ NgramCounts tmp(N_);
+ if (rank == 0) { // 'context of 1best translations'
+ glob_onebest_counts += counts;
+ glob_hyp_len += hyp_len;
+ glob_ref_len += ref_len;
+ hyp_len = glob_hyp_len;
+ ref_len = glob_ref_len;
+ tmp = glob_onebest_counts;
} else {
- hyp_len = hyp.w.size();
- ref_len = ref_ids.size();
+ hyp_len = hyp.size();
+ ref_len = ref.size();
+ tmp = glob_onebest_counts + counts;
}
- return 0.9 * BleuScorer::Bleu(glob_onebest_counts + counts,
- glob_hyp_len + hyp_len, glob_ref_len + ref_len);
-}*/
+ return 0.9 * Bleu(tmp, hyp_len, ref_len);
+}
} // namespace
diff --git a/dtrain/score.h b/dtrain/score.h
index 9af56ef9..85cd0317 100644
--- a/dtrain/score.h
+++ b/dtrain/score.h
@@ -17,7 +17,7 @@ struct NgramCounts
NgramCounts(const unsigned N) : N_(N) { Zero(); }
- void
+ inline void
operator+=(const NgramCounts& rhs)
{
assert(N_ == rhs.N_);
@@ -27,7 +27,7 @@ struct NgramCounts
}
}
- const NgramCounts
+ inline const NgramCounts
operator+(const NgramCounts &other) const
{
NgramCounts result = *this;
@@ -35,8 +35,8 @@ struct NgramCounts
return result;
}
- void
- Add(unsigned count, unsigned ref_count, unsigned i)
+ inline void
+ Add(const unsigned count, const unsigned ref_count, const unsigned i)
{
assert(i < N_);
if (count > ref_count) {
@@ -47,7 +47,7 @@ struct NgramCounts
sum[i] += count;
}
- void
+ inline void
Zero()
{
unsigned i;
@@ -57,7 +57,7 @@ struct NgramCounts
}
}
- void
+ inline void
Print()
{
for (unsigned i = 0; i < N_; i++) {
@@ -106,38 +106,36 @@ make_ngram_counts(const vector<WordID>& hyp, const vector<WordID>& ref, const un
struct BleuScorer : public LocalScorer
{
score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len);
- score_t Score(vector<WordID>& hyp, vector<WordID>& ref_ids);
+ score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank);
};
struct StupidBleuScorer : public LocalScorer
{
- score_t Score(vector<WordID>& hyp, vector<WordID>& ref);
+ score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank);
};
struct SmoothBleuScorer : public LocalScorer
{
- score_t Score(vector<WordID>& hyp, vector<WordID>& ref);
+ score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank);
};
-// FIXME
-/*struct ApproxBleuScorer : public LocalScorer
+struct ApproxBleuScorer : public BleuScorer
{
- bool prepped;
-
- NgramCounts* glob_onebest_counts;
+ NgramCounts glob_onebest_counts;
unsigned glob_hyp_len, glob_ref_len;
- void Prep(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len);
- void Reset();
- score_t Score(ScoredHyp& hyp, vector<WordID>& ref_ids, unsigned id);
-
- ApproxBleuScorer()
+ ApproxBleuScorer(unsigned N) : glob_onebest_counts(NgramCounts(N))
{
+ glob_hyp_len = glob_ref_len = 0;
+ }
+
+ inline void Reset() {
glob_onebest_counts.Zero();
- glob_hyp_len = 0;
- glob_ref_len = 0;
+ glob_hyp_len = glob_ref_len = 0;
}
-};*/
+
+ score_t Score(vector<WordID>& hyp, vector<WordID>& ref, const unsigned rank);
+};
} // namespace
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index 3e5c2cd1..1e841824 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -2,7 +2,6 @@ decoder_config=test/example/cdec.ini
k=100
N=3
gamma=0
-#gamma=0.00001
epochs=4
input=test/example/nc-1k-tabs.gz
scorer=stupid_bleu