summaryrefslogtreecommitdiff
path: root/training/model1.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-02-18 15:16:17 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-02-18 15:16:17 -0500
commitd3ccf26cf501cb15ed300bc0ad17596a4e59fbeb (patch)
treeb01f0f2a89bcf54c36c2b3b19badf31f569a5e3b /training/model1.cc
parenta38b3fa383412e56eb958db998662c026bc08f4b (diff)
fix diagonal model
Diffstat (limited to 'training/model1.cc')
-rw-r--r--training/model1.cc29
1 files changed, 17 insertions, 12 deletions
diff --git a/training/model1.cc b/training/model1.cc
index a87d388f..73104304 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -4,6 +4,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "m.h"
#include "lattice.h"
#include "stringlib.h"
#include "filelib.h"
@@ -13,11 +14,6 @@
namespace po = boost::program_options;
using namespace std;
-inline double log_poisson(unsigned x, const double& lambda) {
- assert(lambda > 0.0);
- return log(lambda) * x - lgamma(x + 1) - lambda;
-}
-
bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -82,6 +78,7 @@ int main(int argc, char** argv) {
TTable::Word2Word2Double was_viterbi;
double tot_len_ratio = 0;
double mean_srclen_multiplier = 0;
+ vector<double> unnormed_a_i;
for (int iter = 0; iter < ITERATIONS; ++iter) {
const bool final_iteration = (iter == (ITERATIONS - 1));
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
@@ -108,6 +105,8 @@ int main(int argc, char** argv) {
assert(src.size() > 0);
assert(trg.size() > 0);
}
+ if (src.size() > unnormed_a_i.size())
+ unnormed_a_i.resize(src.size());
if (iter == 0)
tot_len_ratio += static_cast<double>(trg.size()) / static_cast<double>(src.size());
denom += trg.size();
@@ -125,13 +124,15 @@ int main(int argc, char** argv) {
}
double az = 0;
if (favor_diagonal) {
- for (int ta = 0; ta < src.size(); ++ta)
- az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+ for (int ta = 0; ta < src.size(); ++ta) {
+ unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+ az += unnormed_a_i[ta];
+ }
az /= prob_align_not_null;
}
for (int i = 1; i <= src.size(); ++i) {
if (favor_diagonal)
- prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az;
+ prob_a_i = unnormed_a_i[i-1] / az;
probs[i] = tt.prob(src[i-1][0].label, f_j) * prob_a_i;
sum += probs[i];
}
@@ -202,7 +203,9 @@ int main(int argc, char** argv) {
Lattice src, trg;
LatticeTools::ConvertTextToLattice(ssrc, &src);
LatticeTools::ConvertTextToLattice(strg, &trg);
- double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
+ double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
+ if (src.size() > unnormed_a_i.size())
+ unnormed_a_i.resize(src.size());
// compute likelihood
for (int j = 0; j < trg.size(); ++j) {
@@ -216,13 +219,15 @@ int main(int argc, char** argv) {
}
double az = 0;
if (favor_diagonal) {
- for (int ta = 0; ta < src.size(); ++ta)
- az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+ for (int ta = 0; ta < src.size(); ++ta) {
+ unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+ az += unnormed_a_i[ta];
+ }
az /= prob_align_not_null;
}
for (int i = 1; i <= src.size(); ++i) {
if (favor_diagonal)
- prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az;
+ prob_a_i = unnormed_a_i[i-1] / az;
sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i;
}
log_prob += log(sum);