diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 |
commit | 597d89c11db53e91bc011eab70fd613bbe6453e8 (patch) | |
tree | 83c87c07d1ff6d3ee4e3b1626f7eddd49c61095b /training/crf/cllh_observer.cc | |
parent | 65e958ff2678a41c22be7171456a63f002ef370b (diff) | |
parent | 201af2acd394415a05072fbd53d42584875aa4b4 (diff) |
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'training/crf/cllh_observer.cc')
-rw-r--r-- | training/crf/cllh_observer.cc | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/training/crf/cllh_observer.cc b/training/crf/cllh_observer.cc new file mode 100644 index 00000000..4ec2fa65 --- /dev/null +++ b/training/crf/cllh_observer.cc @@ -0,0 +1,52 @@ +#include "cllh_observer.h" + +#include <cmath> +#include <cassert> + +#include "inside_outside.h" +#include "hg.h" +#include "sentence_metadata.h" + +using namespace std; + +static const double kMINUS_EPSILON = -1e-6; + +ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {} + +void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) { + cur_obj = 0; + state = 1; +} + +void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 1); + state = 2; + SparseVector<prob_t> cur_model_exp; + const prob_t z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); +} + +void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector<prob_t> ref_exp; + const prob_t ref_z = InsideOutside<prob_t, + EdgeProb, + SparseVector<prob_t>, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + + double log_ref_z = log(ref_z); + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); +} + |