#include "ffset.h" #include "ff.h" #include "tdict.h" #include "hg.h" using namespace std; ModelSet::ModelSet(const vector& w, const vector& models) : models_(models), weights_(w), state_size_(0), model_state_pos_(models.size()) { for (int i = 0; i < models_.size(); ++i) { model_state_pos_[i] = state_size_; state_size_ += models_[i]->StateSize(); } } void ModelSet::PrepareForInput(const SentenceMetadata& smeta) { for (int i = 0; i < models_.size(); ++i) const_cast(models_[i])->PrepareForInput(smeta); } void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, const Hypergraph& /* hg */, const FFStates& node_states, HG::Edge* edge, FFState* context, prob_t* combination_cost_estimate) const { //edge->reset_info(); context->resize(state_size_); if (state_size_ > 0) { memset(&(*context)[0], 0, state_size_); } SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); for (int i = 0; i < models_.size(); ++i) { const FeatureFunction& ff = *models_[i]; void* cur_ff_context = NULL; vector ants(edge->tail_nodes_.size()); bool has_context = ff.StateSize() > 0; if (has_context) { int spos = model_state_pos_[i]; cur_ff_context = &(*context)[spos]; for (int i = 0; i < ants.size(); ++i) { ants[i] = &node_states[edge->tail_nodes_[i]][spos]; } } ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); } if (combination_cost_estimate) combination_cost_estimate->logeq(est_vals.dot(weights_)); edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); } void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const { assert(1 == edge->rule_->Arity()); //edge->reset_info(); for (int i = 0; i < models_.size(); ++i) { const FeatureFunction& ff = *models_[i]; const void* ant_state = NULL; bool has_context = ff.StateSize() > 0; if (has_context) { int spos = model_state_pos_[i]; ant_state = &state[spos]; } ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); } edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); }