From ca9c1f40cad1f99f00beb2871dc50bf7222d44d4 Mon Sep 17 00:00:00 2001 From: graehl Date: Sat, 21 Aug 2010 03:07:42 +0000 Subject: agenda for fsa git-svn-id: https://ws10smt.googlecode.com/svn/trunk@612 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/apply_fsa_models.cc | 158 ++++++++++++++++++++++++++++++++++---------- decoder/cfg.cc | 21 ++++-- 2 files changed, 136 insertions(+), 43 deletions(-) (limited to 'decoder') diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc index 7cd5fc6d..f9c94ec3 100755 --- a/decoder/apply_fsa_models.cc +++ b/decoder/apply_fsa_models.cc @@ -13,6 +13,8 @@ #include "utoa.h" #include "hash.h" #include "value_array.h" +#include "d_ary_heap.h" +#include "agenda.h" #define DFSA(x) x #define DPFSA(x) x @@ -72,12 +74,15 @@ struct get_second { struct PrefixTrieNode; struct PrefixTrieEdge { +// PrefixTrieEdge() { } +// explicit PrefixTrieEdge(prob_t p) : p(p),dest(0) { } prob_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob //DPFSA() // we can probably just store deltas, but for debugging remember the full p // prob_t delta; // PrefixTrieNode *dest; - WordID w; // for lhs, this will be positive NTHandle instead + bool is_final() const { return dest==0; } + WordID w; // for lhs, this will be nonneg NTHandle instead. // not set if is_final() // actually, set to lhs nt index // for sorting most probable first in adj; actually >(p) inline bool operator <(PrefixTrieEdge const& o) const { @@ -88,33 +93,55 @@ struct PrefixTrieEdge { struct PrefixTrieNode { prob_t p; // viterbi (max prob) of rule this node leads to - when building. telescope later onto edges for best-first. #if TRIE_START_LHS - bool final; // may also have successors, of course - prob_t p_final; // additional prob beyond what we already paid. while building, this is the total prob +// bool final; // may also have successors, of course. we don't really need to track this; a null dest edge in the adj list lets us encounter the fact in best first order. +// prob_t p_final; // additional prob beyond what we already paid. while building, this is the total prob +// instead of storing final, we'll say that an edge with a NULL dest is a final edge. this way it gets sorted into the list of adj. + // instead of completed map, we have trie start w/ lhs. - NTHandle lhs; // instead of storing this in Item. + NTHandle lhs; // nonneg. - instead of storing this in Item. #else typedef FSA_MAP(LHS,RuleHandle) Completed; // can only have one rule w/ a given signature (duplicates should be collapsed when making CFG). but there may be multiple rules, with different LHS Completed completed; #endif - explicit PrefixTrieNode(prob_t p=1) : p(p),final(false) { } + enum { ROOT=-1 }; + explicit PrefixTrieNode(NTHandle lhs=ROOT,prob_t p=1) : p(p),lhs(lhs) { + //final=false; + } + bool is_root() const { return lhs==ROOT; } // means adj are the nonneg lhs indices, and we have the index edge_for still available // outgoing edges will be ordered highest p to worst p typedef FSA_MAP(WordID,PrefixTrieEdge) PrefixTrieEdgeFor; public: PrefixTrieEdgeFor edge_for; //TODO: move builder elsewhere? then need 2nd hash or edge include pointer to builder. just clear this later + bool have_adj() const { + return adj.size()>=edge_for.size(); + } + bool no_adj() const { + return adj.empty(); + } + void index_adj() { index_adj(edge_for); } - template void index_adj(M &m) { + assert(have_adj()); m.clear(); for (int i=0;i + void index_root(PV &v) { + v.resize(adj.size()); + for (int i=0,e=adj.size();i!=e;++i) { + PrefixTrieEdge const& e=adj[i]; + // assert(e.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p. + v[e.w]=e.dest; + } + } // call only once. void done_building_r() { @@ -124,18 +151,18 @@ public: } // for done_building; compute incremental (telescoped) edge p - PrefixTrieEdge const& operator()(PrefixTrieEdgeFor::value_type &pair) const { - PrefixTrieEdge &e=pair.second; + PrefixTrieEdge const& operator()(PrefixTrieEdgeFor::value_type const& pair) const { + PrefixTrieEdge &e=const_cast(pair.second); e.p=(e.dest->p)/p; return e; } // call only once. void done_building() { - adj.reinit_map(edge_for.begin(),edge_for.end(),*this); - if (final) - p_final/=p; + adj.reinit_map(edge_for,*this); +// if (final) p_final/=p; std::sort(adj.begin(),adj.end()); + //TODO: store adjacent differences on edges (compared to } typedef ValueArray Adj; @@ -143,8 +170,6 @@ public: Adj adj; typedef WordID W; - typedef NTHandle N; // not negative - typedef W const* RI; // let's compute p_min so that every rule reachable from the created node has p at least this low. PrefixTrieNode *improve_edge(PrefixTrieEdge const& e,prob_t rulep) { @@ -153,28 +178,46 @@ public: return d; } - PrefixTrieNode *build(W w,prob_t rulep) { + inline PrefixTrieNode *build(W w,prob_t rulep) { + return build(lhs,w,rulep); + } + inline PrefixTrieNode *build_lhs(NTHandle w,prob_t rulep) { + return build(w,w,rulep); + } + + PrefixTrieNode *build(NTHandle lhs_,W w,prob_t rulep) { PrefixTrieEdgeFor::iterator i=edge_for.find(w); if (i!=edge_for.end()) return improve_edge(i->second,rulep); PrefixTrieEdge &e=edge_for[w]; - return e.dest=new PrefixTrieNode(rulep); + return e.dest=new PrefixTrieNode(lhs_,rulep); } - void set_final(prob_t pf) { - final=true;p_final=pf; + void set_final(NTHandle lhs_,prob_t pf) { + assert(no_adj()); +// final=true; // don't really need to track this. + PrefixTrieEdge &e=edge_for[-1]; + e.p=pf; + e.dest=0; + e.w=lhs_; + if (pf>p) + p=pf; } -#ifdef HAVE_TAIL_RECURSE - // add string i...end - void build(RI i,RI end, prob_t rulep) { - if (i==end) { - set_final(rulep); - } else - // tail recursion: - build(*i)->build(i+1,end,rulep); +private: + void destroy_children() { + assert(adj.size()>=edge_for.size()); + for (int i=0,e=adj.size();i(root).build(r.lhs,p); + PrefixTrieNode *n=const_cast(root).build_lhs(lhs,p); for (RHS::const_iterator i=r.rhs.begin(),e=r.rhs.end();;++i) { if (i==e) { - n->set_final(p); + n->set_final(lhs,p); break; } n=n->build(*i,p); } -#ifdef HAVE_TAIL_RECURSE - root.build(r.lhs,r.p)->build(r.rhs,r.p); -#endif +// root.build(lhs,r.p)->build(r.rhs,r.p); } + + }; -// these should go in a global best-first queue +typedef std::size_t ItemHash; + struct Item { - prob_t forward; + explicit Item(PrefixTrieNode *dot,int next=0) : dot(dot),next(next) { } + PrefixTrieNode *dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y. for a lhs A -> . *, this will point to lh2[A] + int next; // index of dot->adj to complete (if dest==0), or predict (if NT), or scan (if word) + NTHandle lhs() const { return dot->lhs; } + inline ItemHash hash() const { + return GOLDEN_MEAN_FRACTION*next^((ItemHash)dot>>4); // / sizeof(PrefixTrieNode), approx., i.e. lower order bits of ptr are nonrandom + } +}; + +inline ItemHash hash_value(Item const& x) { + return x.hash(); +} + +Item null_item((PrefixTrieNode*)0); + +// these should go in a global best-first queue +struct ItemP { + ItemP() : forward(init_0()),inner(init_0()) { } + prob_t forward; // includes inner prob. // NOTE: sum = viterbi (max) /* The forward probability alpha_i(X[k]->x.y) is the sum of the probabilities of all constrained paths of length that end in state X[k]->x.y*/ prob_t inner; /* The inner probability beta_i(X[k]->x.y) is the sum of the probabilities of all paths of length i-k that start in state X[k,k]->.xy and end in X[k,i]->x.y, and generate the input symbols x[k,...,i-1] */ - PrefixTrieNode *dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y - NTHandle lhs() const { return dot->lhs; } }; +struct Chart { + //Agenda a; + //typedef HASH_MAP(Item,ItemP,boost::hash) Items; + //typedef Items::iterator FindItem; + //typedef std::pair InsertItem; +// Items items; + CFG &cfg; // TODO: remove this from Chart + NTHandle goal_nt; + PrefixTrie trie; + typedef std::vector LhsToTrie; // will have to check lhs2[lhs].p for best cost of some rule with that lhs, then use edge deltas after? they're just caching a very cheap computation, really + LhsToTrie lhs2; // no reason to use a map or hash table; every NT in the CFG will have some rule rhses. lhs_to_trie[i]=root.edge_for[i], i.e. we still have a root trie node conceptually, we just access through this since it's faster. + + void enqueue(Item const& item,ItemP const& p) { +// FindItem f=items.find(item); +// if (f==items.end()) ; + + } + + Chart(CFG &cfg) :cfg(cfg),trie(cfg) { + goal_nt=cfg.goal_nt; + trie.root.index_root(lhs2); + } +}; + + }//anon ns @@ -279,7 +365,7 @@ void ApplyFsa::ApplyBottomUp() void ApplyFsa::ApplyEarley() { hgcfg.GiveCFG(cfg); - PrefixTrie rt(cfg); + Chart chart(cfg); // don't need to uniq - option to do that already exists in cfg_options //TODO: } diff --git a/decoder/cfg.cc b/decoder/cfg.cc index b2219193..651978d2 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -8,6 +8,7 @@ #include "fast_lexical_cast.hpp" //#include "indices_after.h" #include "show.h" +#include "null_traits.h" #define DUNIQ(x) x #define DBIN(x) @@ -31,6 +32,7 @@ using namespace std; typedef CFG::Rule Rule; typedef CFG::NTOrder NTOrder; typedef CFG::RHS RHS; +typedef CFG::BinRhs BinRhs; /////index ruleids: void CFG::UnindexRules() { @@ -166,11 +168,11 @@ void CFG::SortLocalBestFirst(NTHandle ni) { /////binarization: namespace { -CFG::BinRhs null_bin_rhs(std::numeric_limits::min(),std::numeric_limits::min()); +BinRhs null_bin_rhs(std::numeric_limits::min(),std::numeric_limits::min()); // index i >= N.size()? then it's in M[i-N.size()] //WordID first,WordID second, -string BinStr(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) +string BinStr(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) { int nn=N.size(); ostringstream o; @@ -203,7 +205,7 @@ string BinStr(RHS const& r,CFG::NTs const& N,CFG::NTs const& M) } -WordID BinName(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) +WordID BinName(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) { return TD::Convert(BinStr(b,N,M)); } @@ -213,22 +215,27 @@ WordID BinName(RHS const& b,CFG::NTs const& N,CFG::NTs const& M) return TD::Convert(BinStr(b,N,M)); } +/* template struct null_for; -typedef CFG::BinRhs BinRhs; template <> struct null_for { static BinRhs null; }; -BinRhs null_for::null(std::numeric_limits::min(),std::numeric_limits::min()); template <> struct null_for { static RHS null; }; -RHS null_for::null(1,std::numeric_limits::min()); +*/ + +template <> +BinRhs null_traits::null(std::numeric_limits::min(),std::numeric_limits::min()); + +template <> +RHS null_traits::null(1,std::numeric_limits::min()); template struct add_virtual_rules { @@ -243,7 +250,7 @@ struct add_virtual_rules { R2L rhs2lhs; // an rhs maps to this -virtntid, or original id if length 1 bool name_nts; add_virtual_rules(CFG &cfg,bool name_nts=false) : nts(cfg.nts),rules(cfg.rules),newnt(-nts.size()),newruleid(rules.size()),name_nts(name_nts) { - HASH_MAP_EMPTY(rhs2lhs,null_for::null); + HASH_MAP_EMPTY(rhs2lhs,null_traits::null); } NTHandle get_virt(Rhs const& r) { NTHandle nt=get_default(rhs2lhs,r,newnt); -- cgit v1.2.3