diff options
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r-- | decoder/apply_models.cc | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9f8bbead..18c83fd4 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -233,7 +233,20 @@ public: void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; - Candidate*& o_item = (*s2n)[item->state_]; + + Candidate** o_item_ptr = nullptr; + if (item->state_.size() && models.NeedsStateErasure()) { + // When erasure of certain state bytes is needed, we must make a copy of + // the state instead of doing the erasure in-place because future + // candidates may require the information in the bytes to be erased. + FFState state(item->state_); + models.EraseIgnoredBytes(&state); + o_item_ptr = &(*s2n)[state]; + } else { + o_item_ptr = &(*s2n)[item->state_]; + } + Candidate*& o_item = *o_item_ptr; + if (!o_item) o_item = item; int& node_id = o_item->node_index_; @@ -254,7 +267,18 @@ public: // score is the same for all items with a common residual DP // state if (item->vit_prob_ > o_item->vit_prob_) { - assert(o_item->state_ == item->state_); // sanity check! + if (item->state_.size() && models.NeedsStateErasure()) { + // node_states_ should still point to the unerased state. + node_states_[o_item->node_index_] = item->state_; + // sanity check! + FFState item_state(item->state_), o_item_state(o_item->state_); + models.EraseIgnoredBytes(&item_state); + models.EraseIgnoredBytes(&o_item_state); + assert(item_state == o_item_state); + } else { + assert(o_item->state_ == item->state_); // sanity check! + } + o_item->est_prob_ = item->est_prob_; o_item->vit_prob_ = item->vit_prob_; } @@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE - || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING - || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { + } else if (config.algorithm == IntersectionConfiguration::CUBE || + config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING || + config.algorithm == + IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; const int max_pl_for_large=50; if (pl > max_pl_for_large && in.nodes_.size() > 80000) { @@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in, out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed // automatically } - |