summaryrefslogtreecommitdiff
path: root/decoder/apply_models.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r--decoder/apply_models.cc36
1 files changed, 30 insertions, 6 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9f8bbead..18c83fd4 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -233,7 +233,20 @@ public:
void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {
Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);
new_edge->edge_prob_ = item->out_edge_.edge_prob_;
- Candidate*& o_item = (*s2n)[item->state_];
+
+ Candidate** o_item_ptr = nullptr;
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // When erasure of certain state bytes is needed, we must make a copy of
+ // the state instead of doing the erasure in-place because future
+ // candidates may require the information in the bytes to be erased.
+ FFState state(item->state_);
+ models.EraseIgnoredBytes(&state);
+ o_item_ptr = &(*s2n)[state];
+ } else {
+ o_item_ptr = &(*s2n)[item->state_];
+ }
+ Candidate*& o_item = *o_item_ptr;
+
if (!o_item) o_item = item;
int& node_id = o_item->node_index_;
@@ -254,7 +267,18 @@ public:
// score is the same for all items with a common residual DP
// state
if (item->vit_prob_ > o_item->vit_prob_) {
- assert(o_item->state_ == item->state_); // sanity check!
+ if (item->state_.size() && models.NeedsStateErasure()) {
+ // node_states_ should still point to the unerased state.
+ node_states_[o_item->node_index_] = item->state_;
+ // sanity check!
+ FFState item_state(item->state_), o_item_state(o_item->state_);
+ models.EraseIgnoredBytes(&item_state);
+ models.EraseIgnoredBytes(&o_item_state);
+ assert(item_state == o_item_state);
+ } else {
+ assert(o_item->state_ == item->state_); // sanity check!
+ }
+
o_item->est_prob_ = item->est_prob_;
o_item->vit_prob_ = item->vit_prob_;
}
@@ -599,9 +623,10 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
- || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
+ } else if (config.algorithm == IntersectionConfiguration::CUBE ||
+ config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING ||
+ config.algorithm ==
+ IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
const int max_pl_for_large=50;
if (pl > max_pl_for_large && in.nodes_.size() > 80000) {
@@ -628,4 +653,3 @@ void ApplyModelSet(const Hypergraph& in,
out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed
// automatically
}
-