summaryrefslogtreecommitdiff
path: root/training/utils
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils')
-rw-r--r--training/utils/candidate_set.cc15
-rw-r--r--training/utils/candidate_set.h2
2 files changed, 16 insertions, 1 deletions
diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc
index 33dae9a3..36f5b271 100644
--- a/training/utils/candidate_set.cc
+++ b/training/utils/candidate_set.cc
@@ -171,4 +171,19 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c
Dedup();
}
+void CandidateSet::AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
+ typedef KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> K;
+ K kbest(hg, kbest_size);
+
+ for (unsigned i = 0; i < kbest_size; ++i) {
+ const K::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().eval_feats);
+ }
+ Dedup();
+}
+
}
diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h
index 9d326ed0..17a650f5 100644
--- a/training/utils/candidate_set.h
+++ b/training/utils/candidate_set.h
@@ -47,7 +47,7 @@ class CandidateSet {
void ReadFromFile(const std::string& file);
void WriteToFile(const std::string& file) const;
void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
- // TODO add code to do unique k-best
+ void AddUniqueKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to draw k samples
private: