summaryrefslogtreecommitdiff
path: root/vest
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-11 03:17:28 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-11 03:17:28 +0000
commit3c5a49564698b86ba186f7295057ef934ec4047b (patch)
tree3e1764f392fc4ac2e48afd15e75630ee6735a330 /vest
parent90b9ca8b746333f4a47231c244b841f134417706 (diff)
forgotten files
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@520 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'vest')
-rw-r--r--vest/ces.cc87
-rw-r--r--vest/ces.h12
2 files changed, 99 insertions, 0 deletions
diff --git a/vest/ces.cc b/vest/ces.cc
new file mode 100644
index 00000000..aa341058
--- /dev/null
+++ b/vest/ces.cc
@@ -0,0 +1,87 @@
+#include "ces.h"
+
+#include <vector>
+#include <sstream>
+#include <boost/shared_ptr.hpp>
+
+#include "aligner.h"
+#include "lattice.h"
+#include "viterbi_envelope.h"
+#include "error_surface.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+const bool minimize_segments = true; // if adjacent segments have equal scores, merge them
+
+void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, ErrorSurface* env, const ScoreType type, const Hypergraph& hg) {
+ vector<WordID> prev_trans;
+ const vector<shared_ptr<Segment> >& ienv = ve.GetSortedSegs();
+ env->resize(ienv.size());
+ ScoreP prev_score;
+ int j = 0;
+ for (int i = 0; i < ienv.size(); ++i) {
+ const Segment& seg = *ienv[i];
+ vector<WordID> trans;
+ if (type == AER) {
+ vector<bool> edges(hg.edges_.size(), false);
+ seg.CollectEdgesUsed(&edges); // get the set of edges in the viterbi
+ // alignment
+ ostringstream os;
+ const string* psrc = ss.GetSource();
+ if (psrc == NULL) {
+ cerr << "AER scoring in VEST requires source, but it is missing!\n";
+ abort();
+ }
+ size_t pos = psrc->rfind(" ||| ");
+ if (pos == string::npos) {
+ cerr << "Malformed source for AER: expected |||\nINPUT: " << *psrc << endl;
+ abort();
+ }
+ Lattice src;
+ Lattice ref;
+ LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src);
+ LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref);
+ AlignerTools::WriteAlignment(src, ref, hg, &os, true, &edges);
+ string tstr = os.str();
+ TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans);
+ } else {
+ seg.ConstructTranslation(&trans);
+ }
+ // cerr << "Scoring: " << TD::GetString(trans) << endl;
+ if (trans == prev_trans) {
+ if (!minimize_segments) {
+ assert(prev_score); // if this fails, it means
+ // the decoder can generate null translations
+ ErrorSegment& out = (*env)[j];
+ out.delta = prev_score->GetZero();
+ out.x = seg.x;
+ ++j;
+ }
+ // cerr << "Identical translation, skipping scoring\n";
+ } else {
+ ScoreP score = ss.ScoreCandidate(trans);
+ // cerr << "score= " << score->ComputeScore() << "\n";
+ ScoreP cur_delta_p = score->GetZero();
+ Score* cur_delta = cur_delta_p.get();
+ // just record the score diffs
+ if (!prev_score)
+ prev_score = score->GetZero();
+
+ score->Subtract(*prev_score, cur_delta);
+ prev_trans.swap(trans);
+ prev_score = score;
+ if ((!minimize_segments) || (!cur_delta->IsAdditiveIdentity())) {
+ ErrorSegment& out = (*env)[j];
+ out.delta = cur_delta_p;
+ out.x = seg.x;
+ ++j;
+ }
+ }
+ }
+ // cerr << " In segments: " << ienv.size() << endl;
+ // cerr << "Out segments: " << j << endl;
+ assert(j > 0);
+ env->resize(j);
+}
+
diff --git a/vest/ces.h b/vest/ces.h
new file mode 100644
index 00000000..2f098990
--- /dev/null
+++ b/vest/ces.h
@@ -0,0 +1,12 @@
+#ifndef _CES_H_
+#define _CES_H_
+
+#include "scorer.h"
+
+class ViterbiEnvelope;
+class Hypergraph;
+class ErrorSurface;
+
+void ComputeErrorSurface(const SentenceScorer& ss, const ViterbiEnvelope& ve, ErrorSurface* es, const ScoreType type, const Hypergraph& hg);
+
+#endif