#ifndef _HG_KBEST_H_ #define _HG_KBEST_H_ #include <vector> #include <utility> #include <tr1/unordered_set> #include <boost/shared_ptr.hpp> #include "wordid.h" #include "hg.h" namespace KBest { // default, don't filter any derivations from the k-best list struct NoFilter { bool operator()(const std::vector<WordID>& yield) { (void) yield; return false; } }; // optional, filter unique yield strings struct FilterUnique { std::tr1::unordered_set<std::vector<WordID>, boost::hash<std::vector<WordID> > > unique; bool operator()(const std::vector<WordID>& yield) { return !unique.insert(yield).second; } }; // utility class to lazily create the k-best derivations from a forest, uses // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005) template<typename T, // yield type (returned by Traversal) typename Traversal, typename DerivationFilter = NoFilter, typename WeightType = prob_t, typename WeightFunction = EdgeProb> struct KBestDerivations { KBestDerivations(const Hypergraph& hg, const size_t k, const Traversal& tf = Traversal(), const WeightFunction& wf = WeightFunction()) : traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {} ~KBestDerivations() { for (int i = 0; i < freelist.size(); ++i) delete freelist[i]; } struct Derivation { Derivation(const Hypergraph::Edge& e, const SmallVectorInt& jv, const WeightType& w, const SparseVector<double>& f) : edge(&e), j(jv), score(w), feature_values(f) {} // dummy constructor, just for query Derivation(const Hypergraph::Edge& e, const SmallVectorInt& jv) : edge(&e), j(jv) {} T yield; const Hypergraph::Edge* const edge; const SmallVectorInt j; const WeightType score; const SparseVector<double> feature_values; }; struct HeapCompare { bool operator()(const Derivation* a, const Derivation* b) const { return a->score < b->score; } }; struct DerivationCompare { bool operator()(const Derivation* a, const Derivation* b) const { return a->score > b->score; } }; struct EdgeHandle { Derivation const* d; explicit EdgeHandle(Derivation const* d) : d(d) { } // operator bool() const { return d->edge; } operator Hypergraph::Edge const* () const { return d->edge; } // Hypergraph::Edge const * operator ->() const { return d->edge; } }; EdgeHandle operator()(int t,int taili,EdgeHandle const& parent) const { return EdgeHandle(nds[t].D[parent.d->j[taili]]); } std::string derivation_tree(Derivation const& d,bool indent=true,int show_mask=Hypergraph::SPAN|Hypergraph::RULE,int maxdepth=0x7FFFFFFF,int depth=0) const { return d.edge->derivation_tree(*this,EdgeHandle(&d),indent,show_mask,maxdepth,depth); } struct DerivationUniquenessHash { size_t operator()(const Derivation* d) const { size_t x = 5381; x = ((x << 5) + x) ^ d->edge->id_; for (int i = 0; i < d->j.size(); ++i) x = ((x << 5) + x) ^ d->j[i]; return x; } }; struct DerivationUniquenessEquals { bool operator()(const Derivation* a, const Derivation* b) const { return (a->edge == b->edge) && (a->j == b->j); } }; typedef std::vector<Derivation*> CandidateHeap; typedef std::vector<Derivation*> DerivationList; typedef std::tr1::unordered_set< const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet; struct NodeDerivationState { CandidateHeap cand; DerivationList D; DerivationFilter filter; UniqueDerivationSet ds; explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {} }; Derivation* LazyKthBest(int v, int k) { NodeDerivationState& s = GetCandidates(v); CandidateHeap& cand = s.cand; DerivationList& D = s.D; DerivationFilter& filter = s.filter; bool add_next = true; while (D.size() <= k) { if (add_next && D.size() > 0) { const Derivation* d = D.back(); LazyNext(d, &cand, &s.ds); } add_next = false; if (cand.size() > 0) { std::pop_heap(cand.begin(), cand.end(), HeapCompare()); Derivation* d = cand.back(); cand.pop_back(); std::vector<const T*> ants(d->edge->Arity()); for (int j = 0; j < ants.size(); ++j) ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield; traverse(*d->edge, ants, &d->yield); if (!filter(d->yield)) { D.push_back(d); add_next = true; } } else { break; } } if (k < D.size()) return D[k]; else return NULL; } private: // creates a derivation object with all fields set but the yield // the yield is computed in LazyKthBest before the derivation is added to D // returns NULL if j refers to derivation numbers larger than the // antecedent structure define Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVectorInt& j) { WeightType score = w(e); SparseVector<double> feats = e.feature_values_; for (int i = 0; i < e.Arity(); ++i) { const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]); if (!ant) { return NULL; } score *= ant->score; feats += ant->feature_values; } freelist.push_back(new Derivation(e, j, score, feats)); return freelist.back(); } NodeDerivationState& GetCandidates(int v) { NodeDerivationState& s = nds[v]; if (!s.D.empty() || !s.cand.empty()) return s; const Hypergraph::Node& node = g.nodes_[v]; for (int i = 0; i < node.in_edges_.size(); ++i) { const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]]; SmallVectorInt jv(edge.Arity(), 0); Derivation* d = CreateDerivation(edge, jv); assert(d); s.cand.push_back(d); } const int effective_k = std::min(k_prime, s.cand.size()); const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k; std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare()); s.cand.resize(effective_k); std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare()); return s; } void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) { for (int i = 0; i < d->j.size(); ++i) { SmallVectorInt j = d->j; ++j[i]; const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]); if (ant) { Derivation query_unique(*d->edge, j); if (ds->count(&query_unique) == 0) { Derivation* new_d = CreateDerivation(*d->edge, j); if (new_d) { cand->push_back(new_d); std::push_heap(cand->begin(), cand->end(), HeapCompare()); bool inserted = ds->insert(new_d).second; // insert into uniqueness set assert(inserted); } } } } } const Traversal traverse; const WeightFunction w; const Hypergraph& g; std::vector<NodeDerivationState> nds; std::vector<Derivation*> freelist; const size_t k_prime; }; } #endif