summaryrefslogtreecommitdiff
path: root/decoder/apply_models.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r--decoder/apply_models.cc37
1 files changed, 33 insertions, 4 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index 9f8bbead..3f3f6a79 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -233,7 +233,25 @@ public:
void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) {
Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_);
new_edge->edge_prob_ = item->out_edge_.edge_prob_;
- Candidate*& o_item = (*s2n)[item->state_];
+ //start: new code by lijunhui
+ FFState real_state;
+ FFState* real_state_ref;
+ if (models.HaveEraseState()) {
+ models.GetRealFFState(item->state_, real_state);
+ real_state_ref = &real_state;
+ }
+ else
+ real_state_ref = &(item->state_);
+ Candidate*& o_item = (*s2n)[(*real_state_ref)];
+ /*FFState real_state;
+ models.GetRealFFState(item->state_, real_state);
+ Candidate*& o_item = (*s2n)[real_state];*/
+ //end: new code by lijunhui
+
+ //start: original code
+ //Candidate*& o_item = (*s2n)[item->state_];
+ //end: original code
+
if (!o_item) o_item = item;
int& node_id = o_item->node_index_;
@@ -254,7 +272,19 @@ public:
// score is the same for all items with a common residual DP
// state
if (item->vit_prob_ > o_item->vit_prob_) {
- assert(o_item->state_ == item->state_); // sanity check!
+ //start: new code by lijunhui
+ if (models.HaveEraseState()) {
+ assert(models.GetRealFFState(o_item->state_) == models.GetRealFFState(item->state_)); // sanity check!
+ node_states_[o_item->node_index_] = item->state_;
+ } else {
+ assert(o_item->state_ == item->state_); // sanity check!
+ }
+ //end: new code by lijunhui
+
+ //start: original code
+ //assert(o_item->state_ == item->state_); // sanity check!
+ //end: original code
+
o_item->est_prob_ = item->est_prob_;
o_item->vit_prob_ = item->vit_prob_;
}
@@ -599,7 +629,7 @@ void ApplyModelSet(const Hypergraph& in,
if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) {
NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state
ma.Apply();
- } else if (config.algorithm == IntersectionConfiguration::CUBE
+ } else if (config.algorithm == IntersectionConfiguration::CUBE
|| config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING
|| config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2) {
int pl = config.pop_limit;
@@ -628,4 +658,3 @@ void ApplyModelSet(const Hypergraph& in,
out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed
// automatically
}
-