summaryrefslogtreecommitdiff
path: root/dtrain/kbestget.h
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/kbestget.h')
-rw-r--r--dtrain/kbestget.h11
1 files changed, 9 insertions, 2 deletions
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index 77d4a139..dd8882e1 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -59,9 +59,12 @@ struct HypSampler : public DecoderObserver
{
LocalScorer* scorer_;
vector<WordID>* ref_;
+ unsigned f_count_, sz_;
virtual vector<ScoredHyp>* GetSamples()=0;
inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
+ inline unsigned get_f_count() { return f_count_; }
+ inline unsigned get_sz() { return sz_; }
};
////////////////////////////////////////////////////////////////////////////////
@@ -100,7 +103,7 @@ struct KBestGetter : public HypSampler
void
KBestUnique(const Hypergraph& forest)
{
- s_.clear();
+ s_.clear(); sz_ = f_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
for (unsigned i = 0; i < k_; ++i) {
@@ -115,13 +118,15 @@ struct KBestGetter : public HypSampler
h.rank = i;
h.score = scorer_->Score(h.w, *ref_, i, src_len_);
s_.push_back(h);
+ sz_++;
+ f_count_ += h.f.size();
}
}
void
KBestNoFilter(const Hypergraph& forest)
{
- s_.clear();
+ s_.clear(); sz_ = f_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
for (unsigned i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
@@ -134,6 +139,8 @@ struct KBestGetter : public HypSampler
h.rank = i;
h.score = scorer_->Score(h.w, *ref_, i, src_len_);
s_.push_back(h);
+ sz_++;
+ f_count_ += h.f.size();
}
}
};