From 3565a2515d81865ecaa464e22fb6b884082c3544 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 11 Aug 2010 03:17:28 +0000 Subject: forgotten files git-svn-id: https://ws10smt.googlecode.com/svn/trunk@520 ec762483-ff6d-05da-a07a-a48fb63a330f --- vest/ces.cc | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ vest/ces.h | 12 +++++++++ 2 files changed, 99 insertions(+) create mode 100644 vest/ces.cc create mode 100644 vest/ces.h diff --git a/vest/ces.cc b/vest/ces.cc new file mode 100644 index 00000000..aa341058 --- /dev/null +++ b/vest/ces.cc @@ -0,0 +1,87 @@ +#include "ces.h" + +#include +#include +#include + +#include "aligner.h" +#include "lattice.h" +#include "viterbi_envelope.h" +#include "error_surface.h" + +using boost::shared_ptr; +using namespace std; + +const bool minimize_segments = true; // if adjacent segments have equal scores, merge them + +void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, ErrorSurface* env, const ScoreType type, const Hypergraph& hg) { + vector prev_trans; + const vector >& ienv = ve.GetSortedSegs(); + env->resize(ienv.size()); + ScoreP prev_score; + int j = 0; + for (int i = 0; i < ienv.size(); ++i) { + const Segment& seg = *ienv[i]; + vector trans; + if (type == AER) { + vector edges(hg.edges_.size(), false); + seg.CollectEdgesUsed(&edges); // get the set of edges in the viterbi + // alignment + ostringstream os; + const string* psrc = ss.GetSource(); + if (psrc == NULL) { + cerr << "AER scoring in VEST requires source, but it is missing!\n"; + abort(); + } + size_t pos = psrc->rfind(" ||| "); + if (pos == string::npos) { + cerr << "Malformed source for AER: expected |||\nINPUT: " << *psrc << endl; + abort(); + } + Lattice src; + Lattice ref; + LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src); + LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref); + AlignerTools::WriteAlignment(src, ref, hg, &os, true, &edges); + string tstr = os.str(); + TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans); + } else { + seg.ConstructTranslation(&trans); + } + // cerr << "Scoring: " << TD::GetString(trans) << endl; + if (trans == prev_trans) { + if (!minimize_segments) { + assert(prev_score); // if this fails, it means + // the decoder can generate null translations + ErrorSegment& out = (*env)[j]; + out.delta = prev_score->GetZero(); + out.x = seg.x; + ++j; + } + // cerr << "Identical translation, skipping scoring\n"; + } else { + ScoreP score = ss.ScoreCandidate(trans); + // cerr << "score= " << score->ComputeScore() << "\n"; + ScoreP cur_delta_p = score->GetZero(); + Score* cur_delta = cur_delta_p.get(); + // just record the score diffs + if (!prev_score) + prev_score = score->GetZero(); + + score->Subtract(*prev_score, cur_delta); + prev_trans.swap(trans); + prev_score = score; + if ((!minimize_segments) || (!cur_delta->IsAdditiveIdentity())) { + ErrorSegment& out = (*env)[j]; + out.delta = cur_delta_p; + out.x = seg.x; + ++j; + } + } + } + // cerr << " In segments: " << ienv.size() << endl; + // cerr << "Out segments: " << j << endl; + assert(j > 0); + env->resize(j); +} + diff --git a/vest/ces.h b/vest/ces.h new file mode 100644 index 00000000..2f098990 --- /dev/null +++ b/vest/ces.h @@ -0,0 +1,12 @@ +#ifndef _CES_H_ +#define _CES_H_ + +#include "scorer.h" + +class ViterbiEnvelope; +class Hypergraph; +class ErrorSurface; + +void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, ErrorSurface* es, const ScoreType type, const Hypergraph& hg); + +#endif -- cgit v1.2.3