From c4cb48ad003de65f97e0a6013e9da4329c89faf1 Mon Sep 17 00:00:00 2001 From: graehl Date: Wed, 7 Jul 2010 21:26:51 +0000 Subject: safe hg pruning without needing additional inside reachability pass (max margin tightness is less at bottom of derivation tree) git-svn-id: https://ws10smt.googlecode.com/svn/trunk@181 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/ff_lm.cc | 4 +++- decoder/hg.cc | 29 ++++++++++++++++++----------- decoder/hg.h | 32 ++++++++++++++++++-------------- decoder/logval.h | 16 +++++++++++++--- 4 files changed, 52 insertions(+), 29 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_lm.cc b/decoder/ff_lm.cc index 03dc2054..e6f7912e 100644 --- a/decoder/ff_lm.cc +++ b/decoder/ff_lm.cc @@ -1,3 +1,5 @@ +//TODO: backoff wordclasses for named entity xltns, esp. numbers. e.g. digits -> @. idealy rule features would specify replacement lm tokens/classes + //TODO: extra int in state to hold "GAP" token is not needed. if there are less than (N-1) words, then null terminate the e.g. left words. however, this would mean treating gapless items differently. not worth the potential bugs right now. //TODO: allow features to reorder by heuristic*weight the rules' terminal phrases (or of hyperedges'). if first pass has pruning, then compute over whole ruleset as part of heuristic @@ -311,7 +313,7 @@ class LanguageModelImpl { double sum=0; for (;rend>rbegin;--rend) { sum+=clamp(WordProb(rend[-1],rend)); - UNIDBG(","<* preserve_mask) + const vector* preserve_mask + , bool safe_inside + ) { assert(density >= 1.0); const int plen = ViterbiPathLength(*this); @@ -221,14 +227,15 @@ void Hypergraph::DensityPruneInsideOutside(const double scale, assert(edges_.size() == io.size()); vector sorted = io; nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater()); - Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask); + MarginPrune(io,sorted[rnum],preserve_mask,safe_inside); } void Hypergraph::BeamPruneInsideOutside( const double scale, const bool use_sum_prod_semiring, const double alpha, - const vector* preserve_mask) { + const vector* preserve_mask + ,bool safe_inside) { assert(alpha >= 0.0); vector io(edges_.size()); if (use_sum_prod_semiring) { @@ -240,7 +247,7 @@ void Hypergraph::BeamPruneInsideOutside( prob_t best; // initializes to zero for (int i = 0; i < io.size(); ++i) if (io[i] > best) best = io[i]; - Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask); + MarginPrune(io,best*prob_t::exp(-alpha),preserve_mask,safe_inside); } void Hypergraph::PrintGraphviz() const { diff --git a/decoder/hg.h b/decoder/hg.h index e54b4bcc..ab90650c 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -132,20 +132,23 @@ class Hypergraph { } } + typedef std::vector EdgeProbs; + typedef std::vector EdgeMask; + // computes inside and outside scores for each // edge in the hypergraph // alpha->size = edges_.size = beta->size // returns inside prob of goal node prob_t ComputeEdgePosteriors(double scale, - std::vector* posts) const; + EdgeProbs* posts) const; // find the score of the very best path passing through each edge - prob_t ComputeBestPathThroughEdges(std::vector* posts) const; + prob_t ComputeBestPathThroughEdges(EdgeProbs* posts) const; // create a new hypergraph consisting only of the nodes / edges // in the Viterbi derivation of this hypergraph // if edges is set, use the EdgeSelectEdgeWeightFunction - Hypergraph* CreateViterbiHypergraph(const std::vector* edges = NULL) const; + Hypergraph* CreateViterbiHypergraph(const EdgeMask* edges = NULL) const; // move weights as near to the source as possible, resulting in a // stochastic automaton. ONLY FUNCTIONAL FOR *LATTICES*. @@ -164,20 +167,20 @@ class Hypergraph { void RemoveNoncoaccessibleStates(int goal_node_id = -1); // remove edges from the hypergraph if prune_edge[edge_id] is true - // TODO need to investigate why this shouldn't be run for the forest trans - // case. To investigate, change false to true and see where ftrans crashes - //TODO: what does the above "TODO" comment mean? that PruneEdges can lead to a crash? or that run_inside_algorithm should be false? there definitely is an unsolved bug, see hg.cc - workaround added - void PruneEdges(const std::vector& prune_edge, bool run_inside_algorithm = false); + // note: if run_inside_algorithm is false, then consumers may be unhappy if you pruned nodes that are built on by nodes that are kept. + void PruneEdges(const EdgeMask& prune_edge, bool run_inside_algorithm = false); // for density>=1.0, keep this many times the edges needed for the 1best derivation // if you don't know, use_sum_prod_semiring should be false - void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, - const std::vector* preserve_mask = NULL); + void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density,const EdgeMask* preserve_mask = NULL,bool safe_inside=false); + + /// drop edge i if edge_margin[i] < prune_below, unless preserve_mask[i] + void MarginPrune(EdgeProbs const& edge_margin,prob_t prune_below,EdgeMask const* preserve_mask=0,bool safe_inside=false,bool verbose=false); + // safe_inside: if true, a theoretically redundant (but practically important .001% of the time due to rounding error) inside pruning pass will happen after max-marginal pruning. if you don't do this, it's possible that the pruned hypergraph will contain outside-reachable (but not inside-buildable) nodes. that is, a parent will be kept whose children were pruned. if output, those forests may confuse (crash) e.g. mr_vest_map. however, if the hyperedges occur in defined-before-use (all edges with head h occur before h is used as a tail) order, then a grace margin for keeping edges that starts leniently and becomes more forbidding will make it impossible for this to occur, i.e. safe_inside=true is not needed. // prunes any edge whose score on the best path taking that edge is more than alpha away // from the score of the global best past (or the highest edge posterior) - void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha, - const std::vector* preserve_mask = NULL); + void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha,const EdgeMask* preserve_mask = NULL,bool safe_inside=false); // report nodes, edges, paths std::string stats(std::string const& name="forest") const; @@ -204,7 +207,7 @@ class Hypergraph { // reorder nodes_ so they are in topological order // source nodes at 0 sink nodes at size-1 void TopologicallySortNodesAndEdges(int goal_idx, - const std::vector* prune_edges = NULL); + const EdgeMask* prune_edges = NULL); private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges) {} @@ -219,13 +222,14 @@ struct EdgeProb { }; struct EdgeSelectEdgeWeightFunction { - EdgeSelectEdgeWeightFunction(const std::vector& v) : v_(v) {} + typedef std::vector EdgeMask; + EdgeSelectEdgeWeightFunction(const EdgeMask& v) : v_(v) {} inline prob_t operator()(const Hypergraph::Edge& e) const { if (v_[e.id_]) return prob_t::One(); else return prob_t::Zero(); } private: - const std::vector& v_; + const EdgeMask& v_; }; struct ScaledEdgeProb { diff --git a/decoder/logval.h b/decoder/logval.h index 622b308e..457818e7 100644 --- a/decoder/logval.h +++ b/decoder/logval.h @@ -1,6 +1,8 @@ #ifndef LOGVAL_H_ #define LOGVAL_H_ +#define LOGVAL_CHECK_NEG_POW false + #include #include #include @@ -11,9 +13,12 @@ class LogVal { public: LogVal() : s_(), v_(-std::numeric_limits::infinity()) {} explicit LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} + LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {} + static LogVal exp(T lnx) { return LogVal(lnx,false); } + static LogVal One() { return LogVal(1); } static LogVal Zero() { return LogVal(); } - + static LogVal e() { return LogVal(1,false); } void logeq(const T& v) { s_ = false; v_ = v; } LogVal& operator+=(const LogVal& a) { @@ -54,12 +59,13 @@ class LogVal { } LogVal& poweq(const T& power) { +#if LOGVAL_CHECK_NEG_POW if (s_) { std::cerr << "poweq(T) not implemented when s_ is true\n"; std::abort(); - } else { + } else +#endif v_ *= power; - } return *this; } @@ -71,6 +77,10 @@ class LogVal { return res; } + LogVal root(const T& root) const { + return pow(1/root); + } + operator T() const { if (s_) return -std::exp(v_); else return std::exp(v_); } -- cgit v1.2.3