diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/apply_models.cc | 36 | ||||
| -rw-r--r-- | decoder/ff.h | 19 | ||||
| -rw-r--r-- | decoder/ffset.cc | 15 | ||||
| -rw-r--r-- | decoder/ffset.h | 8 | 
4 files changed, 68 insertions, 10 deletions
| diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9f8bbead..18c83fd4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -233,7 +233,20 @@ public:    void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {      Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);      new_edge->edge_prob_ = item->out_edge_.edge_prob_; -    Candidate*& o_item = (*s2n)[item->state_]; + +    Candidate** o_item_ptr = nullptr; +    if (item->state_.size() && models.NeedsStateErasure()) { +      // When erasure of certain state bytes is needed, we must make a copy of +      // the state instead of doing the erasure in-place because future +      // candidates may require the information in the bytes to be erased. +      FFState state(item->state_); +      models.EraseIgnoredBytes(&state); +      o_item_ptr = &(*s2n)[state]; +    } else { +      o_item_ptr = &(*s2n)[item->state_]; +    } +    Candidate*& o_item = *o_item_ptr; +      if (!o_item) o_item = item;      int& node_id = o_item->node_index_; @@ -254,7 +267,18 @@ public:      // score is the same for all items with a common residual DP      // state      if (item->vit_prob_ > o_item->vit_prob_) { -      assert(o_item->state_ == item->state_);    // sanity check! +      if (item->state_.size() && models.NeedsStateErasure()) { +        // node_states_ should still point to the unerased state. +        node_states_[o_item->node_index_] = item->state_; +        // sanity check! +        FFState item_state(item->state_), o_item_state(o_item->state_); +        models.EraseIgnoredBytes(&item_state); +        models.EraseIgnoredBytes(&o_item_state); +        assert(item_state == o_item_state); +      } else { +        assert(o_item->state_ == item->state_);  // sanity check! +      } +        o_item->est_prob_ = item->est_prob_;        o_item->vit_prob_ = item->vit_prob_;      } @@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in,    if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {      NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state      ma.Apply(); -  } else if (config.algorithm == IntersectionConfiguration::CUBE  -             || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING -             || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { +  } else if (config.algorithm == IntersectionConfiguration::CUBE || +             config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING || +             config.algorithm == +                 IntersectionConfiguration::FAST_CUBE_PRUNING_2) {      int pl = config.pop_limit;      const int max_pl_for_large=50;      if (pl > max_pl_for_large && in.nodes_.size() > 80000) { @@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in,    out->is_linear_chain_ = in.is_linear_chain_;  // TODO remove when this is computed                                                  // automatically  } - diff --git a/decoder/ff.h b/decoder/ff.h index 3280592e..afa3dbca 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -17,11 +17,17 @@ class FeatureFunction {    friend class ExternalFeature;   public:    std::string name_; // set by FF factory using usage() -  FeatureFunction() : state_size_() {} -  explicit FeatureFunction(int state_size) : state_size_(state_size) {} +  FeatureFunction() : state_size_(), ignored_state_size_() {} +  explicit FeatureFunction(int state_size, int ignored_state_size = 0) +      : state_size_(state_size), ignored_state_size_(ignored_state_size) {}    virtual ~FeatureFunction();    bool IsStateful() const { return state_size_ > 0; }    int StateSize() const { return state_size_; } +  // Returns the number of bytes in the state that should be ignored during +  // search. When non-zero, the last N bytes in the state should be ignored when +  // splitting a hypernode by the state. This allows the feature function to +  // store some side data and later retrieve it via the state bytes. +  int IgnoredStateSize() const { return ignored_state_size_; }    // override this.  not virtual because we want to expose this to factory template for help before creating a FF    static std::string usage(bool show_params,bool show_details) { @@ -71,12 +77,17 @@ class FeatureFunction {                                       SparseVector<double>* estimated_features,                                       void* context) const; -  // !!! ONLY call this from subclass *CONSTRUCTORS* !!! +  // !!! ONLY call these from subclass *CONSTRUCTORS* !!!    void SetStateSize(size_t state_size) {      state_size_ = state_size;    } + +  void SetIgnoredStateSize(size_t ignored_state_size) { +    ignored_state_size_ = ignored_state_size; +  } +   private: -  int state_size_; +  int state_size_, ignored_state_size_;  };  #endif diff --git a/decoder/ffset.cc b/decoder/ffset.cc index 5820f421..8ba70389 100644 --- a/decoder/ffset.cc +++ b/decoder/ffset.cc @@ -14,6 +14,11 @@ ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>    for (int i = 0; i < models_.size(); ++i) {      model_state_pos_[i] = state_size_;      state_size_ += models_[i]->StateSize(); +    int num_ignored_bytes = models_[i]->IgnoredStateSize(); +    if (num_ignored_bytes > 0) { +      ranges_to_erase_.push_back( +          {state_size_ - num_ignored_bytes, state_size_}); +    }    }  } @@ -70,3 +75,13 @@ void ModelSet::AddFinalFeatures(const FFState& state, HG::Edge* edge,SentenceMet    edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));  } +bool ModelSet::NeedsStateErasure() const { return !ranges_to_erase_.empty(); } + +void ModelSet::EraseIgnoredBytes(FFState* state) const { +  // TODO: can we memset? +  for (const auto& range : ranges_to_erase_) { +    for (int i = range.first; i < range.second; ++i) { +      (*state)[i] = 0; +    } +  } +} diff --git a/decoder/ffset.h b/decoder/ffset.h index 28aef667..a69a75fa 100644 --- a/decoder/ffset.h +++ b/decoder/ffset.h @@ -1,6 +1,7 @@  #ifndef _FFSET_H_  #define _FFSET_H_ +#include <utility>  #include <vector>  #include "value_array.h"  #include "prob.h" @@ -47,11 +48,18 @@ class ModelSet {    bool stateless() const { return !state_size_; } +  // Part of a feature state may be used for storing some side data for +  // calculating feature values but not necessary for splitting hypernodes. Such +  // bytes needs to be erased for hypernode splitting. +  bool NeedsStateErasure() const; +  void EraseIgnoredBytes(FFState* state) const; +   private:    std::vector<const FeatureFunction*> models_;    const std::vector<double>& weights_;    int state_size_;    std::vector<int> model_state_pos_; +  std::vector<std::pair<int, int> > ranges_to_erase_;  };  #endif | 
