summaryrefslogtreecommitdiff
path: root/decoder/ff.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-22 05:12:27 +0000
commit0172721855098ca02b207231a654dffa5e4eb1c9 (patch)
tree8069c3a62e2d72bd64a2cdeee9724b2679c8a56b /decoder/ff.cc
parent37728b8be4d0b3df9da81fdda2198ff55b4b2d91 (diff)
initial checkin
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@2 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/ff.cc')
-rw-r--r--decoder/ff.cc137
1 files changed, 137 insertions, 0 deletions
diff --git a/decoder/ff.cc b/decoder/ff.cc
new file mode 100644
index 00000000..61f4f0b6
--- /dev/null
+++ b/decoder/ff.cc
@@ -0,0 +1,137 @@
+#include "ff.h"
+
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+FeatureFunction::~FeatureFunction() {}
+
+
+void FeatureFunction::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ (void) ant_state;
+ (void) features;
+}
+
+// Hiero and Joshua use log_10(e) as the value, so I do to
+WordPenalty::WordPenalty(const string& param) :
+ fid_(FD::Convert("WordPenalty")),
+ value_(-1.0 / log(10)) {
+ if (!param.empty()) {
+ cerr << "Warning WordPenalty ignoring parameter: " << param << endl;
+ }
+}
+
+void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ (void) ant_states;
+ (void) state;
+ (void) estimated_features;
+ features->set_value(fid_, edge.rule_->EWords() * value_);
+}
+
+SourceWordPenalty::SourceWordPenalty(const string& param) :
+ fid_(FD::Convert("SourceWordPenalty")),
+ value_(-1.0 / log(10)) {
+ if (!param.empty()) {
+ cerr << "Warning SourceWordPenalty ignoring parameter: " << param << endl;
+ }
+}
+
+void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ (void) ant_states;
+ (void) state;
+ (void) estimated_features;
+ features->set_value(fid_, edge.rule_->FWords() * value_);
+}
+
+ArityPenalty::ArityPenalty(const std::string& param) :
+ value_(-1.0 / log(10)) {
+ string fname = "Arity_X";
+ for (int i = 0; i < 10; ++i) {
+ fname[6]=i + '0';
+ fids_[i] = FD::Convert(fname);
+ }
+}
+
+void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ (void) ant_states;
+ (void) state;
+ (void) estimated_features;
+ features->set_value(fids_[edge.Arity()], value_);
+}
+
+ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) :
+ models_(models),
+ weights_(w),
+ state_size_(0),
+ model_state_pos_(models.size()) {
+ for (int i = 0; i < models_.size(); ++i) {
+ model_state_pos_[i] = state_size_;
+ state_size_ += models_[i]->NumBytesContext();
+ }
+}
+
+void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& hg,
+ const vector<string>& node_states,
+ Hypergraph::Edge* edge,
+ string* context,
+ prob_t* combination_cost_estimate) const {
+ context->resize(state_size_);
+ memset(&(*context)[0], 0, state_size_);
+ SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL
+ if (combination_cost_estimate) *combination_cost_estimate = prob_t::One();
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ void* cur_ff_context = NULL;
+ vector<const void*> ants(edge->tail_nodes_.size());
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ cur_ff_context = &(*context)[spos];
+ for (int i = 0; i < ants.size(); ++i) {
+ ants[i] = &node_states[edge->tail_nodes_[i]][spos];
+ }
+ }
+ ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context);
+ }
+ if (combination_cost_estimate)
+ combination_cost_estimate->logeq(est_vals.dot(weights_));
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
+void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const {
+ assert(1 == edge->rule_->Arity());
+
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ const void* ant_state = NULL;
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ ant_state = &state[spos];
+ }
+ ff.FinalTraversalFeatures(ant_state, &edge->feature_values_);
+ }
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+