diff options
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r-- | decoder/apply_models.cc | 89 |
1 files changed, 78 insertions, 11 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index b1d002f4..a340aa1a 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -296,14 +296,69 @@ public: }; struct NoPruningRescorer { - NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) : + NoPruningRescorer(const ModelSet& m, const SentenceMetadata &sm, const Hypergraph& i, Hypergraph* o) : models(m), + smeta(sm), in(i), - out(*o) { + out(*o), + nodemap(i.nodes_.size()) { cerr << " Rescoring forest (full intersection)\n"; } - void RescoreNode(const int node_num, const bool is_goal) { + typedef unordered_map<string, int, boost::hash<string> > State2NodeIndex; + + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + const int arity = in_edge.Arity(); + Hypergraph::TailNodeVector ends(arity); + for (int i = 0; i < arity; ++i) + ends[i] = nodemap[in_edge.tail_nodes_[i]].size(); + + Hypergraph::TailNodeVector tail_iter(arity, 0); + bool done = false; + while (!done) { + Hypergraph::TailNodeVector tail(arity); + for (int i = 0; i < arity; ++i) + tail[i] = nodemap[in_edge.tail_nodes_[i]][tail_iter[i]]; + Hypergraph::Edge* new_edge = out.AddEdge(in_edge.rule_, tail); + new_edge->feature_values_ = in_edge.feature_values_; + new_edge->i_ = in_edge.i_; + new_edge->j_ = in_edge.j_; + new_edge->prev_i_ = in_edge.prev_i_; + new_edge->prev_j_ = in_edge.prev_j_; + string head_state; + if (is_goal) { + assert(tail.size() == 1); + const string& ant_state = out.nodes_[tail.front()].state_; + models.AddFinalFeatures(ant_state, new_edge); + } else { + prob_t edge_estimate; // this is a full intersection, so we disregard this + models.AddFeaturesToEdge(smeta, out, new_edge, &head_state, &edge_estimate); + } + int& head_plus1 = (*state2node)[head_state]; + if (!head_plus1) { + head_plus1 = out.AddNode(in_edge.rule_->GetLHS(), head_state)->id_ + 1; + nodemap[in_edge.head_node_].push_back(head_plus1 - 1); + } + const int head_index = head_plus1 - 1; + out.ConnectEdgeToHeadNode(new_edge->id_, head_index); + + int ii = 0; + for (; ii < arity; ++ii) { + ++tail_iter[ii]; + if (tail_iter[ii] < ends[ii]) break; + tail_iter[ii] = 0; + } + done = (ii == arity); + } + } + + void ProcessOneNode(const int node_num, const bool is_goal) { + State2NodeIndex state2node; + const Hypergraph::Node& node = in.nodes_[node_num]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; + ExpandEdge(edge, is_goal, &state2node); + } } void Apply() { @@ -316,29 +371,41 @@ struct NoPruningRescorer { cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { if (i % every == 0) cerr << '.'; - RescoreNode(i, i == goal_id); + ProcessOneNode(i, i == goal_id); } cerr << endl; } private: const ModelSet& models; + const SentenceMetadata& smeta; const Hypergraph& in; Hypergraph& out; + + vector<vector<int> > nodemap; }; // each node in the graph has one of these, it keeps track of void ApplyModelSet(const Hypergraph& in, const SentenceMetadata& smeta, const ModelSet& models, - const PruningConfiguration& config, + const IntersectionConfiguration& config, Hypergraph* out) { - int pl = config.pop_limit; - if (pl > 100 && in.nodes_.size() > 80000) { - cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; - pl = 30; + // TODO special handling when all models are stateless + if (config.algorithm == 1) { + int pl = config.pop_limit; + if (pl > 100 && in.nodes_.size() > 80000) { + cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; + pl = 30; + } + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } else if (config.algorithm == 0) { + NoPruningRescorer ma(models, smeta, in, out); + ma.Apply(); + } else { + cerr << "Don't understand intersection algorithm " << config.algorithm << endl; + exit(1); } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); } |