diff options
Diffstat (limited to 'training/candidate_set.cc')
-rw-r--r-- | training/candidate_set.cc | 169 |
1 files changed, 0 insertions, 169 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc deleted file mode 100644 index 087efec3..00000000 --- a/training/candidate_set.cc +++ /dev/null @@ -1,169 +0,0 @@ -#include "candidate_set.h" - -#include <tr1/unordered_set> - -#include <boost/functional/hash.hpp> - -#include "verbose.h" -#include "ns.h" -#include "filelib.h" -#include "wordid.h" -#include "tdict.h" -#include "hg.h" -#include "kbest.h" -#include "viterbi.h" - -using namespace std; - -namespace training { - -struct ApproxVectorHasher { - static const size_t MASK = 0xFFFFFFFFull; - union UType { - double f; // leave as double - size_t i; - }; - static inline double round(const double x) { - UType t; - t.f = x; - size_t r = t.i & MASK; - if ((r << 1) > MASK) - t.i += MASK - r + 1; - else - t.i &= (1ull - MASK); - return t.f; - } - size_t operator()(const SparseVector<double>& x) const { - size_t h = 0x573915839; - for (SparseVector<double>::const_iterator it = x.begin(); it != x.end(); ++it) { - UType t; - t.f = it->second; - if (t.f) { - size_t z = (t.i >> 32); - boost::hash_combine(h, it->first); - boost::hash_combine(h, z); - } - } - return h; - } -}; - -struct ApproxVectorEquals { - bool operator()(const SparseVector<double>& a, const SparseVector<double>& b) const { - SparseVector<double>::const_iterator bit = b.begin(); - for (SparseVector<double>::const_iterator ait = a.begin(); ait != a.end(); ++ait) { - if (bit == b.end() || - ait->first != bit->first || - ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) - return false; - ++bit; - } - if (bit != b.end()) return false; - return true; - } -}; - -struct CandidateCompare { - bool operator()(const Candidate& a, const Candidate& b) const { - ApproxVectorEquals eq; - return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); - } -}; - -struct CandidateHasher { - size_t operator()(const Candidate& x) const { - boost::hash<vector<WordID> > hhasher; - ApproxVectorHasher vhasher; - size_t ha = hhasher(x.ewords); - boost::hash_combine(ha, vhasher(x.fmap)); - return ha; - } -}; - -static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { - SparseVector<double>& x = *out; - size_t last_start = cur; - size_t last_comma = string::npos; - while(cur <= line.size()) { - if (line[cur] == ' ' || cur == line.size()) { - if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { - cerr << "[ERROR] " << line << endl << " position = " << cur << endl; - exit(1); - } - const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); - if (cur < line.size()) line[cur] = 0; - const double val = strtod(&line[last_comma + 1], NULL); - x.set_value(fid, val); - - last_comma = string::npos; - last_start = cur+1; - } else { - if (line[cur] == '=') - last_comma = cur; - } - ++cur; - } -} - -void CandidateSet::WriteToFile(const string& file) const { - WriteFile wf(file); - ostream& out = *wf.stream(); - out.precision(10); - string ss; - for (unsigned i = 0; i < cs.size(); ++i) { - out << TD::GetString(cs[i].ewords) << endl; - out << cs[i].fmap << endl; - cs[i].eval_feats.Encode(&ss); - out << ss << endl; - } -} - -void CandidateSet::ReadFromFile(const string& file) { - if(!SILENT) cerr << "Reading candidates from " << file << endl; - ReadFile rf(file); - istream& in = *rf.stream(); - string cand; - string feats; - string ss; - while(getline(in, cand)) { - getline(in, feats); - getline(in, ss); - assert(in); - cs.push_back(Candidate()); - TD::ConvertSentence(cand, &cs.back().ewords); - ParseSparseVector(feats, 0, &cs.back().fmap); - cs.back().eval_feats = SufficientStats(ss); - } - if(!SILENT) cerr << " read " << cs.size() << " candidates\n"; -} - -void CandidateSet::Dedup() { - if(!SILENT) cerr << "Dedup in=" << cs.size(); - tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u; - while(cs.size() > 0) { - u.insert(cs.back()); - cs.pop_back(); - } - tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin(); - while (it != u.end()) { - cs.push_back(*it); - it = u.erase(it); - } - if(!SILENT) cerr << " out=" << cs.size() << endl; -} - -void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); - - for (unsigned i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - cs.push_back(Candidate(d->yield, d->feature_values)); - if (scorer) - scorer->Evaluate(d->yield, &cs.back().eval_feats); - } - Dedup(); -} - -} |