summaryrefslogtreecommitdiff
path: root/training/utils/candidate_set.cc
diff options
context:
space:
mode:
authorAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
committerAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
commit3d8d656fa7911524e0e6885647173474524e0784 (patch)
tree81b1ee2fcb67980376d03f0aa48e42e53abff222 /training/utils/candidate_set.cc
parentbe7f57fdd484e063775d7abf083b9fa4c403b610 (diff)
parent96fedabebafe7a38a6d5928be8fff767e411d705 (diff)
fixed conflicts
Diffstat (limited to 'training/utils/candidate_set.cc')
-rw-r--r--training/utils/candidate_set.cc169
1 files changed, 169 insertions, 0 deletions
diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc
new file mode 100644
index 00000000..087efec3
--- /dev/null
+++ b/training/utils/candidate_set.cc
@@ -0,0 +1,169 @@
+#include "candidate_set.h"
+
+#include <tr1/unordered_set>
+
+#include <boost/functional/hash.hpp>
+
+#include "verbose.h"
+#include "ns.h"
+#include "filelib.h"
+#include "wordid.h"
+#include "tdict.h"
+#include "hg.h"
+#include "kbest.h"
+#include "viterbi.h"
+
+using namespace std;
+
+namespace training {
+
+struct ApproxVectorHasher {
+ static const size_t MASK = 0xFFFFFFFFull;
+ union UType {
+ double f; // leave as double
+ size_t i;
+ };
+ static inline double round(const double x) {
+ UType t;
+ t.f = x;
+ size_t r = t.i & MASK;
+ if ((r << 1) > MASK)
+ t.i += MASK - r + 1;
+ else
+ t.i &= (1ull - MASK);
+ return t.f;
+ }
+ size_t operator()(const SparseVector<double>& x) const {
+ size_t h = 0x573915839;
+ for (SparseVector<double>::const_iterator it = x.begin(); it != x.end(); ++it) {
+ UType t;
+ t.f = it->second;
+ if (t.f) {
+ size_t z = (t.i >> 32);
+ boost::hash_combine(h, it->first);
+ boost::hash_combine(h, z);
+ }
+ }
+ return h;
+ }
+};
+
+struct ApproxVectorEquals {
+ bool operator()(const SparseVector<double>& a, const SparseVector<double>& b) const {
+ SparseVector<double>::const_iterator bit = b.begin();
+ for (SparseVector<double>::const_iterator ait = a.begin(); ait != a.end(); ++ait) {
+ if (bit == b.end() ||
+ ait->first != bit->first ||
+ ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second))
+ return false;
+ ++bit;
+ }
+ if (bit != b.end()) return false;
+ return true;
+ }
+};
+
+struct CandidateCompare {
+ bool operator()(const Candidate& a, const Candidate& b) const {
+ ApproxVectorEquals eq;
+ return (a.ewords == b.ewords && eq(a.fmap,b.fmap));
+ }
+};
+
+struct CandidateHasher {
+ size_t operator()(const Candidate& x) const {
+ boost::hash<vector<WordID> > hhasher;
+ ApproxVectorHasher vhasher;
+ size_t ha = hhasher(x.ewords);
+ boost::hash_combine(ha, vhasher(x.fmap));
+ return ha;
+ }
+};
+
+static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
+ SparseVector<double>& x = *out;
+ size_t last_start = cur;
+ size_t last_comma = string::npos;
+ while(cur <= line.size()) {
+ if (line[cur] == ' ' || cur == line.size()) {
+ if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) {
+ cerr << "[ERROR] " << line << endl << " position = " << cur << endl;
+ exit(1);
+ }
+ const int fid = FD::Convert(line.substr(last_start, last_comma - last_start));
+ if (cur < line.size()) line[cur] = 0;
+ const double val = strtod(&line[last_comma + 1], NULL);
+ x.set_value(fid, val);
+
+ last_comma = string::npos;
+ last_start = cur+1;
+ } else {
+ if (line[cur] == '=')
+ last_comma = cur;
+ }
+ ++cur;
+ }
+}
+
+void CandidateSet::WriteToFile(const string& file) const {
+ WriteFile wf(file);
+ ostream& out = *wf.stream();
+ out.precision(10);
+ string ss;
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ out << TD::GetString(cs[i].ewords) << endl;
+ out << cs[i].fmap << endl;
+ cs[i].eval_feats.Encode(&ss);
+ out << ss << endl;
+ }
+}
+
+void CandidateSet::ReadFromFile(const string& file) {
+ if(!SILENT) cerr << "Reading candidates from " << file << endl;
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string cand;
+ string feats;
+ string ss;
+ while(getline(in, cand)) {
+ getline(in, feats);
+ getline(in, ss);
+ assert(in);
+ cs.push_back(Candidate());
+ TD::ConvertSentence(cand, &cs.back().ewords);
+ ParseSparseVector(feats, 0, &cs.back().fmap);
+ cs.back().eval_feats = SufficientStats(ss);
+ }
+ if(!SILENT) cerr << " read " << cs.size() << " candidates\n";
+}
+
+void CandidateSet::Dedup() {
+ if(!SILENT) cerr << "Dedup in=" << cs.size();
+ tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u;
+ while(cs.size() > 0) {
+ u.insert(cs.back());
+ cs.pop_back();
+ }
+ tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin();
+ while (it != u.end()) {
+ cs.push_back(*it);
+ it = u.erase(it);
+ }
+ if(!SILENT) cerr << " out=" << cs.size() << endl;
+}
+
+void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size);
+
+ for (unsigned i = 0; i < kbest_size; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().eval_feats);
+ }
+ Dedup();
+}
+
+}