diff options
author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
---|---|---|
committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
commit | 3d8d656fa7911524e0e6885647173474524e0784 (patch) | |
tree | 81b1ee2fcb67980376d03f0aa48e42e53abff222 /decoder/ffset.cc | |
parent | be7f57fdd484e063775d7abf083b9fa4c403b610 (diff) | |
parent | 96fedabebafe7a38a6d5928be8fff767e411d705 (diff) |
fixed conflicts
Diffstat (limited to 'decoder/ffset.cc')
-rw-r--r-- | decoder/ffset.cc | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/decoder/ffset.cc b/decoder/ffset.cc new file mode 100644 index 00000000..5820f421 --- /dev/null +++ b/decoder/ffset.cc @@ -0,0 +1,72 @@ +#include "ffset.h" + +#include "ff.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->StateSize(); + } +} + +void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { + for (int i = 0; i < models_.size(); ++i) + const_cast<FeatureFunction*>(models_[i])->PrepareForInput(smeta); +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& /* hg */, + const FFStates& node_states, + HG::Edge* edge, + FFState* context, + prob_t* combination_cost_estimate) const { + //edge->reset_info(); + context->resize(state_size_); + if (state_size_ > 0) { + memset(&(*context)[0], 0, state_size_); + } + SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector<const void*> ants(edge->tail_nodes_.size()); + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &node_states[edge->tail_nodes_[i]][spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const { + assert(1 == edge->rule_->Arity()); + //edge->reset_info(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.StateSize() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + |