summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorGuest_account Guest_account prguest11 <prguest11@taipan.cs>2011-10-11 16:16:53 +0100
committerGuest_account Guest_account prguest11 <prguest11@taipan.cs>2011-10-11 16:16:53 +0100
commit08c4a7fae8f0bec4f76c4e0928e357100eb7a1ca (patch)
tree44030db9ef1625ce130ab08acfd308643d568d1f /decoder
parentffaae62e4f1cedabbc6eb1982af129e7294d33eb (diff)
remove implicit conversion-to-double operator from LogVal<T> that caused overflow errors, clean up some pf code
Diffstat (limited to 'decoder')
-rw-r--r--decoder/aligner.cc2
-rwxr-xr-xdecoder/cfg.cc2
-rwxr-xr-xdecoder/cfg_format.h2
-rw-r--r--decoder/decoder.cc10
-rw-r--r--decoder/hg.cc4
-rw-r--r--decoder/rule_lexer.l2
-rw-r--r--decoder/trule.h15
7 files changed, 26 insertions, 11 deletions
diff --git a/decoder/aligner.cc b/decoder/aligner.cc
index 292ee123..53e059fb 100644
--- a/decoder/aligner.cc
+++ b/decoder/aligner.cc
@@ -165,7 +165,7 @@ inline void WriteProbGrid(const Array2D<prob_t>& m, ostream* pos) {
if (m(i,j) == prob_t::Zero()) {
os << "\t---X---";
} else {
- snprintf(b, 1024, "%0.5f", static_cast<double>(m(i,j)));
+ snprintf(b, 1024, "%0.5f", m(i,j).as_float());
os << '\t' << b;
}
}
diff --git a/decoder/cfg.cc b/decoder/cfg.cc
index 651978d2..cd7e66e9 100755
--- a/decoder/cfg.cc
+++ b/decoder/cfg.cc
@@ -639,7 +639,7 @@ void CFG::Print(std::ostream &o,CFGFormat const& f) const {
o << '['<<f.goal_nt_name <<']';
WordID rhs=-goal_nt;
f.print_rhs(o,*this,&rhs,&rhs+1);
- if (pushed_inside!=1)
+ if (pushed_inside!=prob_t::One())
f.print_features(o,pushed_inside);
o<<'\n';
}
diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h
index c6a594b8..2f40d483 100755
--- a/decoder/cfg_format.h
+++ b/decoder/cfg_format.h
@@ -101,7 +101,7 @@ struct CFGFormat {
}
void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const {
- bool logp=(logprob_feat && p!=1);
+ bool logp=(logprob_feat && p!=prob_t::One());
if (features || logp) {
o << partsep;
if (logp)
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index c4fe3c4d..3b53fd6b 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -325,7 +325,7 @@ struct DecoderImpl {
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
- trg->set_value(it->first, it->second);
+ trg->set_value(it->first, it->second.as_float());
}
};
@@ -788,10 +788,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const bool show_tree_structure=conf.count("show_tree_structure");
if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation);
if (conf.count("show_expected_length")) {
- const PRPair<double, double> res =
- Inside<PRPair<double, double>,
- PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
- cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ const PRPair<prob_t, prob_t> res =
+ Inside<PRPair<prob_t, prob_t>,
+ PRWeightFunction<prob_t, EdgeProb, prob_t, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl;
}
if (conf.count("show_partition")) {
diff --git a/decoder/hg.cc b/decoder/hg.cc
index 3ad17f1a..180986d7 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -157,14 +157,14 @@ prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector<prob_t>* posts) co
const ScaledEdgeProb weight(scale);
const ScaledTransitionEventWeightFunction w2(scale);
SparseVector<prob_t> pv;
- const double inside = InsideOutside<prob_t,
+ const prob_t inside = InsideOutside<prob_t,
ScaledEdgeProb,
SparseVector<prob_t>,
ScaledTransitionEventWeightFunction>(*this, &pv, weight, w2);
posts->resize(edges_.size());
for (int i = 0; i < edges_.size(); ++i)
(*posts)[i] = prob_t(pv.value(i));
- return prob_t(inside);
+ return inside;
}
prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const {
diff --git a/decoder/rule_lexer.l b/decoder/rule_lexer.l
index 9331d8ed..083a5bb1 100644
--- a/decoder/rule_lexer.l
+++ b/decoder/rule_lexer.l
@@ -220,6 +220,8 @@ NT [^\t \[\],]+
std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n";
abort();
}
+ // const bool ignore_grammar_features = false;
+ // if (ignore_grammar_features) scfglex_num_feats = 0;
TRulePtr rp(new TRule(scfglex_lhs, scfglex_src_rhs, scfglex_src_rhs_size, scfglex_trg_rhs, scfglex_trg_rhs_size, scfglex_feat_ids, scfglex_feat_vals, scfglex_num_feats, scfglex_src_arity, scfglex_als, scfglex_num_als));
check_and_update_ctf_stack(rp);
TRulePtr coarse_rp = ((ctf_level == 0) ? TRulePtr() : ctf_rule_stack.top());
diff --git a/decoder/trule.h b/decoder/trule.h
index 4df4ec90..8eb2a059 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -5,7 +5,9 @@
#include <vector>
#include <cassert>
#include <iostream>
-#include <boost/shared_ptr.hpp>
+
+#include "boost/shared_ptr.hpp"
+#include "boost/functional/hash.hpp"
#include "sparse_vector.h"
#include "wordid.h"
@@ -162,4 +164,15 @@ class TRule {
bool SanityCheck() const;
};
+inline size_t hash_value(const TRule& r) {
+ size_t h = boost::hash_value(r.e_);
+ boost::hash_combine(h, -r.lhs_);
+ boost::hash_combine(h, boost::hash_value(r.f_));
+ return h;
+}
+
+inline bool operator==(const TRule& a, const TRule& b) {
+ return (a.lhs_ == b.lhs_ && a.e_ == b.e_ && a.f_ == b.f_);
+}
+
#endif