summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-20 00:11:45 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-20 00:11:45 +0000
commitf28e6d45671035d39dcfc25070c72f6e120032e1 (patch)
treed3627fd7c440bae67b81f1578fea9576d4836603
parentb6cf365f217bc7b528243071af733d4a251ff77c (diff)
Score::TimesEquals for vlad-mira, intrusive refcount for Score, shared_ptr compile fixes for decoder progs
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@331 ec762483-ff6d-05da-a07a-a48fb63a330f
-rwxr-xr-xdecoder/intrusive_refcount.hpp84
-rwxr-xr-xdecoder/oracle_bleu.h3
-rw-r--r--vest/fast_score.cc6
-rw-r--r--vest/mbr_kbest.cc6
-rw-r--r--vest/scorer.cc12
-rw-r--r--vest/scorer.h8
6 files changed, 108 insertions, 11 deletions
diff --git a/decoder/intrusive_refcount.hpp b/decoder/intrusive_refcount.hpp
new file mode 100755
index 00000000..4a4b0187
--- /dev/null
+++ b/decoder/intrusive_refcount.hpp
@@ -0,0 +1,84 @@
+#ifndef GRAEHL__SHARED__INTRUSIVE_REFCOUNT_HPP
+#define GRAEHL__SHARED__INTRUSIVE_REFCOUNT_HPP
+
+#include <boost/intrusive_ptr.hpp>
+#include <boost/noncopyable.hpp>
+#include <boost/detail/atomic_count.hpp>
+#include <cassert>
+
+/** usage:
+ struct mine : public boost::instrusive_refcount<mine> {};
+
+ boost::intrusive_ptr<mine> p(new mine());
+*/
+
+namespace boost {
+// note: the free functions need to be in boost namespace, OR namespace of involved type. this is the only way to do it.
+
+template <class T>
+class intrusive_refcount;
+
+template <class T>
+class atomic_intrusive_refcount;
+
+template<class T>
+void intrusive_ptr_add_ref(intrusive_refcount<T>* ptr)
+{
+ ++(ptr->refs);
+}
+
+template<class T>
+void intrusive_ptr_release(intrusive_refcount<T>* ptr)
+{
+ if (!--(ptr->refs)) delete static_cast<T*>(ptr);
+}
+
+
+//WARNING: only 2^32 (unsigned) refs allowed. hope that's ok :)
+template<class T>
+class intrusive_refcount : boost::noncopyable
+{
+ protected:
+// typedef intrusive_refcount<T> pointed_type;
+ friend void intrusive_ptr_add_ref<T>(intrusive_refcount<T>* ptr);
+ friend void intrusive_ptr_release<T>(intrusive_refcount<T>* ptr);
+// friend class intrusive_ptr<T>;
+
+ intrusive_refcount(): refs(0) {}
+ ~intrusive_refcount() { assert(refs==0); }
+
+private:
+ unsigned refs;
+};
+
+
+template<class T>
+void intrusive_ptr_add_ref(atomic_intrusive_refcount<T>* ptr)
+{
+ ++(ptr->refs);
+}
+
+template<class T>
+void intrusive_ptr_release(atomic_intrusive_refcount<T>* ptr)
+{
+ if(!--(ptr->refs)) delete static_cast<T*>(ptr);
+}
+
+template<class T>
+class atomic_intrusive_refcount : boost::noncopyable
+{
+ protected:
+ friend void intrusive_ptr_add_ref<T>(atomic_intrusive_refcount<T>* ptr);
+ friend void intrusive_ptr_release<T>(atomic_intrusive_refcount<T>* ptr);
+
+ atomic_intrusive_refcount(): refs(0) {}
+ ~atomic_intrusive_refcount() { assert(refs==0); }
+
+private:
+ boost::detail::atomic_count refs;
+};
+
+}
+
+
+#endif
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 94548c18..2ccace61 100755
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -231,7 +231,8 @@ struct OracleBleu {
void IncludeLastScore(std::ostream *out=0) {
double bleu_scale_ = doc_src_length * doc_score->ComputeScore();
- doc_score->PlusEquals(*sentscore, scale_oracle);
+ doc_score->PlusEquals(*sentscore);
+ doc_score->TimesEquals(scale_oracle);
sentscore.reset();
doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle;
if (out) {
diff --git a/vest/fast_score.cc b/vest/fast_score.cc
index 0d611d56..5ee264a6 100644
--- a/vest/fast_score.cc
+++ b/vest/fast_score.cc
@@ -41,7 +41,7 @@ int main(int argc, char** argv) {
cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl;
ReadFile rf(conf["in_file"].as<string>());
- Score* acc = NULL;
+ ScoreP acc;
istream& in = *rf.stream();
int lc = 0;
while(in) {
@@ -50,10 +50,9 @@ int main(int argc, char** argv) {
if (line.empty() && !in) break;
vector<WordID> sent;
TD::ConvertSentence(line, &sent);
- Score* sentscore = ds[lc]->ScoreCandidate(sent);
+ ScoreP sentscore = ds[lc]->ScoreCandidate(sent);
if (!acc) { acc = sentscore->GetZero(); }
acc->PlusEquals(*sentscore);
- delete sentscore;
++lc;
}
assert(lc > 0);
@@ -67,7 +66,6 @@ int main(int argc, char** argv) {
float score = acc->ComputeScore();
string details;
acc->ScoreDetails(&details);
- delete acc;
cerr << details << endl;
cout << score << endl;
return 0;
diff --git a/vest/mbr_kbest.cc b/vest/mbr_kbest.cc
index 5d70b4e2..2867b36b 100644
--- a/vest/mbr_kbest.cc
+++ b/vest/mbr_kbest.cc
@@ -101,14 +101,13 @@ int main(int argc, char** argv) {
for (int i = 0 ; i < list.size(); ++i) {
vector<vector<WordID> > refs(1, list[i].first);
//cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
- SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs);
+ ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs);
double wl_acc = 0;
for (int j = 0; j < list.size(); ++j) {
if (i != j) {
- Score* s = scorer->ScoreCandidate(list[j].first);
+ ScoreP s = scorer->ScoreCandidate(list[j].first);
double loss = 1.0 - s->ComputeScore();
if (type == TER || type == AER) loss = 1.0 - loss;
- delete s;
double weighted_loss = loss * (joints[j] / marginal);
wl_acc += weighted_loss;
if ((!output_list) && wl_acc > mbr_loss) break;
@@ -119,7 +118,6 @@ int main(int argc, char** argv) {
mbr_loss = wl_acc;
mbr_idx = i;
}
- delete scorer;
}
// cerr << "ML translation: " << TD::GetString(list[0].first) << endl;
cerr << "MBR Best idx: " << mbr_idx << endl;
diff --git a/vest/scorer.cc b/vest/scorer.cc
index 86894c32..05269a3b 100644
--- a/vest/scorer.cc
+++ b/vest/scorer.cc
@@ -28,6 +28,10 @@ using namespace std;
const bool minimize_segments = true; // if adjacent segments have equal scores, merge them
+void Score::TimesEquals(float scale) {
+ cerr<<"UNIMPLEMENTED except for BLEU (for MIRA): Score::TimesEquals"<<endl;abort();
+}
+
ScoreType ScoreTypeFromString(const string& st) {
const string sl = LowercaseString(st);
if (sl == "ser")
@@ -159,6 +163,7 @@ class BLEUScore : public Score {
float ComputeScore() const;
float ComputePartialScore() const;
void ScoreDetails(string* details) const;
+ void TimesEquals(float scale);
void PlusEquals(const Score& delta);
void PlusEquals(const Score& delta, const float scale);
void PlusPartialEquals(const Score& delta, int oracle_e_cover, int oracle_f_cover, int src_len);
@@ -566,6 +571,13 @@ void BLEUScore::PlusEquals(const Score& delta) {
hyp_len += d.hyp_len;
}
+void BLEUScore::TimesEquals(float scale) {
+ correct_ngram_hit_counts *= scale;
+ hyp_ngram_counts *= scale;
+ ref_len *= scale;
+ hyp_len *= scale;
+}
+
void BLEUScore::PlusEquals(const Score& delta, const float scale) {
const BLEUScore& d = static_cast<const BLEUScore&>(delta);
correct_ngram_hit_counts = correct_ngram_hit_counts + (d.correct_ngram_hit_counts * scale);
diff --git a/vest/scorer.h b/vest/scorer.h
index 29ba5377..0d90f378 100644
--- a/vest/scorer.h
+++ b/vest/scorer.h
@@ -5,21 +5,23 @@
#include <boost/shared_ptr.hpp>
//TODO: use intrusive shared_ptr in Score (because there are many of them on ErrorSurfaces)
#include "wordid.h"
+#include "intrusive_refcount.hpp"
class Score;
class SentenceScorer;
-typedef boost::shared_ptr<Score> ScoreP;
+typedef boost::intrusive_ptr<Score> ScoreP;
typedef boost::shared_ptr<SentenceScorer> ScorerP;
class ViterbiEnvelope;
class ErrorSurface;
class Hypergraph; // needed for alignment
+//TODO: BLEU N (N separate arg, not part of enum)?
enum ScoreType { IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, BLEU_minus_TER_over_2, SER, AER, IBM_BLEU_3 };
ScoreType ScoreTypeFromString(const std::string& st);
std::string StringFromScoreType(ScoreType st);
-class Score {
+class Score : public boost::intrusive_refcount<Score> {
public:
virtual ~Score();
virtual float ComputeScore() const = 0;
@@ -30,6 +32,8 @@ class Score {
ScoreDetails(&d);
return d;
}
+ virtual void TimesEquals(float scale); // only for bleu; for mira oracle
+ /// same as rhs.TimesEquals(scale);PlusEquals(rhs) except doesn't modify rhs.
virtual void PlusEquals(const Score& rhs, const float scale) = 0;
virtual void PlusEquals(const Score& rhs) = 0;
virtual void PlusPartialEquals(const Score& rhs, int oracle_e_cover, int oracle_f_cover, int src_len) = 0;