diff options
Diffstat (limited to 'decoder/ff.cc')
-rw-r--r-- | decoder/ff.cc | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/decoder/ff.cc b/decoder/ff.cc new file mode 100644 index 00000000..61f4f0b6 --- /dev/null +++ b/decoder/ff.cc @@ -0,0 +1,137 @@ +#include "ff.h" + +#include "tdict.h" +#include "hg.h" + +using namespace std; + +FeatureFunction::~FeatureFunction() {} + + +void FeatureFunction::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + (void) ant_state; + (void) features; +} + +// Hiero and Joshua use log_10(e) as the value, so I do to +WordPenalty::WordPenalty(const string& param) : + fid_(FD::Convert("WordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning WordPenalty ignoring parameter: " << param << endl; + } +} + +void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->EWords() * value_); +} + +SourceWordPenalty::SourceWordPenalty(const string& param) : + fid_(FD::Convert("SourceWordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning SourceWordPenalty ignoring parameter: " << param << endl; + } +} + +void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->FWords() * value_); +} + +ArityPenalty::ArityPenalty(const std::string& param) : + value_(-1.0 / log(10)) { + string fname = "Arity_X"; + for (int i = 0; i < 10; ++i) { + fname[6]=i + '0'; + fids_[i] = FD::Convert(fname); + } +} + +void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fids_[edge.Arity()], value_); +} + +ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->NumBytesContext(); + } +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + const vector<string>& node_states, + Hypergraph::Edge* edge, + string* context, + prob_t* combination_cost_estimate) const { + context->resize(state_size_); + memset(&(*context)[0], 0, state_size_); + SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector<const void*> ants(edge->tail_nodes_.size()); + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &node_states[edge->tail_nodes_[i]][spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const { + assert(1 == edge->rule_->Arity()); + + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + |