diff options
author | Chris Dyer <prguest11@taipan.cs> | 2012-06-19 23:07:51 +0100 |
---|---|---|
committer | Chris Dyer <prguest11@taipan.cs> | 2012-06-19 23:07:51 +0100 |
commit | dc67307d5fc703941a129da0ce7b23fe3712127b (patch) | |
tree | 0790a84aeb194f5b88a1cfa8c8401607f553880b /training/mpi_compute_cllh.cc | |
parent | 5975dcaa50adb5ce7a05b83583b8f9ddc45f3f0a (diff) |
compute held-out ppl in mpi_batch_optimize
Diffstat (limited to 'training/mpi_compute_cllh.cc')
-rw-r--r-- | training/mpi_compute_cllh.cc | 59 |
1 files changed, 1 insertions, 58 deletions
diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc index d5caa745..066389d0 100644 --- a/training/mpi_compute_cllh.cc +++ b/training/mpi_compute_cllh.cc @@ -10,6 +10,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "cllh_observer.h" #include "sentence_metadata.h" #include "verbose.h" #include "hg.h" @@ -67,64 +68,6 @@ void ReadInstances(const string& fname, int rank, int size, vector<string>* c) { static const double kMINUS_EPSILON = -1e-6; -struct ConditionalLikelihoodObserver : public DecoderObserver { - - ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} - - virtual void NotifyDecodingStart(const SentenceMetadata&) { - cur_obj = 0; - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - assert(state == 1); - state = 2; - SparseVector<prob_t> cur_model_exp; - const prob_t z = InsideOutside<prob_t, - EdgeProb, - SparseVector<prob_t>, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_obj = log(z); - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - SparseVector<prob_t> ref_exp; - const prob_t ref_z = InsideOutside<prob_t, - EdgeProb, - SparseVector<prob_t>, - EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); - - double log_ref_z; -#if 0 - if (crf_uniform_empirical) { - log_ref_z = ref_exp.dot(feature_weights); - } else { - log_ref_z = log(ref_z); - } -#else - log_ref_z = log(ref_z); -#endif - - // rounding errors means that <0 is too strict - if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; - exit(1); - } - assert(!isnan(log_ref_z)); - acc_obj += (cur_obj - log_ref_z); - trg_words += smeta.GetReference().size(); - } - - unsigned trg_words; - double acc_obj; - double cur_obj; - int state; -}; - #ifdef HAVE_MPI namespace mpi = boost::mpi; #endif |