summaryrefslogtreecommitdiff
path: root/decoder/ffset.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ffset.cc')
-rw-r--r--decoder/ffset.cc72
1 files changed, 72 insertions, 0 deletions
diff --git a/decoder/ffset.cc b/decoder/ffset.cc
new file mode 100644
index 00000000..5820f421
--- /dev/null
+++ b/decoder/ffset.cc
@@ -0,0 +1,72 @@
+#include "ffset.h"
+
+#include "ff.h"
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) :
+ models_(models),
+ weights_(w),
+ state_size_(0),
+ model_state_pos_(models.size()) {
+ for (int i = 0; i < models_.size(); ++i) {
+ model_state_pos_[i] = state_size_;
+ state_size_ += models_[i]->StateSize();
+ }
+}
+
+void ModelSet::PrepareForInput(const SentenceMetadata& smeta) {
+ for (int i = 0; i < models_.size(); ++i)
+ const_cast<FeatureFunction*>(models_[i])->PrepareForInput(smeta);
+}
+
+void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& /* hg */,
+ const FFStates& node_states,
+ HG::Edge* edge,
+ FFState* context,
+ prob_t* combination_cost_estimate) const {
+ //edge->reset_info();
+ context->resize(state_size_);
+ if (state_size_ > 0) {
+ memset(&(*context)[0], 0, state_size_);
+ }
+ SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL
+ if (combination_cost_estimate) *combination_cost_estimate = prob_t::One();
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ void* cur_ff_context = NULL;
+ vector<const void*> ants(edge->tail_nodes_.size());
+ bool has_context = ff.StateSize() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ cur_ff_context = &(*context)[spos];
+ for (int i = 0; i < ants.size(); ++i) {
+ ants[i] = &node_states[edge->tail_nodes_[i]][spos];
+ }
+ }
+ ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context);
+ }
+ if (combination_cost_estimate)
+ combination_cost_estimate->logeq(est_vals.dot(weights_));
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
+void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMetadata const& smeta) const {
+ assert(1 == edge->rule_->Arity());
+ //edge->reset_info();
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ const void* ant_state = NULL;
+ bool has_context = ff.StateSize() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ ant_state = &state[spos];
+ }
+ ff.FinalTraversalFeatures(ant_state, &edge->feature_values_);
+ }
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+