diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 |
commit | adfa44a02ea08cde8b1490258aefbd766617f447 (patch) | |
tree | 380c061340bb468b8909d94e79f5e70e904714fb /training/model1.cc | |
parent | f412aaab3d10fb82b20a2190f2cb1424959c599a (diff) |
move model 1 code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@665 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/model1.cc')
-rw-r--r-- | training/model1.cc | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/training/model1.cc b/training/model1.cc index f571700f..92a70985 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -1,4 +1,5 @@ #include <iostream> +#include <cmath> #include "lattice.h" #include "stringlib.h" @@ -14,7 +15,7 @@ int main(int argc, char** argv) { return 1; } const int ITERATIONS = 5; - const prob_t BEAM_THRESHOLD(0.0001); + const double BEAM_THRESHOLD = 0.0001; TTable tt; const WordID kNULL = TD::Convert("<eps>"); bool use_null = true; @@ -24,7 +25,7 @@ int main(int argc, char** argv) { cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; ReadFile rf(argv[1]); istream& in = *rf.stream(); - prob_t likelihood = prob_t::One(); + double likelihood = 0; double denom = 0.0; int lc = 0; bool flag = false; @@ -43,10 +44,10 @@ int main(int argc, char** argv) { assert(src.size() > 0); assert(trg.size() > 0); denom += 1.0; - vector<prob_t> probs(src.size() + 1); + vector<double> probs(src.size() + 1); for (int j = 0; j < trg.size(); ++j) { const WordID& f_j = trg[j][0].label; - prob_t sum = prob_t::Zero(); + double sum = 0; if (use_null) { probs[0] = tt.prob(kNULL, f_j); sum += probs[0]; @@ -57,7 +58,7 @@ int main(int argc, char** argv) { } if (final_iteration) { WordID max_i = 0; - prob_t max_p = prob_t::Zero(); + double max_p = -1; if (use_null) { max_i = kNULL; max_p = probs[0]; @@ -75,23 +76,23 @@ int main(int argc, char** argv) { for (int i = 1; i <= src.size(); ++i) tt.Increment(src[i-1][0].label, f_j, probs[i] / sum); } - likelihood *= sum; + likelihood += log(sum); } } if (flag) { cerr << endl; } - cerr << " log likelihood: " << log(likelihood) << endl; - cerr << " cross entopy: " << (-log(likelihood) / denom) << endl; - cerr << " perplexity: " << pow(2.0, -log(likelihood) / denom) << endl; + cerr << " log likelihood: " << likelihood << endl; + cerr << " cross entopy: " << (-likelihood / denom) << endl; + cerr << " perplexity: " << pow(2.0, -likelihood / denom) << endl; if (!final_iteration) tt.Normalize(); } for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) { const TTable::Word2Double& cpd = ei->second; const TTable::Word2Double& vit = was_viterbi[ei->first]; const string& esym = TD::Convert(ei->first); - prob_t max_p = prob_t::Zero(); + double max_p = -1; for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) - if (fi->second > max_p) max_p = prob_t(fi->second); - const prob_t threshold = max_p * BEAM_THRESHOLD; + if (fi->second > max_p) max_p = fi->second; + const double threshold = max_p * BEAM_THRESHOLD; for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { if (fi->second > threshold || (vit.count(fi->first) > 0)) { cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; |