diff options
| author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 | 
|---|---|---|
| committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-29 21:06:33 +0000 | 
| commit | adfa44a02ea08cde8b1490258aefbd766617f447 (patch) | |
| tree | 380c061340bb468b8909d94e79f5e70e904714fb | |
| parent | f412aaab3d10fb82b20a2190f2cb1424959c599a (diff) | |
move model 1 code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@665 ec762483-ff6d-05da-a07a-a48fb63a330f
| -rw-r--r-- | decoder/Makefile.am | 1 | ||||
| -rw-r--r-- | training/Makefile.am | 2 | ||||
| -rw-r--r-- | training/model1.cc | 25 | ||||
| -rw-r--r-- | training/ttables.cc (renamed from decoder/ttables.cc) | 2 | ||||
| -rw-r--r-- | training/ttables.h (renamed from decoder/ttables.h) | 15 | 
5 files changed, 22 insertions, 23 deletions
| diff --git a/decoder/Makefile.am b/decoder/Makefile.am index a6dbfca4..bf368c6d 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -55,7 +55,6 @@ libcdec_a_SOURCES = \    earley_composer.cc \    phrasetable_fst.cc \    trule.cc \ -  ttables.cc \    ff.cc \    ff_lm.cc \    ff_ruleshape.cc \ diff --git a/training/Makefile.am b/training/Makefile.am index 2679adea..7cdf10d7 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -34,7 +34,7 @@ online_train_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutil  atools_SOURCES = atools.cc  atools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -model1_SOURCES = model1.cc +model1_SOURCES = model1.cc ttables.cc  model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz  grammar_convert_SOURCES = grammar_convert.cc diff --git a/training/model1.cc b/training/model1.cc index f571700f..92a70985 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -1,4 +1,5 @@  #include <iostream> +#include <cmath>  #include "lattice.h"  #include "stringlib.h" @@ -14,7 +15,7 @@ int main(int argc, char** argv) {      return 1;    }    const int ITERATIONS = 5; -  const prob_t BEAM_THRESHOLD(0.0001); +  const double BEAM_THRESHOLD = 0.0001;    TTable tt;    const WordID kNULL = TD::Convert("<eps>");    bool use_null = true; @@ -24,7 +25,7 @@ int main(int argc, char** argv) {      cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;      ReadFile rf(argv[1]);      istream& in = *rf.stream(); -    prob_t likelihood = prob_t::One(); +    double likelihood = 0;      double denom = 0.0;      int lc = 0;      bool flag = false; @@ -43,10 +44,10 @@ int main(int argc, char** argv) {        assert(src.size() > 0);        assert(trg.size() > 0);        denom += 1.0; -      vector<prob_t> probs(src.size() + 1); +      vector<double> probs(src.size() + 1);        for (int j = 0; j < trg.size(); ++j) {          const WordID& f_j = trg[j][0].label; -        prob_t sum = prob_t::Zero(); +        double sum = 0;          if (use_null) {            probs[0] = tt.prob(kNULL, f_j);            sum += probs[0]; @@ -57,7 +58,7 @@ int main(int argc, char** argv) {          }          if (final_iteration) {            WordID max_i = 0; -          prob_t max_p = prob_t::Zero(); +          double max_p = -1;            if (use_null) {              max_i = kNULL;              max_p = probs[0]; @@ -75,23 +76,23 @@ int main(int argc, char** argv) {            for (int i = 1; i <= src.size(); ++i)              tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);          } -        likelihood *= sum; +        likelihood += log(sum);        }      }      if (flag) { cerr << endl; } -    cerr << "  log likelihood: " << log(likelihood) << endl; -    cerr << "    cross entopy: " << (-log(likelihood) / denom) << endl; -    cerr << "      perplexity: " << pow(2.0, -log(likelihood) / denom) << endl; +    cerr << "  log likelihood: " << likelihood << endl; +    cerr << "    cross entopy: " << (-likelihood / denom) << endl; +    cerr << "      perplexity: " << pow(2.0, -likelihood / denom) << endl;      if (!final_iteration) tt.Normalize();    }    for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {      const TTable::Word2Double& cpd = ei->second;      const TTable::Word2Double& vit = was_viterbi[ei->first];      const string& esym = TD::Convert(ei->first); -    prob_t max_p = prob_t::Zero(); +    double max_p = -1;      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) -      if (fi->second > max_p) max_p = prob_t(fi->second); -    const prob_t threshold = max_p * BEAM_THRESHOLD; +      if (fi->second > max_p) max_p = fi->second; +    const double threshold = max_p * BEAM_THRESHOLD;      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {        if (fi->second > threshold || (vit.count(fi->first) > 0)) {          cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; diff --git a/decoder/ttables.cc b/training/ttables.cc index 2ea960f0..45bf14c5 100644 --- a/decoder/ttables.cc +++ b/training/ttables.cc @@ -16,7 +16,7 @@ void TTable::DeserializeProbsFromText(std::istream* in) {      (*in) >> e >> f >> p;      if (e.empty()) break;      ++c; -    ttable[TD::Convert(e)][TD::Convert(f)] = prob_t(p); +    ttable[TD::Convert(e)][TD::Convert(f)] = p;    }    cerr << "Loaded " << c << " translation parameters.\n";  } diff --git a/decoder/ttables.h b/training/ttables.h index 3ffc238a..04e54f9d 100644 --- a/decoder/ttables.h +++ b/training/ttables.h @@ -2,26 +2,25 @@  #define _TTABLES_H_  #include <iostream> -#include <map> +#include <tr1/unordered_map>  #include "wordid.h" -#include "prob.h"  #include "tdict.h"  class TTable {   public:    TTable() {} -  typedef std::map<WordID, double> Word2Double; -  typedef std::map<WordID, Word2Double> Word2Word2Double; -  inline const prob_t prob(const int& e, const int& f) const { +  typedef std::tr1::unordered_map<WordID, double> Word2Double; +  typedef std::tr1::unordered_map<WordID, Word2Double> Word2Word2Double; +  inline const double prob(const int& e, const int& f) const {      const Word2Word2Double::const_iterator cit = ttable.find(e);      if (cit != ttable.end()) {        const Word2Double& cpd = cit->second;        const Word2Double::const_iterator it = cpd.find(f); -      if (it == cpd.end()) return prob_t(0.00001); -      return prob_t(it->second); +      if (it == cpd.end()) return 1e-9; +      return it->second;      } else { -      return prob_t(0.00001); +      return 1e-9;      }    }    inline void Increment(const int& e, const int& f) { | 
