summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.cc36
-rw-r--r--decoder/ff.h19
-rw-r--r--decoder/ffset.cc15
-rw-r--r--decoder/ffset.h8
4 files changed, 68 insertions, 10 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9f8bbead..18c83fd4 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -233,7 +233,20 @@ public:
void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {
Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);
new_edge->edge_prob_ = item->out_edge_.edge_prob_;
- Candidate*& o_item = (*s2n)[item->state_];
+
+ Candidate** o_item_ptr = nullptr;
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // When erasure of certain state bytes is needed, we must make a copy of
+ // the state instead of doing the erasure in-place because future
+ // candidates may require the information in the bytes to be erased.
+ FFState state(item->state_);
+ models.EraseIgnoredBytes(&state);
+ o_item_ptr = &(*s2n)[state];
+ } else {
+ o_item_ptr = &(*s2n)[item->state_];
+ }
+ Candidate*& o_item = *o_item_ptr;
+
if (!o_item) o_item = item;
int& node_id = o_item->node_index_;
@@ -254,7 +267,18 @@ public:
// score is the same for all items with a common residual DP
// state
if (item->vit_prob_ > o_item->vit_prob_) {
- assert(o_item->state_ == item->state_); // sanity check!
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // node_states_ should still point to the unerased state.
+ node_states_[o_item->node_index_] = item->state_;
+ // sanity check!
+ FFState item_state(item->state_), o_item_state(o_item->state_);
+ models.EraseIgnoredBytes(&item_state);
+ models.EraseIgnoredBytes(&o_item_state);
+ assert(item_state == o_item_state);
+ } else {
+ assert(o_item->state_ == item->state_); // sanity check!
+ }
+
o_item->est_prob_ = item->est_prob_;
o_item->vit_prob_ = item->vit_prob_;
}
@@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE ||
+ config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING ||
+ config.algorithm ==
+ IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
@@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in,
out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed
// automatically
}
-
diff --git a/decoder/ff.h b/decoder/ff.h
index 3280592e..afa3dbca 100644
--- a/decoder/ff.h
+++ b/decoder/ff.h
@@ -17,11 +17,17 @@ class FeatureFunction {
friend class ExternalFeature;
public:
std::string name_; // set by FF factory using usage()
- FeatureFunction() : state_size_() {}
- explicit FeatureFunction(int state_size) : state_size_(state_size) {}
+ FeatureFunction() : state_size_(), ignored_state_size_() {}
+ explicit FeatureFunction(int state_size, int ignored_state_size = 0)
+ : state_size_(state_size), ignored_state_size_(ignored_state_size) {}
virtual ~FeatureFunction();
bool IsStateful() const { return state_size_ > 0; }
int StateSize() const { return state_size_; }
+ // Returns the number of bytes in the state that should be ignored during
+ // search. When non-zero, the last N bytes in the state should be ignored when
+ // splitting a hypernode by the state. This allows the feature function to
+ // store some side data and later retrieve it via the state bytes.
+ int IgnoredStateSize() const { return ignored_state_size_; }
// override this. not virtual because we want to expose this to factory template for help before creating a FF
static std::string usage(bool show_params,bool show_details) {
@@ -71,12 +77,17 @@ class FeatureFunction {
SparseVector<double>* estimated_features,
void* context) const;
- // !!! ONLY call this from subclass *CONSTRUCTORS* !!!
+ // !!! ONLY call these from subclass *CONSTRUCTORS* !!!
void SetStateSize(size_t state_size) {
state_size_ = state_size;
}
+
+ void SetIgnoredStateSize(size_t ignored_state_size) {
+ ignored_state_size_ = ignored_state_size;
+ }
+
private:
- int state_size_;
+ int state_size_, ignored_state_size_;
};
#endif
diff --git a/decoder/ffset.cc b/decoder/ffset.cc
index 5820f421..8ba70389 100644
--- a/decoder/ffset.cc
+++ b/decoder/ffset.cc
@@ -14,6 +14,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>
for (int i = 0; i < models_.size(); ++i) {
model_state_pos_[i] = state_size_;
state_size_ += models_[i]->StateSize();
+ int num_ignored_bytes = models_[i]->IgnoredStateSize();
+ if (num_ignored_bytes > 0) {
+ ranges_to_erase_.push_back(
+ {state_size_ - num_ignored_bytes, state_size_});
+ }
}
}
@@ -70,3 +75,13 @@ void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMet
edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
}
+bool ModelSet::NeedsStateErasure() const { return !ranges_to_erase_.empty(); }
+
+void ModelSet::EraseIgnoredBytes(FFState* state) const {
+ // TODO: can we memset?
+ for (const auto& range : ranges_to_erase_) {
+ for (int i = range.first; i < range.second; ++i) {
+ (*state)[i] = 0;
+ }
+ }
+}
diff --git a/decoder/ffset.h b/decoder/ffset.h
index 28aef667..a69a75fa 100644
--- a/decoder/ffset.h
+++ b/decoder/ffset.h
@@ -1,6 +1,7 @@
#ifndef _FFSET_H_
#define _FFSET_H_
+#include <utility>
#include <vector>
#include "value_array.h"
#include "prob.h"
@@ -47,11 +48,18 @@ class ModelSet {
bool stateless() const { return !state_size_; }
+ // Part of a feature state may be used for storing some side data for
+ // calculating feature values but not necessary for splitting hypernodes. Such
+ // bytes needs to be erased for hypernode splitting.
+ bool NeedsStateErasure() const;
+ void EraseIgnoredBytes(FFState* state) const;
+
private:
std::vector<const FeatureFunction*> models_;
const std::vector<double>& weights_;
int state_size_;
std::vector<int> model_state_pos_;
+ std::vector<std::pair<int, int> > ranges_to_erase_;
};
#endif