summaryrefslogtreecommitdiff
path: root/vest
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-20 00:11:45 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-20 00:11:45 +0000
commiteb0a8931eb835d8a23c939ef61867a08c19ebf6b (patch)
tree7112b6af490a04c8bfdc40eb88718702060823c7 /vest
parentbcad98e114d468edf16369b3a30d98556a3d0e61 (diff)
Score::TimesEquals for vlad-mira, intrusive refcount for Score, shared_ptr compile fixes for decoder progs
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@331 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'vest')
-rw-r--r--vest/fast_score.cc6
-rw-r--r--vest/mbr_kbest.cc6
-rw-r--r--vest/scorer.cc12
-rw-r--r--vest/scorer.h8
4 files changed, 22 insertions, 10 deletions
diff --git a/vest/fast_score.cc b/vest/fast_score.cc
index 0d611d56..5ee264a6 100644
--- a/vest/fast_score.cc
+++ b/vest/fast_score.cc
@@ -41,7 +41,7 @@ int main(int argc, char** argv) {
cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl;
ReadFile rf(conf["in_file"].as<string>());
- Score* acc = NULL;
+ ScoreP acc;
istream& in = *rf.stream();
int lc = 0;
while(in) {
@@ -50,10 +50,9 @@ int main(int argc, char** argv) {
if (line.empty() && !in) break;
vector<WordID> sent;
TD::ConvertSentence(line, &sent);
- Score* sentscore = ds[lc]->ScoreCandidate(sent);
+ ScoreP sentscore = ds[lc]->ScoreCandidate(sent);
if (!acc) { acc = sentscore->GetZero(); }
acc->PlusEquals(*sentscore);
- delete sentscore;
++lc;
}
assert(lc > 0);
@@ -67,7 +66,6 @@ int main(int argc, char** argv) {
float score = acc->ComputeScore();
string details;
acc->ScoreDetails(&details);
- delete acc;
cerr << details << endl;
cout << score << endl;
return 0;
diff --git a/vest/mbr_kbest.cc b/vest/mbr_kbest.cc
index 5d70b4e2..2867b36b 100644
--- a/vest/mbr_kbest.cc
+++ b/vest/mbr_kbest.cc
@@ -101,14 +101,13 @@ int main(int argc, char** argv) {
for (int i = 0 ; i < list.size(); ++i) {
vector<vector<WordID> > refs(1, list[i].first);
//cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
- SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs);
+ ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs);
double wl_acc = 0;
for (int j = 0; j < list.size(); ++j) {
if (i != j) {
- Score* s = scorer->ScoreCandidate(list[j].first);
+ ScoreP s = scorer->ScoreCandidate(list[j].first);
double loss = 1.0 - s->ComputeScore();
if (type == TER || type == AER) loss = 1.0 - loss;
- delete s;
double weighted_loss = loss * (joints[j] / marginal);
wl_acc += weighted_loss;
if ((!output_list) && wl_acc > mbr_loss) break;
@@ -119,7 +118,6 @@ int main(int argc, char** argv) {
mbr_loss = wl_acc;
mbr_idx = i;
}
- delete scorer;
}
// cerr << "ML translation: " << TD::GetString(list[0].first) << endl;
cerr << "MBR Best idx: " << mbr_idx << endl;
diff --git a/vest/scorer.cc b/vest/scorer.cc
index 86894c32..05269a3b 100644
--- a/vest/scorer.cc
+++ b/vest/scorer.cc
@@ -28,6 +28,10 @@ using namespace std;
const bool minimize_segments = true; // if adjacent segments have equal scores, merge them
+void Score::TimesEquals(float scale) {
+ cerr<<"UNIMPLEMENTED except for BLEU (for MIRA): Score::TimesEquals"<<endl;abort();
+}
+
ScoreType ScoreTypeFromString(const string& st) {
const string sl = LowercaseString(st);
if (sl == "ser")
@@ -159,6 +163,7 @@ class BLEUScore : public Score {
float ComputeScore() const;
float ComputePartialScore() const;
void ScoreDetails(string* details) const;
+ void TimesEquals(float scale);
void PlusEquals(const Score& delta);
void PlusEquals(const Score& delta, const float scale);
void PlusPartialEquals(const Score& delta, int oracle_e_cover, int oracle_f_cover, int src_len);
@@ -566,6 +571,13 @@ void BLEUScore::PlusEquals(const Score& delta) {
hyp_len += d.hyp_len;
}
+void BLEUScore::TimesEquals(float scale) {
+ correct_ngram_hit_counts *= scale;
+ hyp_ngram_counts *= scale;
+ ref_len *= scale;
+ hyp_len *= scale;
+}
+
void BLEUScore::PlusEquals(const Score& delta, const float scale) {
const BLEUScore& d = static_cast<const BLEUScore&>(delta);
correct_ngram_hit_counts = correct_ngram_hit_counts + (d.correct_ngram_hit_counts * scale);
diff --git a/vest/scorer.h b/vest/scorer.h
index 29ba5377..0d90f378 100644
--- a/vest/scorer.h
+++ b/vest/scorer.h
@@ -5,21 +5,23 @@
#include <boost/shared_ptr.hpp>
//TODO: use intrusive shared_ptr in Score (because there are many of them on ErrorSurfaces)
#include "wordid.h"
+#include "intrusive_refcount.hpp"
class Score;
class SentenceScorer;
-typedef boost::shared_ptr<Score> ScoreP;
+typedef boost::intrusive_ptr<Score> ScoreP;
typedef boost::shared_ptr<SentenceScorer> ScorerP;
class ViterbiEnvelope;
class ErrorSurface;
class Hypergraph; // needed for alignment
+//TODO: BLEU N (N separate arg, not part of enum)?
enum ScoreType { IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, BLEU_minus_TER_over_2, SER, AER, IBM_BLEU_3 };
ScoreType ScoreTypeFromString(const std::string& st);
std::string StringFromScoreType(ScoreType st);
-class Score {
+class Score : public boost::intrusive_refcount<Score> {
public:
virtual ~Score();
virtual float ComputeScore() const = 0;
@@ -30,6 +32,8 @@ class Score {
ScoreDetails(&d);
return d;
}
+ virtual void TimesEquals(float scale); // only for bleu; for mira oracle
+ /// same as rhs.TimesEquals(scale);PlusEquals(rhs) except doesn't modify rhs.
virtual void PlusEquals(const Score& rhs, const float scale) = 0;
virtual void PlusEquals(const Score& rhs) = 0;
virtual void PlusPartialEquals(const Score& rhs, int oracle_e_cover, int oracle_f_cover, int src_len) = 0;