diff options
author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-10-16 01:27:30 -0400 |
---|---|---|
committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-10-16 01:27:30 -0400 |
commit | 62bf03f0ec79b658c9208f0ae8976397e64a7dee (patch) | |
tree | d95676017f04b2abb6315c5f550064c03aeda702 /decoder/ffset.cc | |
parent | 501c0419200234f4ff6b8c3d9f9602596a96c28e (diff) | |
parent | 51b5c16c9110999ac573bd3383d7eb0e3f10fc37 (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'decoder/ffset.cc')
-rw-r--r-- | decoder/ffset.cc | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/decoder/ffset.cc b/decoder/ffset.cc new file mode 100644 index 00000000..653a29f8 --- /dev/null +++ b/decoder/ffset.cc @@ -0,0 +1,72 @@ +#include "ffset.h" + +#include "ff.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->StateSize(); + } +} + +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { + for (int i = 0; i < models_.size(); ++i) + const_cast<FeatureFunction*>(models_[i])->PrepareForInput(smeta); +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& /* hg */, + const FFStates& node_states, + HG::Edge* edge, + FFState* context, + prob_t* combination_cost_estimate) const { + //edge->reset_info(); + context->resize(state_size_); + if (state_size_ > 0) { + memset(&(*context)[0], 0, state_size_); + } + SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector<const void*> ants(edge->tail_nodes_.size()); + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &node_states[edge->tail_nodes_[i]][spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const { + assert(1 == edge->rule_->Arity()); + //edge->reset_info(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(smeta, *edge, ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + |