diff options
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r-- | decoder/apply_models.cc | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 9f8bbead..3f3f6a79 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -233,7 +233,25 @@ public: void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); new_edge->edge_prob_ = item->out_edge_.edge_prob_; - Candidate*& o_item = (*s2n)[item->state_]; + //start: new code by lijunhui + FFState real_state; + FFState* real_state_ref; + if (models.HaveEraseState()) { + models.GetRealFFState(item->state_, real_state); + real_state_ref = &real_state; + } + else + real_state_ref = &(item->state_); + Candidate*& o_item = (*s2n)[(*real_state_ref)]; + /*FFState real_state; + models.GetRealFFState(item->state_, real_state); + Candidate*& o_item = (*s2n)[real_state];*/ + //end: new code by lijunhui + + //start: original code + //Candidate*& o_item = (*s2n)[item->state_]; + //end: original code + if (!o_item) o_item = item; int& node_id = o_item->node_index_; @@ -254,7 +272,19 @@ public: // score is the same for all items with a common residual DP // state if (item->vit_prob_ > o_item->vit_prob_) { - assert(o_item->state_ == item->state_); // sanity check! + //start: new code by lijunhui + if (models.HaveEraseState()) { + assert(models.GetRealFFState(o_item->state_) == models.GetRealFFState(item->state_)); // sanity check! + node_states_[o_item->node_index_] = item->state_; + } else { + assert(o_item->state_ == item->state_); // sanity check! + } + //end: new code by lijunhui + + //start: original code + //assert(o_item->state_ == item->state_); // sanity check! + //end: original code + o_item->est_prob_ = item->est_prob_; o_item->vit_prob_ = item->vit_prob_; } @@ -599,7 +629,7 @@ void ApplyModelSet(const Hypergraph& in, if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state ma.Apply(); - } else if (config.algorithm == IntersectionConfiguration::CUBE + } else if (config.algorithm == IntersectionConfiguration::CUBE || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING || config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) { int pl = config.pop_limit; @@ -628,4 +658,3 @@ void ApplyModelSet(const Hypergraph& in, out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed // automatically } - |