From eb0a8931eb835d8a23c939ef61867a08c19ebf6b Mon Sep 17 00:00:00 2001 From: graehl Date: Tue, 20 Jul 2010 00:11:45 +0000 Subject: Score::TimesEquals for vlad-mira, intrusive refcount for Score, shared_ptr compile fixes for decoder progs git-svn-id: https://ws10smt.googlecode.com/svn/trunk@331 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/intrusive_refcount.hpp | 84 ++++++++++++++++++++++++++++++++++++++++++ decoder/oracle_bleu.h | 3 +- vest/fast_score.cc | 6 +-- vest/mbr_kbest.cc | 6 +-- vest/scorer.cc | 12 ++++++ vest/scorer.h | 8 +++- 6 files changed, 108 insertions(+), 11 deletions(-) create mode 100755 decoder/intrusive_refcount.hpp diff --git a/decoder/intrusive_refcount.hpp b/decoder/intrusive_refcount.hpp new file mode 100755 index 00000000..4a4b0187 --- /dev/null +++ b/decoder/intrusive_refcount.hpp @@ -0,0 +1,84 @@ +#ifndef GRAEHL__SHARED__INTRUSIVE_REFCOUNT_HPP +#define GRAEHL__SHARED__INTRUSIVE_REFCOUNT_HPP + +#include +#include +#include +#include + +/** usage: + struct mine : public boost::instrusive_refcount {}; + + boost::intrusive_ptr p(new mine()); +*/ + +namespace boost { +// note: the free functions need to be in boost namespace, OR namespace of involved type. this is the only way to do it. + +template +class intrusive_refcount; + +template +class atomic_intrusive_refcount; + +template +void intrusive_ptr_add_ref(intrusive_refcount* ptr) +{ + ++(ptr->refs); +} + +template +void intrusive_ptr_release(intrusive_refcount* ptr) +{ + if (!--(ptr->refs)) delete static_cast(ptr); +} + + +//WARNING: only 2^32 (unsigned) refs allowed. hope that's ok :) +template +class intrusive_refcount : boost::noncopyable +{ + protected: +// typedef intrusive_refcount pointed_type; + friend void intrusive_ptr_add_ref(intrusive_refcount* ptr); + friend void intrusive_ptr_release(intrusive_refcount* ptr); +// friend class intrusive_ptr; + + intrusive_refcount(): refs(0) {} + ~intrusive_refcount() { assert(refs==0); } + +private: + unsigned refs; +}; + + +template +void intrusive_ptr_add_ref(atomic_intrusive_refcount* ptr) +{ + ++(ptr->refs); +} + +template +void intrusive_ptr_release(atomic_intrusive_refcount* ptr) +{ + if(!--(ptr->refs)) delete static_cast(ptr); +} + +template +class atomic_intrusive_refcount : boost::noncopyable +{ + protected: + friend void intrusive_ptr_add_ref(atomic_intrusive_refcount* ptr); + friend void intrusive_ptr_release(atomic_intrusive_refcount* ptr); + + atomic_intrusive_refcount(): refs(0) {} + ~atomic_intrusive_refcount() { assert(refs==0); } + +private: + boost::detail::atomic_count refs; +}; + +} + + +#endif diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 94548c18..2ccace61 100755 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -231,7 +231,8 @@ struct OracleBleu { void IncludeLastScore(std::ostream *out=0) { double bleu_scale_ = doc_src_length * doc_score->ComputeScore(); - doc_score->PlusEquals(*sentscore, scale_oracle); + doc_score->PlusEquals(*sentscore); + doc_score->TimesEquals(scale_oracle); sentscore.reset(); doc_src_length = (doc_src_length + tmp_src_length) * scale_oracle; if (out) { diff --git a/vest/fast_score.cc b/vest/fast_score.cc index 0d611d56..5ee264a6 100644 --- a/vest/fast_score.cc +++ b/vest/fast_score.cc @@ -41,7 +41,7 @@ int main(int argc, char** argv) { cerr << "Loaded " << ds.size() << " references for scoring with " << loss_function << endl; ReadFile rf(conf["in_file"].as()); - Score* acc = NULL; + ScoreP acc; istream& in = *rf.stream(); int lc = 0; while(in) { @@ -50,10 +50,9 @@ int main(int argc, char** argv) { if (line.empty() && !in) break; vector sent; TD::ConvertSentence(line, &sent); - Score* sentscore = ds[lc]->ScoreCandidate(sent); + ScoreP sentscore = ds[lc]->ScoreCandidate(sent); if (!acc) { acc = sentscore->GetZero(); } acc->PlusEquals(*sentscore); - delete sentscore; ++lc; } assert(lc > 0); @@ -67,7 +66,6 @@ int main(int argc, char** argv) { float score = acc->ComputeScore(); string details; acc->ScoreDetails(&details); - delete acc; cerr << details << endl; cout << score << endl; return 0; diff --git a/vest/mbr_kbest.cc b/vest/mbr_kbest.cc index 5d70b4e2..2867b36b 100644 --- a/vest/mbr_kbest.cc +++ b/vest/mbr_kbest.cc @@ -101,14 +101,13 @@ int main(int argc, char** argv) { for (int i = 0 ; i < list.size(); ++i) { vector > refs(1, list[i].first); //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl; - SentenceScorer* scorer = SentenceScorer::CreateSentenceScorer(type, refs); + ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs); double wl_acc = 0; for (int j = 0; j < list.size(); ++j) { if (i != j) { - Score* s = scorer->ScoreCandidate(list[j].first); + ScoreP s = scorer->ScoreCandidate(list[j].first); double loss = 1.0 - s->ComputeScore(); if (type == TER || type == AER) loss = 1.0 - loss; - delete s; double weighted_loss = loss * (joints[j] / marginal); wl_acc += weighted_loss; if ((!output_list) && wl_acc > mbr_loss) break; @@ -119,7 +118,6 @@ int main(int argc, char** argv) { mbr_loss = wl_acc; mbr_idx = i; } - delete scorer; } // cerr << "ML translation: " << TD::GetString(list[0].first) << endl; cerr << "MBR Best idx: " << mbr_idx << endl; diff --git a/vest/scorer.cc b/vest/scorer.cc index 86894c32..05269a3b 100644 --- a/vest/scorer.cc +++ b/vest/scorer.cc @@ -28,6 +28,10 @@ using namespace std; const bool minimize_segments = true; // if adjacent segments have equal scores, merge them +void Score::TimesEquals(float scale) { + cerr<<"UNIMPLEMENTED except for BLEU (for MIRA): Score::TimesEquals"<(delta); correct_ngram_hit_counts = correct_ngram_hit_counts + (d.correct_ngram_hit_counts * scale); diff --git a/vest/scorer.h b/vest/scorer.h index 29ba5377..0d90f378 100644 --- a/vest/scorer.h +++ b/vest/scorer.h @@ -5,21 +5,23 @@ #include //TODO: use intrusive shared_ptr in Score (because there are many of them on ErrorSurfaces) #include "wordid.h" +#include "intrusive_refcount.hpp" class Score; class SentenceScorer; -typedef boost::shared_ptr ScoreP; +typedef boost::intrusive_ptr ScoreP; typedef boost::shared_ptr ScorerP; class ViterbiEnvelope; class ErrorSurface; class Hypergraph; // needed for alignment +//TODO: BLEU N (N separate arg, not part of enum)? enum ScoreType { IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, BLEU_minus_TER_over_2, SER, AER, IBM_BLEU_3 }; ScoreType ScoreTypeFromString(const std::string& st); std::string StringFromScoreType(ScoreType st); -class Score { +class Score : public boost::intrusive_refcount { public: virtual ~Score(); virtual float ComputeScore() const = 0; @@ -30,6 +32,8 @@ class Score { ScoreDetails(&d); return d; } + virtual void TimesEquals(float scale); // only for bleu; for mira oracle + /// same as rhs.TimesEquals(scale);PlusEquals(rhs) except doesn't modify rhs. virtual void PlusEquals(const Score& rhs, const float scale) = 0; virtual void PlusEquals(const Score& rhs) = 0; virtual void PlusPartialEquals(const Score& rhs, int oracle_e_cover, int oracle_f_cover, int src_len) = 0; -- cgit v1.2.3