diff options
Diffstat (limited to 'decoder/kbest.h')
-rw-r--r-- | decoder/kbest.h | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/decoder/kbest.h b/decoder/kbest.h new file mode 100644 index 00000000..cd9b6c2b --- /dev/null +++ b/decoder/kbest.h @@ -0,0 +1,207 @@ +#ifndef _HG_KBEST_H_ +#define _HG_KBEST_H_ + +#include <vector> +#include <utility> +#include <tr1/unordered_set> + +#include <boost/shared_ptr.hpp> + +#include "wordid.h" +#include "hg.h" + +namespace KBest { + // default, don't filter any derivations from the k-best list + struct NoFilter { + bool operator()(const std::vector<WordID>& yield) { + (void) yield; + return false; + } + }; + + // optional, filter unique yield strings + struct FilterUnique { + std::tr1::unordered_set<std::vector<WordID>, boost::hash<std::vector<WordID> > > unique; + + bool operator()(const std::vector<WordID>& yield) { + return !unique.insert(yield).second; + } + }; + + // utility class to lazily create the k-best derivations from a forest, uses + // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005) + template<typename T, // yield type (returned by Traversal) + typename Traversal, + typename DerivationFilter = NoFilter, + typename WeightType = prob_t, + typename WeightFunction = EdgeProb> + struct KBestDerivations { + KBestDerivations(const Hypergraph& hg, + const size_t k, + const Traversal& tf = Traversal(), + const WeightFunction& wf = WeightFunction()) : + traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {} + + ~KBestDerivations() { + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + struct Derivation { + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv, + const WeightType& w, + const SparseVector<double>& f) : + edge(&e), + j(jv), + score(w), + feature_values(f) {} + + // dummy constructor, just for query + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv) : edge(&e), j(jv) {} + + T yield; + const Hypergraph::Edge* const edge; + const SmallVector j; + const WeightType score; + const SparseVector<double> feature_values; + }; + struct HeapCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score < b->score; + } + }; + struct DerivationCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score > b->score; + } + }; + struct DerivationUniquenessHash { + size_t operator()(const Derivation* d) const { + size_t x = 5381; + x = ((x << 5) + x) ^ d->edge->id_; + for (int i = 0; i < d->j.size(); ++i) + x = ((x << 5) + x) ^ d->j[i]; + return x; + } + }; + struct DerivationUniquenessEquals { + bool operator()(const Derivation* a, const Derivation* b) const { + return (a->edge == b->edge) && (a->j == b->j); + } + }; + typedef std::vector<Derivation*> CandidateHeap; + typedef std::vector<Derivation*> DerivationList; + typedef std::tr1::unordered_set< + const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet; + + struct NodeDerivationState { + CandidateHeap cand; + DerivationList D; + DerivationFilter filter; + UniqueDerivationSet ds; + explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {} + }; + + Derivation* LazyKthBest(int v, int k) { + NodeDerivationState& s = GetCandidates(v); + CandidateHeap& cand = s.cand; + DerivationList& D = s.D; + DerivationFilter& filter = s.filter; + bool add_next = true; + while (D.size() <= k) { + if (add_next && D.size() > 0) { + const Derivation* d = D.back(); + LazyNext(d, &cand, &s.ds); + } + add_next = false; + + if (cand.size() > 0) { + std::pop_heap(cand.begin(), cand.end(), HeapCompare()); + Derivation* d = cand.back(); + cand.pop_back(); + std::vector<const T*> ants(d->edge->Arity()); + for (int j = 0; j < ants.size(); ++j) + ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield; + traverse(*d->edge, ants, &d->yield); + if (!filter(d->yield)) { + D.push_back(d); + add_next = true; + } + } else { + break; + } + } + if (k < D.size()) return D[k]; else return NULL; + } + + private: + // creates a derivation object with all fields set but the yield + // the yield is computed in LazyKthBest before the derivation is added to D + // returns NULL if j refers to derivation numbers larger than the + // antecedent structure define + Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVector& j) { + WeightType score = w(e); + SparseVector<double> feats = e.feature_values_; + for (int i = 0; i < e.Arity(); ++i) { + const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]); + if (!ant) { return NULL; } + score *= ant->score; + feats += ant->feature_values; + } + freelist.push_back(new Derivation(e, j, score, feats)); + return freelist.back(); + } + + NodeDerivationState& GetCandidates(int v) { + NodeDerivationState& s = nds[v]; + if (!s.D.empty() || !s.cand.empty()) return s; + + const Hypergraph::Node& node = g.nodes_[v]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]]; + SmallVector jv(edge.Arity(), 0); + Derivation* d = CreateDerivation(edge, jv); + assert(d); + s.cand.push_back(d); + } + + const int effective_k = std::min(k_prime, s.cand.size()); + const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k; + std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare()); + s.cand.resize(effective_k); + std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare()); + + return s; + } + + void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) { + for (int i = 0; i < d->j.size(); ++i) { + SmallVector j = d->j; + ++j[i]; + const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]); + if (ant) { + Derivation query_unique(*d->edge, j); + if (ds->count(&query_unique) == 0) { + Derivation* new_d = CreateDerivation(*d->edge, j); + if (new_d) { + cand->push_back(new_d); + std::push_heap(cand->begin(), cand->end(), HeapCompare()); + assert(ds->insert(new_d).second); // insert into uniqueness set, sanity check + } + } + } + } + } + + const Traversal traverse; + const WeightFunction w; + const Hypergraph& g; + std::vector<NodeDerivationState> nds; + std::vector<Derivation*> freelist; + const size_t k_prime; + }; +} + +#endif |