summaryrefslogtreecommitdiff
path: root/training/model1.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
committerChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
commit671c21451542e2dd20e45b4033d44d8e8735f87b (patch)
treeb1773b077dd65b826f067a423d26f7942ce4e043 /training/model1.cc
initial check in
Diffstat (limited to 'training/model1.cc')
-rw-r--r--training/model1.cc103
1 files changed, 103 insertions, 0 deletions
diff --git a/training/model1.cc b/training/model1.cc
new file mode 100644
index 00000000..f571700f
--- /dev/null
+++ b/training/model1.cc
@@ -0,0 +1,103 @@
+#include <iostream>
+
+#include "lattice.h"
+#include "stringlib.h"
+#include "filelib.h"
+#include "ttables.h"
+#include "tdict.h"
+
+using namespace std;
+
+int main(int argc, char** argv) {
+ if (argc != 2) {
+ cerr << "Usage: " << argv[0] << " corpus.fr-en\n";
+ return 1;
+ }
+ const int ITERATIONS = 5;
+ const prob_t BEAM_THRESHOLD(0.0001);
+ TTable tt;
+ const WordID kNULL = TD::Convert("<eps>");
+ bool use_null = true;
+ TTable::Word2Word2Double was_viterbi;
+ for (int iter = 0; iter < ITERATIONS; ++iter) {
+ const bool final_iteration = (iter == (ITERATIONS - 1));
+ cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream();
+ prob_t likelihood = prob_t::One();
+ double denom = 0.0;
+ int lc = 0;
+ bool flag = false;
+ while(true) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ if (lc % 1000 == 0) { cerr << '.'; flag = true; }
+ if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; }
+ string ssrc, strg;
+ ParseTranslatorInput(line, &ssrc, &strg);
+ Lattice src, trg;
+ LatticeTools::ConvertTextToLattice(ssrc, &src);
+ LatticeTools::ConvertTextToLattice(strg, &trg);
+ assert(src.size() > 0);
+ assert(trg.size() > 0);
+ denom += 1.0;
+ vector<prob_t> probs(src.size() + 1);
+ for (int j = 0; j < trg.size(); ++j) {
+ const WordID& f_j = trg[j][0].label;
+ prob_t sum = prob_t::Zero();
+ if (use_null) {
+ probs[0] = tt.prob(kNULL, f_j);
+ sum += probs[0];
+ }
+ for (int i = 1; i <= src.size(); ++i) {
+ probs[i] = tt.prob(src[i-1][0].label, f_j);
+ sum += probs[i];
+ }
+ if (final_iteration) {
+ WordID max_i = 0;
+ prob_t max_p = prob_t::Zero();
+ if (use_null) {
+ max_i = kNULL;
+ max_p = probs[0];
+ }
+ for (int i = 1; i <= src.size(); ++i) {
+ if (probs[i] > max_p) {
+ max_p = probs[i];
+ max_i = src[i-1][0].label;
+ }
+ }
+ was_viterbi[max_i][f_j] = 1.0;
+ } else {
+ if (use_null)
+ tt.Increment(kNULL, f_j, probs[0] / sum);
+ for (int i = 1; i <= src.size(); ++i)
+ tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);
+ }
+ likelihood *= sum;
+ }
+ }
+ if (flag) { cerr << endl; }
+ cerr << " log likelihood: " << log(likelihood) << endl;
+ cerr << " cross entopy: " << (-log(likelihood) / denom) << endl;
+ cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl;
+ if (!final_iteration) tt.Normalize();
+ }
+ for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {
+ const TTable::Word2Double& cpd = ei->second;
+ const TTable::Word2Double& vit = was_viterbi[ei->first];
+ const string& esym = TD::Convert(ei->first);
+ prob_t max_p = prob_t::Zero();
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
+ if (fi->second > max_p) max_p = prob_t(fi->second);
+ const prob_t threshold = max_p * BEAM_THRESHOLD;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
+ if (fi->second > threshold || (vit.count(fi->first) > 0)) {
+ cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+ }
+ }
+ }
+ return 0;
+}
+