diff options
author | Guest_account Guest_account prguest11 <prguest11@taipan.cs> | 2011-10-11 16:16:53 +0100 |
---|---|---|
committer | Guest_account Guest_account prguest11 <prguest11@taipan.cs> | 2011-10-11 16:16:53 +0100 |
commit | 08c4a7fae8f0bec4f76c4e0928e357100eb7a1ca (patch) | |
tree | 44030db9ef1625ce130ab08acfd308643d568d1f /decoder | |
parent | ffaae62e4f1cedabbc6eb1982af129e7294d33eb (diff) |
remove implicit conversion-to-double operator from LogVal<T> that caused overflow errors, clean up some pf code
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/aligner.cc | 2 | ||||
-rwxr-xr-x | decoder/cfg.cc | 2 | ||||
-rwxr-xr-x | decoder/cfg_format.h | 2 | ||||
-rw-r--r-- | decoder/decoder.cc | 10 | ||||
-rw-r--r-- | decoder/hg.cc | 4 | ||||
-rw-r--r-- | decoder/rule_lexer.l | 2 | ||||
-rw-r--r-- | decoder/trule.h | 15 |
7 files changed, 26 insertions, 11 deletions
diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 292ee123..53e059fb 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -165,7 +165,7 @@ inline void WriteProbGrid(const Array2D<prob_t>& m, ostream* pos) { if (m(i,j) == prob_t::Zero()) { os << "\t---X---"; } else { - snprintf(b, 1024, "%0.5f", static_cast<double>(m(i,j))); + snprintf(b, 1024, "%0.5f", m(i,j).as_float()); os << '\t' << b; } } diff --git a/decoder/cfg.cc b/decoder/cfg.cc index 651978d2..cd7e66e9 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -639,7 +639,7 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const { o << '['<<f.goal_nt_name <<']'; WordID rhs=-goal_nt; f.print_rhs(o,*this,&rhs,&rhs+1); - if (pushed_inside!=1) + if (pushed_inside!=prob_t::One()) f.print_features(o,pushed_inside); o<<'\n'; } diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index c6a594b8..2f40d483 100755 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -101,7 +101,7 @@ struct CFGFormat { } void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const { - bool logp=(logprob_feat && p!=1); + bool logp=(logprob_feat && p!=prob_t::One()); if (features || logp) { o << partsep; if (logp) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index c4fe3c4d..3b53fd6b 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -325,7 +325,7 @@ struct DecoderImpl { static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) - trg->set_value(it->first, it->second); + trg->set_value(it->first, it->second.as_float()); } }; @@ -788,10 +788,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { - const PRPair<double, double> res = - Inside<PRPair<double, double>, - PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest); - cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + const PRPair<prob_t, prob_t> res = + Inside<PRPair<prob_t, prob_t>, + PRWeightFunction<prob_t, EdgeProb, prob_t, ELengthWeightFunction> >(forest); + cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl; } if (conf.count("show_partition")) { diff --git a/decoder/hg.cc b/decoder/hg.cc index 3ad17f1a..180986d7 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -157,14 +157,14 @@ prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector<prob_t>* posts) co const ScaledEdgeProb weight(scale); const ScaledTransitionEventWeightFunction w2(scale); SparseVector<prob_t> pv; - const double inside = InsideOutside<prob_t, + const prob_t inside = InsideOutside<prob_t, ScaledEdgeProb, SparseVector<prob_t>, ScaledTransitionEventWeightFunction>(*this, &pv, weight, w2); posts->resize(edges_.size()); for (int i = 0; i < edges_.size(); ++i) (*posts)[i] = prob_t(pv.value(i)); - return prob_t(inside); + return inside; } prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const { diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l index 9331d8ed..083a5bb1 100644 --- a/decoder/rule_lexer.l +++ b/decoder/rule_lexer.l @@ -220,6 +220,8 @@ NT [^\t \[\],]+ std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); } + // const bool ignore_grammar_features = false; + // if (ignore_grammar_features) scfglex_num_feats = 0; TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als)); check_and_update_ctf_stack(rp); TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top()); diff --git a/decoder/trule.h b/decoder/trule.h index 4df4ec90..8eb2a059 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -5,7 +5,9 @@ #include <vector> #include <cassert> #include <iostream> -#include <boost/shared_ptr.hpp> + +#include "boost/shared_ptr.hpp" +#include "boost/functional/hash.hpp" #include "sparse_vector.h" #include "wordid.h" @@ -162,4 +164,15 @@ class TRule { bool SanityCheck() const; }; +inline size_t hash_value(const TRule& r) { + size_t h = boost::hash_value(r.e_); + boost::hash_combine(h, -r.lhs_); + boost::hash_combine(h, boost::hash_value(r.f_)); + return h; +} + +inline bool operator==(const TRule& a, const TRule& b) { + return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_); +} + #endif |