summaryrefslogtreecommitdiff
path: root/decoder/apply_models.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/apply_models.cc')
-rw-r--r--decoder/apply_models.cc89
1 files changed, 78 insertions, 11 deletions
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index b1d002f4..a340aa1a 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -296,14 +296,69 @@ public:
};
struct NoPruningRescorer {
- NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) :
+ NoPruningRescorer(const ModelSet& m, const SentenceMetadata &sm, const Hypergraph& i, Hypergraph* o) :
models(m),
+ smeta(sm),
in(i),
- out(*o) {
+ out(*o),
+ nodemap(i.nodes_.size()) {
cerr << " Rescoring forest (full intersection)\n";
}
- void RescoreNode(const int node_num, const bool is_goal) {
+ typedef unordered_map<string, int, boost::hash<string> > State2NodeIndex;
+
+ void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) {
+ const int arity = in_edge.Arity();
+ Hypergraph::TailNodeVector ends(arity);
+ for (int i = 0; i < arity; ++i)
+ ends[i] = nodemap[in_edge.tail_nodes_[i]].size();
+
+ Hypergraph::TailNodeVector tail_iter(arity, 0);
+ bool done = false;
+ while (!done) {
+ Hypergraph::TailNodeVector tail(arity);
+ for (int i = 0; i < arity; ++i)
+ tail[i] = nodemap[in_edge.tail_nodes_[i]][tail_iter[i]];
+ Hypergraph::Edge* new_edge = out.AddEdge(in_edge.rule_, tail);
+ new_edge->feature_values_ = in_edge.feature_values_;
+ new_edge->i_ = in_edge.i_;
+ new_edge->j_ = in_edge.j_;
+ new_edge->prev_i_ = in_edge.prev_i_;
+ new_edge->prev_j_ = in_edge.prev_j_;
+ string head_state;
+ if (is_goal) {
+ assert(tail.size() == 1);
+ const string& ant_state = out.nodes_[tail.front()].state_;
+ models.AddFinalFeatures(ant_state, new_edge);
+ } else {
+ prob_t edge_estimate; // this is a full intersection, so we disregard this
+ models.AddFeaturesToEdge(smeta, out, new_edge, &head_state, &edge_estimate);
+ }
+ int& head_plus1 = (*state2node)[head_state];
+ if (!head_plus1) {
+ head_plus1 = out.AddNode(in_edge.rule_->GetLHS(), head_state)->id_ + 1;
+ nodemap[in_edge.head_node_].push_back(head_plus1 - 1);
+ }
+ const int head_index = head_plus1 - 1;
+ out.ConnectEdgeToHeadNode(new_edge->id_, head_index);
+
+ int ii = 0;
+ for (; ii < arity; ++ii) {
+ ++tail_iter[ii];
+ if (tail_iter[ii] < ends[ii]) break;
+ tail_iter[ii] = 0;
+ }
+ done = (ii == arity);
+ }
+ }
+
+ void ProcessOneNode(const int node_num, const bool is_goal) {
+ State2NodeIndex state2node;
+ const Hypergraph::Node& node = in.nodes_[node_num];
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]];
+ ExpandEdge(edge, is_goal, &state2node);
+ }
}
void Apply() {
@@ -316,29 +371,41 @@ struct NoPruningRescorer {
cerr << " ";
for (int i = 0; i < in.nodes_.size(); ++i) {
if (i % every == 0) cerr << '.';
- RescoreNode(i, i == goal_id);
+ ProcessOneNode(i, i == goal_id);
}
cerr << endl;
}
private:
const ModelSet& models;
+ const SentenceMetadata& smeta;
const Hypergraph& in;
Hypergraph& out;
+
+ vector<vector<int> > nodemap;
};
// each node in the graph has one of these, it keeps track of
void ApplyModelSet(const Hypergraph& in,
const SentenceMetadata& smeta,
const ModelSet& models,
- const PruningConfiguration& config,
+ const IntersectionConfiguration& config,
Hypergraph* out) {
- int pl = config.pop_limit;
- if (pl > 100 && in.nodes_.size() > 80000) {
- cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
- pl = 30;
+ // TODO special handling when all models are stateless
+ if (config.algorithm == 1) {
+ int pl = config.pop_limit;
+ if (pl > 100 && in.nodes_.size() > 80000) {
+ cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
+ pl = 30;
+ }
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ } else if (config.algorithm == 0) {
+ NoPruningRescorer ma(models, smeta, in, out);
+ ma.Apply();
+ } else {
+ cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
+ exit(1);
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
}