From 1b8181bf0d6e9137e6b9ccdbe414aec37377a1a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- training/utils/candidate_set.h | 60 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 training/utils/candidate_set.h (limited to 'training/utils/candidate_set.h') diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h new file mode 100644 index 00000000..9d326ed0 --- /dev/null +++ b/training/utils/candidate_set.h @@ -0,0 +1,60 @@ +#ifndef _CANDIDATE_SET_H_ +#define _CANDIDATE_SET_H_ + +#include +#include + +#include "ns.h" +#include "wordid.h" +#include "sparse_vector.h" + +class Hypergraph; + +namespace training { + +struct Candidate { + Candidate() {} + Candidate(const std::vector& e, const SparseVector& fm) : + ewords(e), + fmap(fm) {} + Candidate(const std::vector& e, + const SparseVector& fm, + const SegmentEvaluator& se) : + ewords(e), + fmap(fm) { + se.Evaluate(ewords, &eval_feats); + } + + void swap(Candidate& other) { + eval_feats.swap(other.eval_feats); + ewords.swap(other.ewords); + fmap.swap(other.fmap); + } + + std::vector ewords; + SparseVector fmap; + SufficientStats eval_feats; +}; + +// represents some kind of collection of translation candidates, e.g. +// aggregated k-best lists, sample lists, etc. +class CandidateSet { + public: + CandidateSet() {} + inline size_t size() const { return cs.size(); } + const Candidate& operator[](size_t i) const { return cs[i]; } + + void ReadFromFile(const std::string& file); + void WriteToFile(const std::string& file) const; + void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); + // TODO add code to do unique k-best + // TODO add code to draw k samples + + private: + void Dedup(); + std::vector cs; +}; + +} + +#endif -- cgit v1.2.3