diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 |
commit | 446fde8f67d4ad8c2699a8e9327a8988c3380723 (patch) | |
tree | d0e968a2dbb84dab205f6df17c6ffab1331fd55f /training/ttables.h | |
parent | 252d628d38b792acbd98e2f74e129994cc9fe3b4 (diff) |
move model 1 code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@665 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/ttables.h')
-rw-r--r-- | training/ttables.h | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/training/ttables.h b/training/ttables.h new file mode 100644 index 00000000..04e54f9d --- /dev/null +++ b/training/ttables.h @@ -0,0 +1,86 @@ +#ifndef _TTABLES_H_ +#define _TTABLES_H_ + +#include <iostream> +#include <tr1/unordered_map> + +#include "wordid.h" +#include "tdict.h" + +class TTable { + public: + TTable() {} + typedef std::tr1::unordered_map<WordID, double> Word2Double; + typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double; + inline const double prob(const int& e, const int& f) const { + const Word2Word2Double::const_iterator cit = ttable.find(e); + if (cit != ttable.end()) { + const Word2Double& cpd = cit->second; + const Word2Double::const_iterator it = cpd.find(f); + if (it == cpd.end()) return 1e-9; + return it->second; + } else { + return 1e-9; + } + } + inline void Increment(const int& e, const int& f) { + counts[e][f] += 1.0; + } + inline void Increment(const int& e, const int& f, double x) { + counts[e][f] += x; + } + void Normalize() { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second /= tot; + } + counts.clear(); + } + // adds counts from another TTable - probabilities remain unchanged + TTable& operator+=(const TTable& rhs) { + for (Word2Word2Double::const_iterator it = rhs.counts.begin(); + it != rhs.counts.end(); ++it) { + const Word2Double& cpd = it->second; + Word2Double& tgt = counts[it->first]; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + tgt[j->first] += j->second; + } + } + return *this; + } + void ShowTTable() { + for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void ShowCounts() { + for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void DeserializeProbsFromText(std::istream* in); + void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } + void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } + void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } + void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } + private: + static void SerializeHelper(std::string*, const Word2Word2Double& o); + static void DeserializeHelper(const std::string&, Word2Word2Double* o); + public: + Word2Word2Double ttable; + Word2Word2Double counts; +}; + +#endif |