summaryrefslogtreecommitdiff
path: root/training/cllh_observer.cc
diff options
context:
space:
mode:
authorChris Dyer <prguest11@taipan.cs>2012-06-19 23:07:51 +0100
committerChris Dyer <prguest11@taipan.cs>2012-06-19 23:07:51 +0100
commit9d4cfa88a71c0cba9a7d3e21cb2b58f78b097b48 (patch)
tree315a93974903c8a90dad4737367157959899dedf /training/cllh_observer.cc
parentfcd8e74ca9c16fe0e3001906ae2bd0ac0686f813 (diff)
compute held-out ppl in mpi_batch_optimize
Diffstat (limited to 'training/cllh_observer.cc')
-rw-r--r--training/cllh_observer.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/training/cllh_observer.cc b/training/cllh_observer.cc
new file mode 100644
index 00000000..58232769
--- /dev/null
+++ b/training/cllh_observer.cc
@@ -0,0 +1,52 @@
+#include "cllh_observer.h"
+
+#include <cmath>
+#include <cassert>
+
+#include "inside_outside.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+static const double kMINUS_EPSILON = -1e-6;
+
+ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {}
+
+void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) {
+ cur_obj = 0;
+ state = 1;
+}
+
+void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) {
+ assert(state == 1);
+ state = 2;
+ SparseVector<prob_t> cur_model_exp;
+ const prob_t z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp);
+ cur_obj = log(z);
+}
+
+void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 2);
+ state = 3;
+ SparseVector<prob_t> ref_exp;
+ const prob_t ref_z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp);
+
+ double log_ref_z = log(ref_z);
+
+ // rounding errors means that <0 is too strict
+ if ((cur_obj - log_ref_z) < kMINUS_EPSILON) {
+ cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl;
+ exit(1);
+ }
+ assert(!isnan(log_ref_z));
+ acc_obj += (cur_obj - log_ref_z);
+ trg_words += smeta.GetReference().size();
+}
+