#include "apply_models.h" #include <vector> #include <algorithm> #include <tr1/unordered_map> #include <tr1/unordered_set> #include <boost/tuple/tuple.hpp> #include <boost/functional/hash.hpp> #include "tdict.h" #include "hg.h" #include "ff.h" using boost::tuple; using namespace std; using namespace std::tr1; namespace Hack { struct Candidate; typedef SmallVectorInt JVector; typedef vector<Candidate*> CandidateHeap; typedef vector<Candidate*> CandidateList; // life cycle: candidates are created, placed on the heap // and retrieved by their estimated cost, when they're // retrieved, they're incorporated into the +LM hypergraph // where they also know the head node index they are // attached to. After they are added to the +LM hypergraph // inside_prob_ and est_prob_ fields may be updated as better // derivations are found (this happens since the successor's // of derivation d may have a better score- they are // explored lazily). However, the updates don't happen // when a candidate is in the heap so maintaining the heap // property is not an issue. struct Candidate { int node_index_; // -1 until incorporated // into the +LM forest const Hypergraph::Edge* in_edge_; // in -LM forest Hypergraph::Edge out_edge_; vector<WordID> state_; const JVector j_; prob_t inside_prob_; // these are fixed until the cand // is popped, then they may be updated prob_t est_prob_; Candidate(const Hypergraph::Edge& e, const JVector& j, const vector<CandidateList>& D, bool is_goal) : node_index_(-1), in_edge_(&e), j_(j) { InitializeCandidate(D, is_goal); } // used to query uniqueness Candidate(const Hypergraph::Edge& e, const JVector& j) : in_edge_(&e), j_(j) {} bool IsIncorporatedIntoHypergraph() const { return node_index_ >= 0; } void InitializeCandidate(const vector<vector<Candidate*> >& D, const bool is_goal) { const Hypergraph::Edge& in_edge = *in_edge_; out_edge_.rule_ = in_edge.rule_; out_edge_.feature_values_ = in_edge.feature_values_; Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; tail.resize(j_.size()); prob_t p = prob_t::One(); // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; vector<const vector<WordID>* > ants(tail.size()); for (int i = 0; i < tail.size(); ++i) { const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; ants[i] = &ant.state_; assert(ant.IsIncorporatedIntoHypergraph()); tail[i] = ant.node_index_; p *= ant.inside_prob_; } prob_t edge_estimate = prob_t::One(); if (is_goal) { assert(tail.size() == 1); out_edge_.edge_prob_ = in_edge.edge_prob_; } else { in_edge.rule_->ESubstitute(ants, &state_); out_edge_.edge_prob_ = in_edge.edge_prob_; } inside_prob_ = out_edge_.edge_prob_ * p; est_prob_ = inside_prob_ * edge_estimate; } }; ostream& operator<<(ostream& os, const Candidate& cand) { os << "CAND["; if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } else { os << "+LM_node=" << cand.node_index_; } os << " edge=" << cand.in_edge_->id_; os << " j=<"; for (int i = 0; i < cand.j_.size(); ++i) os << (i==0 ? "" : " ") << cand.j_[i]; os << "> vit=" << log(cand.inside_prob_); os << " est=" << log(cand.est_prob_); return os << ']'; } struct HeapCandCompare { bool operator()(const Candidate* l, const Candidate* r) const { return l->est_prob_ < r->est_prob_; } }; struct EstProbSorter { bool operator()(const Candidate* l, const Candidate* r) const { return l->est_prob_ > r->est_prob_; } }; // the same candidate <edge, j> can be added multiple times if // j is multidimensional (if you're going NW in Manhattan, you // can first go north, then west, or you can go west then north) // this is a hash function on the relevant variables from // Candidate to enforce this. struct CandidateUniquenessHash { size_t operator()(const Candidate* c) const { size_t x = 5381; x = ((x << 5) + x) ^ c->in_edge_->id_; for (int i = 0; i < c->j_.size(); ++i) x = ((x << 5) + x) ^ c->j_[i]; return x; } }; struct CandidateUniquenessEquals { bool operator()(const Candidate* a, const Candidate* b) const { return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); } }; typedef unordered_set<const Candidate*, CandidateUniquenessHash, CandidateUniquenessEquals> UniqueCandidateSet; typedef unordered_map<vector<WordID>, Candidate*, boost::hash<vector<WordID> > > State2Node; class MaxTransBeamSearch { public: MaxTransBeamSearch(const Hypergraph& i, int pop_limit, Hypergraph* o) : in(i), out(*o), D(in.nodes_.size()), pop_limit_(pop_limit) { cerr << " Finding max translation (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; } void Apply() { int num_nodes = in.nodes_.size(); int goal_id = num_nodes - 1; int pregoal = goal_id - 1; assert(in.nodes_[pregoal].out_edges_.size() == 1); cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { cerr << '.'; KBest(i, i == goal_id); } cerr << endl; int best_node = D[goal_id].front()->in_edge_->tail_nodes_.front(); Candidate& best = *D[best_node].front(); cerr << " Best path: " << log(best.inside_prob_) << "\t" << log(best.est_prob_) << endl; cout << TD::GetString(D[best_node].front()->state_) << endl; FreeAll(); } private: void FreeAll() { for (int i = 0; i < D.size(); ++i) { CandidateList& D_i = D[i]; for (int j = 0; j < D_i.size(); ++j) delete D_i[j]; } D.clear(); } void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_); new_edge->feature_values_ = item->out_edge_.feature_values_; new_edge->edge_prob_ = item->out_edge_.edge_prob_; Candidate*& o_item = (*s2n)[item->state_]; if (!o_item) o_item = item; int& node_id = o_item->node_index_; if (node_id < 0) { Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); node_id = new_node->id_; } Hypergraph::Node* node = &out.nodes_[node_id]; out.ConnectEdgeToHeadNode(new_edge, node); if (item != o_item) { assert(o_item->state_ == item->state_); // sanity check! o_item->est_prob_ += item->est_prob_; o_item->inside_prob_ += item->inside_prob_; freelist->push_back(item); } } void KBest(const int vert_index, const bool is_goal) { // cerr << "KBest(" << vert_index << ")\n"; CandidateList& D_v = D[vert_index]; assert(D_v.empty()); const Hypergraph::Node& v = in.nodes_[vert_index]; // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; const vector<int>& in_edges = v.in_edges_; CandidateHeap cand; CandidateList freelist; cand.reserve(in_edges.size()); UniqueCandidateSet unique_cands; for (int i = 0; i < in_edges.size(); ++i) { const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; const JVector j(edge.tail_nodes_.size(), 0); cand.push_back(new Candidate(edge, j, D, is_goal)); assert(unique_cands.insert(cand.back()).second); // these should all be unique! } // cerr << " making heap of " << cand.size() << " candidates\n"; make_heap(cand.begin(), cand.end(), HeapCandCompare()); State2Node state2node; // "buf" in Figure 2 int pops = 0; while(!cand.empty() && pops < pop_limit_) { pop_heap(cand.begin(), cand.end(), HeapCandCompare()); Candidate* item = cand.back(); cand.pop_back(); // cerr << "POPPED: " << *item << endl; PushSucc(*item, is_goal, &cand, &unique_cands); IncorporateIntoPlusLMForest(item, &state2node, &freelist); ++pops; } D_v.resize(state2node.size()); int c = 0; for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) D_v[c++] = i->second; sort(D_v.begin(), D_v.end(), EstProbSorter()); // cerr << " expanded to " << D_v.size() << " nodes\n"; for (int i = 0; i < cand.size(); ++i) delete cand[i]; // freelist is necessary since even after an item merged, it still stays in // the unique set so it can't be deleted til now for (int i = 0; i < freelist.size(); ++i) delete freelist[i]; } void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { CandidateHeap& cand = *pcand; for (int i = 0; i < item.j_.size(); ++i) { JVector j = item.j_; ++j[i]; if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { Candidate query_unique(*item.in_edge_, j); if (cs->count(&query_unique) == 0) { Candidate* new_cand = new Candidate(*item.in_edge_, j, D, is_goal); cand.push_back(new_cand); push_heap(cand.begin(), cand.end(), HeapCandCompare()); assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check } } } } const Hypergraph& in; Hypergraph& out; vector<CandidateList> D; // maps nodes in in-HG to the // equivalent nodes (many due to state // splits) in the out-HG. const int pop_limit_; }; // each node in the graph has one of these, it keeps track of void MaxTrans(const Hypergraph& in, int beam_size) { Hypergraph out; MaxTransBeamSearch ma(in, beam_size, &out); ma.Apply(); } }