diff options
| author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-21 03:07:42 +0000 | 
|---|---|---|
| committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-21 03:07:42 +0000 | 
| commit | ca9c1f40cad1f99f00beb2871dc50bf7222d44d4 (patch) | |
| tree | 183f19411904bb2a23cc5f916f1887a484c6574b /decoder | |
| parent | d8dbfdcc460754bd5f45182495ff14b39b94b24d (diff) | |
agenda for fsa
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@612 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
| -rwxr-xr-x | decoder/apply_fsa_models.cc | 158 | ||||
| -rwxr-xr-x | decoder/cfg.cc | 21 | 
2 files changed, 136 insertions, 43 deletions
| diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc index 7cd5fc6d..f9c94ec3 100755 --- a/decoder/apply_fsa_models.cc +++ b/decoder/apply_fsa_models.cc @@ -13,6 +13,8 @@  #include "utoa.h"  #include "hash.h"  #include "value_array.h" +#include "d_ary_heap.h" +#include "agenda.h"  #define DFSA(x) x  #define DPFSA(x) x @@ -72,12 +74,15 @@ struct get_second {  struct PrefixTrieNode;  struct PrefixTrieEdge { +//  PrefixTrieEdge() {  } +//  explicit PrefixTrieEdge(prob_t p) : p(p),dest(0) {  }    prob_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob    //DPFSA()    // we can probably just store deltas, but for debugging remember the full p    //    prob_t delta; //    PrefixTrieNode *dest; -  WordID w; // for lhs, this will be positive NTHandle instead +  bool is_final() const { return dest==0; } +  WordID w; // for lhs, this will be nonneg NTHandle instead.  //  not set if is_final() // actually, set to lhs nt index    // for sorting most probable first in adj; actually >(p)    inline bool operator <(PrefixTrieEdge const& o) const { @@ -88,33 +93,55 @@ struct PrefixTrieEdge {  struct PrefixTrieNode {    prob_t p; // viterbi (max prob) of rule this node leads to - when building.  telescope later onto edges for best-first.  #if TRIE_START_LHS -  bool final; // may also have successors, of course -  prob_t p_final; // additional prob beyond what we already paid. while building, this is the total prob +//  bool final; // may also have successors, of course.  we don't really need to track this; a null dest edge in the adj list lets us encounter the fact in best first order. +//  prob_t p_final; // additional prob beyond what we already paid. while building, this is the total prob +// instead of storing final, we'll say that an edge with a NULL dest is a final edge.  this way it gets sorted into the list of adj. +    // instead of completed map, we have trie start w/ lhs. -  NTHandle lhs; // instead of storing this in Item. +  NTHandle lhs; // nonneg. - instead of storing this in Item.  #else    typedef FSA_MAP(LHS,RuleHandle) Completed; // can only have one rule w/ a given signature (duplicates should be collapsed when making CFG).  but there may be multiple rules, with different LHS    Completed completed;  #endif -  explicit PrefixTrieNode(prob_t p=1) : p(p),final(false) {  } +  enum { ROOT=-1 }; +  explicit PrefixTrieNode(NTHandle lhs=ROOT,prob_t p=1) : p(p),lhs(lhs) { +    //final=false; +  } +  bool is_root() const { return lhs==ROOT; } // means adj are the nonneg lhs indices, and we have the index edge_for still available    // outgoing edges will be ordered highest p to worst p    typedef FSA_MAP(WordID,PrefixTrieEdge) PrefixTrieEdgeFor;  public:    PrefixTrieEdgeFor edge_for; //TODO: move builder elsewhere?  then need 2nd hash or edge include pointer to builder.  just clear this later +  bool have_adj() const { +    return adj.size()>=edge_for.size(); +  } +  bool no_adj() const { +    return adj.empty(); +  } +    void index_adj() {      index_adj(edge_for);    } -    template <class M>    void index_adj(M &m) { +    assert(have_adj());      m.clear();      for (int i=0;i<adj.size();++i) {        PrefixTrieEdge const& e=adj[i];        m[e.w]=e;      }    } +  template <class PV> +  void index_root(PV &v) { +    v.resize(adj.size()); +    for (int i=0,e=adj.size();i!=e;++i) { +      PrefixTrieEdge const& e=adj[i]; +      // assert(e.p.is_1());  // actually, after done_building, e will have telescoped dest->p/p. +      v[e.w]=e.dest; +    } +  }    // call only once.    void done_building_r() { @@ -124,18 +151,18 @@ public:    }    // for done_building; compute incremental (telescoped) edge p -  PrefixTrieEdge const& operator()(PrefixTrieEdgeFor::value_type &pair) const { -    PrefixTrieEdge &e=pair.second; +  PrefixTrieEdge const& operator()(PrefixTrieEdgeFor::value_type const& pair) const { +    PrefixTrieEdge &e=const_cast<PrefixTrieEdge&>(pair.second);      e.p=(e.dest->p)/p;      return e;    }    // call only once.    void done_building() { -    adj.reinit_map(edge_for.begin(),edge_for.end(),*this); -    if (final) -      p_final/=p; +    adj.reinit_map(edge_for,*this); +//    if (final) p_final/=p;      std::sort(adj.begin(),adj.end()); +    //TODO: store adjacent differences on edges (compared to    }    typedef ValueArray<PrefixTrieEdge>  Adj; @@ -143,8 +170,6 @@ public:    Adj adj;    typedef WordID W; -  typedef NTHandle N; // not negative -  typedef W const* RI;    // let's compute p_min so that every rule reachable from the created node has p at least this low.    PrefixTrieNode *improve_edge(PrefixTrieEdge const& e,prob_t rulep) { @@ -153,28 +178,46 @@ public:      return d;    } -  PrefixTrieNode *build(W w,prob_t rulep) { +  inline PrefixTrieNode *build(W w,prob_t rulep) { +    return build(lhs,w,rulep); +  } +  inline PrefixTrieNode *build_lhs(NTHandle w,prob_t rulep) { +    return build(w,w,rulep); +  } + +  PrefixTrieNode *build(NTHandle lhs_,W w,prob_t rulep) {      PrefixTrieEdgeFor::iterator i=edge_for.find(w);      if (i!=edge_for.end())        return improve_edge(i->second,rulep);      PrefixTrieEdge &e=edge_for[w]; -    return e.dest=new PrefixTrieNode(rulep); +    return e.dest=new PrefixTrieNode(lhs_,rulep);    } -  void set_final(prob_t pf) { -    final=true;p_final=pf; +  void set_final(NTHandle lhs_,prob_t pf) { +    assert(no_adj()); +//    final=true; // don't really need to track this. +    PrefixTrieEdge &e=edge_for[-1]; +    e.p=pf; +    e.dest=0; +    e.w=lhs_; +    if (pf>p) +      p=pf;    } -#ifdef HAVE_TAIL_RECURSE -  // add string i...end -  void build(RI i,RI end, prob_t rulep) { -    if (i==end) { -      set_final(rulep); -    } else -    // tail recursion: -      build(*i)->build(i+1,end,rulep); +private: +  void destroy_children() { +    assert(adj.size()>=edge_for.size()); +    for (int i=0,e=adj.size();i<e;++i) { +      PrefixTrieNode *c=adj[i].dest; +      if (c) { // final state has no end +        delete c; +      } +    } +  } +public: +  ~PrefixTrieNode() { +    destroy_children();    } -#endif  };  #if TRIE_START_LHS @@ -200,34 +243,77 @@ struct PrefixTrie {    }    void operator()(int ri) const {      Rule const& r=rules()[ri]; +    NTHandle lhs=r.lhs;      prob_t p=r.p; -    PrefixTrieNode *n=const_cast<PrefixTrieNode&>(root).build(r.lhs,p); +    PrefixTrieNode *n=const_cast<PrefixTrieNode&>(root).build_lhs(lhs,p);      for (RHS::const_iterator i=r.rhs.begin(),e=r.rhs.end();;++i) {        if (i==e) { -        n->set_final(p); +        n->set_final(lhs,p);          break;        }        n=n->build(*i,p);      } -#ifdef HAVE_TAIL_RECURSE -    root.build(r.lhs,r.p)->build(r.rhs,r.p); -#endif +//    root.build(lhs,r.p)->build(r.rhs,r.p);    } + +  }; -// these should go in a global best-first queue +typedef std::size_t ItemHash; +  struct Item { -  prob_t forward; +  explicit Item(PrefixTrieNode *dot,int next=0) : dot(dot),next(next) {  } +  PrefixTrieNode *dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y.  for a lhs A -> . *, this will point to lh2[A] +  int next; // index of dot->adj to complete (if dest==0), or predict (if NT), or scan (if word) +  NTHandle lhs() const { return dot->lhs; } +  inline ItemHash hash() const { +    return GOLDEN_MEAN_FRACTION*next^((ItemHash)dot>>4); // / sizeof(PrefixTrieNode), approx., i.e. lower order bits of ptr are nonrandom +  } +}; + +inline ItemHash hash_value(Item const& x) { +  return x.hash(); +} + +Item null_item((PrefixTrieNode*)0); + +// these should go in a global best-first queue +struct ItemP { +  ItemP() : forward(init_0()),inner(init_0()) {  } +  prob_t forward; // includes inner prob.    // NOTE: sum = viterbi (max)    /* The forward probability alpha_i(X[k]->x.y) is the sum of the probabilities of all       constrained paths of length that end in state X[k]->x.y*/    prob_t inner;    /* The inner probability beta_i(X[k]->x.y) is the sum of the probabilities of all       paths of length i-k that start in state X[k,k]->.xy and end in X[k,i]->x.y, and generate the input symbols x[k,...,i-1] */ -  PrefixTrieNode *dot; // dot is a function of the stuff already recognized, and gives a set of suffixes y to complete to finish a rhs for lhs() -> dot y -  NTHandle lhs() const { return dot->lhs; }  }; +struct Chart { +  //Agenda<Item> a; +  //typedef HASH_MAP(Item,ItemP,boost::hash<Item>) Items; +  //typedef Items::iterator FindItem; +  //typedef std::pair<FindItem,bool> InsertItem; +//  Items items; +  CFG &cfg; // TODO: remove this from Chart +  NTHandle goal_nt; +  PrefixTrie trie; +  typedef std::vector<PrefixTrieNode *> LhsToTrie; // will have to check lhs2[lhs].p for best cost of some rule with that lhs, then use edge deltas after?  they're just caching a very cheap computation, really +  LhsToTrie lhs2; // no reason to use a map or hash table; every NT in the CFG will have some rule rhses.  lhs_to_trie[i]=root.edge_for[i], i.e. we still have a root trie node conceptually, we just access through this since it's faster. + +  void enqueue(Item const& item,ItemP const& p) { +//    FindItem f=items.find(item); +//    if (f==items.end()) ; + +  } + +  Chart(CFG &cfg) :cfg(cfg),trie(cfg) { +    goal_nt=cfg.goal_nt; +    trie.root.index_root(lhs2); +  } +}; + +  }//anon ns @@ -279,7 +365,7 @@ void ApplyFsa::ApplyBottomUp()  void ApplyFsa::ApplyEarley()  {    hgcfg.GiveCFG(cfg); -  PrefixTrie rt(cfg); +  Chart chart(cfg);    // don't need to uniq - option to do that already exists in cfg_options    //TODO:  } diff --git a/decoder/cfg.cc b/decoder/cfg.cc index b2219193..651978d2 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -8,6 +8,7 @@  #include "fast_lexical_cast.hpp"  //#include "indices_after.h"  #include "show.h" +#include "null_traits.h"  #define DUNIQ(x) x  #define DBIN(x) @@ -31,6 +32,7 @@ using namespace std;  typedef CFG::Rule Rule;  typedef CFG::NTOrder NTOrder;  typedef CFG::RHS RHS; +typedef CFG::BinRhs BinRhs;  /////index ruleids:  void CFG::UnindexRules() { @@ -166,11 +168,11 @@ void CFG::SortLocalBestFirst(NTHandle ni) {  /////binarization:  namespace { -CFG::BinRhs null_bin_rhs(std::numeric_limits<int>::min(),std::numeric_limits<int>::min()); +BinRhs null_bin_rhs(std::numeric_limits<int>::min(),std::numeric_limits<int>::min());  // index i >= N.size()?  then it's in M[i-N.size()]  //WordID first,WordID second, -string BinStr(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) +string BinStr(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M)  {    int nn=N.size();    ostringstream o; @@ -203,7 +205,7 @@ string BinStr(RHS const& r,CFG::NTs const& N,CFG::NTs const& M)  } -WordID BinName(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) +WordID BinName(BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M)  {    return TD::Convert(BinStr(b,N,M));  } @@ -213,22 +215,27 @@ WordID BinName(RHS const& b,CFG::NTs const& N,CFG::NTs const& M)    return TD::Convert(BinStr(b,N,M));  } +/*  template <class Rhs>  struct null_for; -typedef CFG::BinRhs BinRhs;  template <>  struct null_for<BinRhs> {    static BinRhs null;  }; -BinRhs null_for<BinRhs>::null(std::numeric_limits<int>::min(),std::numeric_limits<int>::min());  template <>  struct null_for<RHS> {    static RHS null;  }; -RHS null_for<RHS>::null(1,std::numeric_limits<int>::min()); +*/ + +template <> +BinRhs null_traits<BinRhs>::null(std::numeric_limits<int>::min(),std::numeric_limits<int>::min()); + +template <> +RHS null_traits<RHS>::null(1,std::numeric_limits<int>::min());  template <class Rhs>  struct add_virtual_rules { @@ -243,7 +250,7 @@ struct add_virtual_rules {    R2L rhs2lhs; // an rhs maps to this -virtntid, or original id if length 1    bool name_nts;    add_virtual_rules(CFG &cfg,bool name_nts=false) : nts(cfg.nts),rules(cfg.rules),newnt(-nts.size()),newruleid(rules.size()),name_nts(name_nts) { -    HASH_MAP_EMPTY(rhs2lhs,null_for<Rhs>::null); +    HASH_MAP_EMPTY(rhs2lhs,null_traits<Rhs>::null);    }    NTHandle get_virt(Rhs const& r) {      NTHandle nt=get_default(rhs2lhs,r,newnt); | 
