summaryrefslogtreecommitdiff
path: root/training/crf/cllh_observer.cc
diff options
context:
space:
mode:
authorMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-22 16:01:23 -0500
committerMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-22 16:01:23 -0500
commit597d89c11db53e91bc011eab70fd613bbe6453e8 (patch)
tree83c87c07d1ff6d3ee4e3b1626f7eddd49c61095b /training/crf/cllh_observer.cc
parent65e958ff2678a41c22be7171456a63f002ef370b (diff)
parent201af2acd394415a05072fbd53d42584875aa4b4 (diff)
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'training/crf/cllh_observer.cc')
-rw-r--r--training/crf/cllh_observer.cc52
1 files changed, 52 insertions, 0 deletions
diff --git a/training/crf/cllh_observer.cc b/training/crf/cllh_observer.cc
new file mode 100644
index 00000000..4ec2fa65
--- /dev/null
+++ b/training/crf/cllh_observer.cc
@@ -0,0 +1,52 @@
+#include "cllh_observer.h"
+
+#include <cmath>
+#include <cassert>
+
+#include "inside_outside.h"
+#include "hg.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+static const double kMINUS_EPSILON = -1e-6;
+
+ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {}
+
+void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) {
+ cur_obj = 0;
+ state = 1;
+}
+
+void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) {
+ assert(state == 1);
+ state = 2;
+ SparseVector<prob_t> cur_model_exp;
+ const prob_t z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp);
+ cur_obj = log(z);
+}
+
+void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) {
+ assert(state == 2);
+ state = 3;
+ SparseVector<prob_t> ref_exp;
+ const prob_t ref_z = InsideOutside<prob_t,
+ EdgeProb,
+ SparseVector<prob_t>,
+ EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp);
+
+ double log_ref_z = log(ref_z);
+
+ // rounding errors means that <0 is too strict
+ if ((cur_obj - log_ref_z) < kMINUS_EPSILON) {
+ cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl;
+ exit(1);
+ }
+ assert(!std::isnan(log_ref_z));
+ acc_obj += (cur_obj - log_ref_z);
+ trg_words += smeta.GetReference().size();
+}
+