diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2012-11-18 13:35:42 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2012-11-18 13:35:42 -0500 |
commit | 8aa29810bb77611cc20b7a384897ff6703783ea1 (patch) | |
tree | 8635daa8fffb3f2cd90e30b41e27f4f9e0909447 /word-aligner/ttables.h | |
parent | fbdacabc85bea65d735f2cb7f92b98e08ce72d04 (diff) |
major restructure of the training code
Diffstat (limited to 'word-aligner/ttables.h')
-rw-r--r-- | word-aligner/ttables.h | 101 |
1 files changed, 101 insertions, 0 deletions
diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h new file mode 100644 index 00000000..9baa13ca --- /dev/null +++ b/word-aligner/ttables.h @@ -0,0 +1,101 @@ +#ifndef _TTABLES_H_ +#define _TTABLES_H_ + +#include <iostream> +#include <tr1/unordered_map> + +#include "sparse_vector.h" +#include "m.h" +#include "wordid.h" +#include "tdict.h" + +class TTable { + public: + TTable() {} + typedef std::tr1::unordered_map<WordID, double> Word2Double; + typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double; + inline double prob(const int& e, const int& f) const { + const Word2Word2Double::const_iterator cit = ttable.find(e); + if (cit != ttable.end()) { + const Word2Double& cpd = cit->second; + const Word2Double::const_iterator it = cpd.find(f); + if (it == cpd.end()) return 1e-9; + return it->second; + } else { + return 1e-9; + } + } + inline void Increment(const int& e, const int& f) { + counts[e][f] += 1.0; + } + inline void Increment(const int& e, const int& f, double x) { + counts[e][f] += x; + } + void NormalizeVB(const double alpha) { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second + alpha; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); + } + counts.clear(); + } + void Normalize() { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second /= tot; + } + counts.clear(); + } + // adds counts from another TTable - probabilities remain unchanged + TTable& operator+=(const TTable& rhs) { + for (Word2Word2Double::const_iterator it = rhs.counts.begin(); + it != rhs.counts.end(); ++it) { + const Word2Double& cpd = it->second; + Word2Double& tgt = counts[it->first]; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + tgt[j->first] += j->second; + } + } + return *this; + } + void ShowTTable() const { + for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) { + const Word2Double& cpd = it->second; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void ShowCounts() const { + for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) { + const Word2Double& cpd = it->second; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void DeserializeProbsFromText(std::istream* in); + void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } + void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } + void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } + void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } + private: + static void SerializeHelper(std::string*, const Word2Word2Double& o); + static void DeserializeHelper(const std::string&, Word2Word2Double* o); + public: + Word2Word2Double ttable; + Word2Word2Double counts; +}; + +#endif |