summaryrefslogtreecommitdiff
path: root/word-aligner/ttables.h
diff options
context:
space:
mode:
Diffstat (limited to 'word-aligner/ttables.h')
-rw-r--r--word-aligner/ttables.h50
1 files changed, 25 insertions, 25 deletions
diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h
index d82aff72..b9964225 100644
--- a/word-aligner/ttables.h
+++ b/word-aligner/ttables.h
@@ -2,6 +2,7 @@
#define _TTABLES_H_
#include <iostream>
+#include <vector>
#ifndef HAVE_OLD_CPP
# include <unordered_map>
#else
@@ -18,11 +19,10 @@ class TTable {
public:
TTable() {}
typedef std::unordered_map<WordID, double> Word2Double;
- typedef std::unordered_map<WordID, Word2Double> Word2Word2Double;
+ typedef std::vector<Word2Double> Word2Word2Double;
inline double prob(const int& e, const int& f) const {
- const Word2Word2Double::const_iterator cit = ttable.find(e);
- if (cit != ttable.end()) {
- const Word2Double& cpd = cit->second;
+ if (e < static_cast<int>(ttable.size())) {
+ const Word2Double& cpd = ttable[e];
const Word2Double::const_iterator it = cpd.find(f);
if (it == cpd.end()) return 1e-9;
return it->second;
@@ -31,19 +31,21 @@ class TTable {
}
}
inline void Increment(const int& e, const int& f) {
+ if (e >= static_cast<int>(ttable.size())) counts.resize(e + 1);
counts[e][f] += 1.0;
}
inline void Increment(const int& e, const int& f, double x) {
+ if (e >= static_cast<int>(counts.size())) counts.resize(e + 1);
counts[e][f] += x;
}
void NormalizeVB(const double alpha) {
ttable.swap(counts);
- for (Word2Word2Double::iterator cit = ttable.begin();
- cit != ttable.end(); ++cit) {
+ for (unsigned i = 0; i < ttable.size(); ++i) {
double tot = 0;
- Word2Double& cpd = cit->second;
+ Word2Double& cpd = ttable[i];
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second + alpha;
+ if (!tot) tot = 1;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot));
}
@@ -51,12 +53,12 @@ class TTable {
}
void Normalize() {
ttable.swap(counts);
- for (Word2Word2Double::iterator cit = ttable.begin();
- cit != ttable.end(); ++cit) {
+ for (unsigned i = 0; i < ttable.size(); ++i) {
double tot = 0;
- Word2Double& cpd = cit->second;
+ Word2Double& cpd = ttable[i];
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second;
+ if (!tot) tot = 1;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
it->second /= tot;
}
@@ -64,29 +66,27 @@ class TTable {
}
// adds counts from another TTable - probabilities remain unchanged
TTable& operator+=(const TTable& rhs) {
- for (Word2Word2Double::const_iterator it = rhs.counts.begin();
- it != rhs.counts.end(); ++it) {
- const Word2Double& cpd = it->second;
- Word2Double& tgt = counts[it->first];
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- tgt[j->first] += j->second;
- }
+ if (rhs.counts.size() > counts.size()) counts.resize(rhs.counts.size());
+ for (unsigned i = 0; i < rhs.counts.size(); ++i) {
+ const Word2Double& cpd = rhs.counts[i];
+ Word2Double& tgt = counts[i];
+ for (auto p : cpd) tgt[p.first] += p.second;
}
return *this;
}
void ShowTTable() const {
- for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) {
- const Word2Double& cpd = it->second;
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ for (unsigned it = 0; it < ttable.size(); ++it) {
+ const Word2Double& cpd = ttable[it];
+ for (auto& p : cpd) {
+ std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
}
}
}
void ShowCounts() const {
- for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) {
- const Word2Double& cpd = it->second;
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ for (unsigned it = 0; it < counts.size(); ++it) {
+ const Word2Double& cpd = counts[it];
+ for (auto& p : cpd) {
+ std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
}
}
}