From 8aa29810bb77611cc20b7a384897ff6703783ea1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- training/utils/candidate_set.cc | 169 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 training/utils/candidate_set.cc (limited to 'training/utils/candidate_set.cc') diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc new file mode 100644 index 00000000..087efec3 --- /dev/null +++ b/training/utils/candidate_set.cc @@ -0,0 +1,169 @@ +#include "candidate_set.h" + +#include + +#include + +#include "verbose.h" +#include "ns.h" +#include "filelib.h" +#include "wordid.h" +#include "tdict.h" +#include "hg.h" +#include "kbest.h" +#include "viterbi.h" + +using namespace std; + +namespace training { + +struct ApproxVectorHasher { + static const size_t MASK = 0xFFFFFFFFull; + union UType { + double f; // leave as double + size_t i; + }; + static inline double round(const double x) { + UType t; + t.f = x; + size_t r = t.i & MASK; + if ((r << 1) > MASK) + t.i += MASK - r + 1; + else + t.i &= (1ull - MASK); + return t.f; + } + size_t operator()(const SparseVector& x) const { + size_t h = 0x573915839; + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { + UType t; + t.f = it->second; + if (t.f) { + size_t z = (t.i >> 32); + boost::hash_combine(h, it->first); + boost::hash_combine(h, z); + } + } + return h; + } +}; + +struct ApproxVectorEquals { + bool operator()(const SparseVector& a, const SparseVector& b) const { + SparseVector::const_iterator bit = b.begin(); + for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { + if (bit == b.end() || + ait->first != bit->first || + ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) + return false; + ++bit; + } + if (bit != b.end()) return false; + return true; + } +}; + +struct CandidateCompare { + bool operator()(const Candidate& a, const Candidate& b) const { + ApproxVectorEquals eq; + return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); + } +}; + +struct CandidateHasher { + size_t operator()(const Candidate& x) const { + boost::hash > hhasher; + ApproxVectorHasher vhasher; + size_t ha = hhasher(x.ewords); + boost::hash_combine(ha, vhasher(x.fmap)); + return ha; + } +}; + +static void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + +void CandidateSet::WriteToFile(const string& file) const { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + string ss; + for (unsigned i = 0; i < cs.size(); ++i) { + out << TD::GetString(cs[i].ewords) << endl; + out << cs[i].fmap << endl; + cs[i].eval_feats.Encode(&ss); + out << ss << endl; + } +} + +void CandidateSet::ReadFromFile(const string& file) { + if(!SILENT) cerr << "Reading candidates from " << file << endl; + ReadFile rf(file); + istream& in = *rf.stream(); + string cand; + string feats; + string ss; + while(getline(in, cand)) { + getline(in, feats); + getline(in, ss); + assert(in); + cs.push_back(Candidate()); + TD::ConvertSentence(cand, &cs.back().ewords); + ParseSparseVector(feats, 0, &cs.back().fmap); + cs.back().eval_feats = SufficientStats(ss); + } + if(!SILENT) cerr << " read " << cs.size() << " candidates\n"; +} + +void CandidateSet::Dedup() { + if(!SILENT) cerr << "Dedup in=" << cs.size(); + tr1::unordered_set u; + while(cs.size() > 0) { + u.insert(cs.back()); + cs.pop_back(); + } + tr1::unordered_set::iterator it = u.begin(); + while (it != u.end()) { + cs.push_back(*it); + it = u.erase(it); + } + if(!SILENT) cerr << " out=" << cs.size() << endl; +} + +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + for (unsigned i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().eval_feats); + } + Dedup(); +} + +} -- cgit v1.2.3