From 51b5c16c9110999ac573bd3383d7eb0e3f10fc37 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 16 Oct 2012 00:37:21 -0400 Subject: clean up of bad header includes --- decoder/ffset.cc | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 decoder/ffset.cc (limited to 'decoder/ffset.cc') diff --git a/decoder/ffset.cc b/decoder/ffset.cc new file mode 100644 index 00000000..653a29f8 --- /dev/null +++ b/decoder/ffset.cc @@ -0,0 +1,72 @@ +#include "ffset.h" + +#include "ff.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +ModelSet::ModelSet(const vector& w, const vector& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->StateSize(); + } +} + +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { + for (int i = 0; i < models_.size(); ++i) + const_cast(models_[i])->PrepareForInput(smeta); +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& /* hg */, + const FFStates& node_states, + HG::Edge* edge, + FFState* context, + prob_t* combination_cost_estimate) const { + //edge->reset_info(); + context->resize(state_size_); + if (state_size_ > 0) { + memset(&(*context)[0], 0, state_size_); + } + SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector ants(edge->tail_nodes_.size()); + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &node_states[edge->tail_nodes_[i]][spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const { + assert(1 == edge->rule_->Arity()); + //edge->reset_info(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(smeta, *edge, ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + -- cgit v1.2.3 From 21825a09d97c2e0afd20512f306fb25fed55e529 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 16 Oct 2012 11:59:21 -0400 Subject: remove confusing function --- decoder/ff.cc | 7 ------- decoder/ff.h | 10 +--------- decoder/ffset.cc | 2 +- 3 files changed, 2 insertions(+), 17 deletions(-) (limited to 'decoder/ffset.cc') diff --git a/decoder/ff.cc b/decoder/ff.cc index 6e276a5e..a6a035b5 100644 --- a/decoder/ff.cc +++ b/decoder/ff.cc @@ -25,13 +25,6 @@ string FeatureFunction::usage_helper(std::string const& name,std::string const& return r; } -void FeatureFunction::FinalTraversalFeatures(const SentenceMetadata& /* smeta */, - const HG::Edge& /* edge */, - const void* residual_state, - SparseVector* final_features) const { - FinalTraversalFeatures(residual_state,final_features); -} - void FeatureFunction::TraversalFeaturesImpl(const SentenceMetadata&, const Hypergraph::Edge&, const std::vector&, diff --git a/decoder/ff.h b/decoder/ff.h index 4acbb7e3..3280592e 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -51,18 +51,10 @@ class FeatureFunction { } // if there's some state left when you transition to the goal state, score - // it here. For example, the language model computes the cost of adding + // it here. For example, a language model might the cost of adding // and . - -protected: virtual void FinalTraversalFeatures(const void* residual_state, SparseVector* final_features) const; -public: - //override either this or one of above. - virtual void FinalTraversalFeatures(const SentenceMetadata& /* smeta */, - const HG::Edge& /* edge */, - const void* residual_state, - SparseVector* final_features) const; protected: // context is a pointer to a buffer of size NumBytesContext() that the diff --git a/decoder/ffset.cc b/decoder/ffset.cc index 653a29f8..5820f421 100644 --- a/decoder/ffset.cc +++ b/decoder/ffset.cc @@ -65,7 +65,7 @@ void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMet int spos = model_state_pos_[i]; ant_state = &state[spos]; } - ff.FinalTraversalFeatures(smeta, *edge, ant_state, &edge->feature_values_); + ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); } edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); } -- cgit v1.2.3