summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/Makefile.am2
-rw-r--r--training/model1.cc25
-rw-r--r--training/ttables.cc31
-rw-r--r--training/ttables.h86
4 files changed, 131 insertions, 13 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index 2679adea..7cdf10d7 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -34,7 +34,7 @@ online_train_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutil
atools_SOURCES = atools.cc
atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
-model1_SOURCES = model1.cc
+model1_SOURCES = model1.cc ttables.cc
model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
grammar_convert_SOURCES = grammar_convert.cc
diff --git a/training/model1.cc b/training/model1.cc
index f571700f..92a70985 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -1,4 +1,5 @@
#include <iostream>
+#include <cmath>
#include "lattice.h"
#include "stringlib.h"
@@ -14,7 +15,7 @@ int main(int argc, char** argv) {
return 1;
}
const int ITERATIONS = 5;
- const prob_t BEAM_THRESHOLD(0.0001);
+ const double BEAM_THRESHOLD = 0.0001;
TTable tt;
const WordID kNULL = TD::Convert("<eps>");
bool use_null = true;
@@ -24,7 +25,7 @@ int main(int argc, char** argv) {
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
ReadFile rf(argv[1]);
istream& in = *rf.stream();
- prob_t likelihood = prob_t::One();
+ double likelihood = 0;
double denom = 0.0;
int lc = 0;
bool flag = false;
@@ -43,10 +44,10 @@ int main(int argc, char** argv) {
assert(src.size() > 0);
assert(trg.size() > 0);
denom += 1.0;
- vector<prob_t> probs(src.size() + 1);
+ vector<double> probs(src.size() + 1);
for (int j = 0; j < trg.size(); ++j) {
const WordID& f_j = trg[j][0].label;
- prob_t sum = prob_t::Zero();
+ double sum = 0;
if (use_null) {
probs[0] = tt.prob(kNULL, f_j);
sum += probs[0];
@@ -57,7 +58,7 @@ int main(int argc, char** argv) {
}
if (final_iteration) {
WordID max_i = 0;
- prob_t max_p = prob_t::Zero();
+ double max_p = -1;
if (use_null) {
max_i = kNULL;
max_p = probs[0];
@@ -75,23 +76,23 @@ int main(int argc, char** argv) {
for (int i = 1; i <= src.size(); ++i)
tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);
}
- likelihood *= sum;
+ likelihood += log(sum);
}
}
if (flag) { cerr << endl; }
- cerr << " log likelihood: " << log(likelihood) << endl;
- cerr << " cross entopy: " << (-log(likelihood) / denom) << endl;
- cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl;
+ cerr << " log likelihood: " << likelihood << endl;
+ cerr << " cross entopy: " << (-likelihood / denom) << endl;
+ cerr << " perplexity: " << pow(2.0, -likelihood / denom) << endl;
if (!final_iteration) tt.Normalize();
}
for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {
const TTable::Word2Double& cpd = ei->second;
const TTable::Word2Double& vit = was_viterbi[ei->first];
const string& esym = TD::Convert(ei->first);
- prob_t max_p = prob_t::Zero();
+ double max_p = -1;
for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
- if (fi->second > max_p) max_p = prob_t(fi->second);
- const prob_t threshold = max_p * BEAM_THRESHOLD;
+ if (fi->second > max_p) max_p = fi->second;
+ const double threshold = max_p * BEAM_THRESHOLD;
for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
if (fi->second > threshold || (vit.count(fi->first) > 0)) {
cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
diff --git a/training/ttables.cc b/training/ttables.cc
new file mode 100644
index 00000000..45bf14c5
--- /dev/null
+++ b/training/ttables.cc
@@ -0,0 +1,31 @@
+#include "ttables.h"
+
+#include <cassert>
+
+#include "dict.h"
+
+using namespace std;
+using namespace std::tr1;
+
+void TTable::DeserializeProbsFromText(std::istream* in) {
+ int c = 0;
+ while(*in) {
+ string e;
+ string f;
+ double p;
+ (*in) >> e >> f >> p;
+ if (e.empty()) break;
+ ++c;
+ ttable[TD::Convert(e)][TD::Convert(f)] = p;
+ }
+ cerr << "Loaded " << c << " translation parameters.\n";
+}
+
+void TTable::SerializeHelper(string* out, const Word2Word2Double& o) {
+ assert(!"not implemented");
+}
+
+void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) {
+ assert(!"not implemented");
+}
+
diff --git a/training/ttables.h b/training/ttables.h
new file mode 100644
index 00000000..04e54f9d
--- /dev/null
+++ b/training/ttables.h
@@ -0,0 +1,86 @@
+#ifndef _TTABLES_H_
+#define _TTABLES_H_
+
+#include <iostream>
+#include <tr1/unordered_map>
+
+#include "wordid.h"
+#include "tdict.h"
+
+class TTable {
+ public:
+ TTable() {}
+ typedef std::tr1::unordered_map<WordID, double> Word2Double;
+ typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double;
+ inline const double prob(const int& e, const int& f) const {
+ const Word2Word2Double::const_iterator cit = ttable.find(e);
+ if (cit != ttable.end()) {
+ const Word2Double& cpd = cit->second;
+ const Word2Double::const_iterator it = cpd.find(f);
+ if (it == cpd.end()) return 1e-9;
+ return it->second;
+ } else {
+ return 1e-9;
+ }
+ }
+ inline void Increment(const int& e, const int& f) {
+ counts[e][f] += 1.0;
+ }
+ inline void Increment(const int& e, const int& f, double x) {
+ counts[e][f] += x;
+ }
+ void Normalize() {
+ ttable.swap(counts);
+ for (Word2Word2Double::iterator cit = ttable.begin();
+ cit != ttable.end(); ++cit) {
+ double tot = 0;
+ Word2Double& cpd = cit->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ tot += it->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ it->second /= tot;
+ }
+ counts.clear();
+ }
+ // adds counts from another TTable - probabilities remain unchanged
+ TTable& operator+=(const TTable& rhs) {
+ for (Word2Word2Double::const_iterator it = rhs.counts.begin();
+ it != rhs.counts.end(); ++it) {
+ const Word2Double& cpd = it->second;
+ Word2Double& tgt = counts[it->first];
+ for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ tgt[j->first] += j->second;
+ }
+ }
+ return *this;
+ }
+ void ShowTTable() {
+ for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void ShowCounts() {
+ for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void DeserializeProbsFromText(std::istream* in);
+ void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); }
+ void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); }
+ void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); }
+ void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); }
+ private:
+ static void SerializeHelper(std::string*, const Word2Word2Double& o);
+ static void DeserializeHelper(const std::string&, Word2Word2Double* o);
+ public:
+ Word2Word2Double ttable;
+ Word2Word2Double counts;
+};
+
+#endif