summaryrefslogtreecommitdiff
path: root/training/mpi_compute_cllh.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mpi_compute_cllh.cc')
-rw-r--r--training/mpi_compute_cllh.cc59
1 files changed, 1 insertions, 58 deletions
diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc
index d5caa745..066389d0 100644
--- a/training/mpi_compute_cllh.cc
+++ b/training/mpi_compute_cllh.cc
@@ -10,6 +10,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "cllh_observer.h"
#include "sentence_metadata.h"
#include "verbose.h"
#include "hg.h"
@@ -67,64 +68,6 @@ void ReadInstances(const string& fname, int rank, int size, vector<string>* c) {
static const double kMINUS_EPSILON = -1e-6;
-struct ConditionalLikelihoodObserver : public DecoderObserver {
-
- ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {}
-
- virtual void NotifyDecodingStart(const SentenceMetadata&) {
- cur_obj = 0;
- state = 1;
- }
-
- // compute model expectations, denominator of objective
- virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) {
- assert(state == 1);
- state = 2;
- SparseVector<prob_t> cur_model_exp;
- const prob_t z = InsideOutside<prob_t,
- EdgeProb,
- SparseVector<prob_t>,
- EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp);
- cur_obj = log(z);
- }
-
- // compute "empirical" expectations, numerator of objective
- virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {
- assert(state == 2);
- state = 3;
- SparseVector<prob_t> ref_exp;
- const prob_t ref_z = InsideOutside<prob_t,
- EdgeProb,
- SparseVector<prob_t>,
- EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp);
-
- double log_ref_z;
-#if 0
- if (crf_uniform_empirical) {
- log_ref_z = ref_exp.dot(feature_weights);
- } else {
- log_ref_z = log(ref_z);
- }
-#else
- log_ref_z = log(ref_z);
-#endif
-
- // rounding errors means that <0 is too strict
- if ((cur_obj - log_ref_z) < kMINUS_EPSILON) {
- cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl;
- exit(1);
- }
- assert(!isnan(log_ref_z));
- acc_obj += (cur_obj - log_ref_z);
- trg_words += smeta.GetReference().size();
- }
-
- unsigned trg_words;
- double acc_obj;
- double cur_obj;
- int state;
-};
-
#ifdef HAVE_MPI
namespace mpi = boost::mpi;
#endif