From 2af5a445f3905c69c42be5c758c52a2f21b17446 Mon Sep 17 00:00:00 2001 From: graehl Date: Tue, 31 Aug 2010 01:08:42 +0000 Subject: l2r bugfixes git-svn-id: https://ws10smt.googlecode.com/svn/trunk@634 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/apply_fsa_models.cc | 86 +++++++++++++++++++++++++++++++++------------ decoder/cfg.h | 10 +++++- 2 files changed, 72 insertions(+), 24 deletions(-) (limited to 'decoder') diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc index dddbddd9..4a928206 100755 --- a/decoder/apply_fsa_models.cc +++ b/decoder/apply_fsa_models.cc @@ -1,13 +1,16 @@ -#include #include "apply_fsa_models.h" +#include +#include +#include +#include + +#include "writer.h" #include "hg.h" #include "ff_fsa_dynamic.h" #include "ff_from_fsa.h" #include "feature_vector.h" #include "stringlib.h" #include "apply_models.h" -#include -#include #include "cfg.h" #include "hg_cfg.h" #include "utoa.h" @@ -16,7 +19,7 @@ #include "d_ary_heap.h" #include "agenda.h" #include "show.h" -#include +#include "string_to.h" #define DFSA(x) x //fsa earley chart @@ -24,7 +27,7 @@ #define DPFSA(x) x //prefix trie -#define DBUILDTRIE(x) x +#define DBUILDTRIE(x) #define PRINT_PREFIX 1 #if PRINT_PREFIX @@ -101,23 +104,54 @@ struct TrieBackP { FsaFeatureFunction const* print_fsa=0; CFG const* print_cfg=0; -inline void print_cfg_rhs(std::ostream &o,WordID w) { - if (print_cfg) - print_cfg->print_rhs_name(o,w); +inline ostream& print_cfg_rhs(std::ostream &o,WordID w,CFG const*pcfg=print_cfg) { + if (pcfg) + pcfg->print_rhs_name(o,w); else CFG::static_print_rhs_name(o,w); + return o; +} + +inline std::string nt_name(WordID n,CFG const*pcfg=print_cfg) { + if (pcfg) return pcfg->nt_name(n); + return CFG::static_nt_name(n); +} + +template +ostream& print_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { + o< "< +ostream& print_map_by_nt(std::ostream &o,V const& v,CFG const*pcfg=print_cfg,char const* header="\nNT -> X\n") { + o<first,pcfg) << " -> "<second<<"\n"; + } + return o; } + struct PrefixTrieEdge { -// PrefixTrieEdge() { } + PrefixTrieEdge() + // : dest(0),w(TD::max_wordid) + {} + PrefixTrieEdge(WordID w,NodeP dest) + : dest(dest),w(w) + {} // explicit PrefixTrieEdge(best_t p) : p(p),dest(0) { } - best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob + + best_t p;// viterbi additional prob, i.e. product over path incl. p_final = total rule prob. note: for final edge, set this. //DPFSA() // we can probably just store deltas, but for debugging remember the full p // best_t delta; // NodeP dest; bool is_final() const { return dest==0; } - WordID w; // for lhs, this will be nonneg NTHandle instead. // not set if is_final() // actually, set to lhs nt index + best_t p_dest() const; + WordID w; // for root and and is_final(), this will be (negated) NTHandle. // for sorting most probable first in adj; actually >(p) inline bool operator <(PrefixTrieEdge const& o) const { @@ -218,7 +252,7 @@ public: for (int i=0,e=adj.size();i!=e;++i) { PrefixTrieEdge const& edge=adj[i]; // assert(edge.p.is_1()); // actually, after done_building, e will have telescoped dest->p/p. - NTHandle n=edge.w; + NTHandle n=-edge.w; assert(n>=0); SHOWM3(DPFSA,"index_lhs",i,edge,n); v[n]=edge.dest; @@ -228,7 +262,10 @@ public: template void done_root(PV &v) { assert(is_root()); + SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_map_by_nt,edge_for)); done_building_r(); //sets adj + SHOWM1(DBUILDTRIE,"done_root",OSTRF1(print_by_nt,adj)); +// SHOWM1(DBUILDTRIE,done_root,adj); // index_adj(); // we want an index for the root node?. don't think so - index_lhs handles it. also we stopped clearing edge_for. index_lhs(v); // uses adj } @@ -244,7 +281,7 @@ public: // for done_building; compute incremental (telescoped) edge p PrefixTrieEdge /*const&*/ operator()(PrefixTrieEdgeFor::value_type & pair) const { PrefixTrieEdge &e=pair.second;//const_cast(pair.second); - e.p=(e.dest->p)/p; + e.p=e.p_dest()/p; return e; } @@ -265,6 +302,7 @@ public: // (*this)(*i); } #endif + SHOWM1(DBUILDTRIE,"done building adj",prange(adj.begin(),adj.end(),true)); assert(adj.size()==edge_for.size()); // if (final) p_final/=p; std::sort(adj.begin(),adj.end()); @@ -287,18 +325,18 @@ public: inline NodeP build(W w,best_t rulep) { return build(lhs,w,rulep); } - inline NodeP build_lhs(NTHandle w,best_t rulep) { - return build(w,w,rulep); + inline NodeP build_lhs(NTHandle n,best_t rulep) { + return build(n,-n,rulep); } NodeP build(NTHandle lhs_,W w,best_t rulep) { PrefixTrieEdgeFor::iterator i=edge_for.find(w); if (i!=edge_for.end()) return improve_edge(i->second,rulep); - PrefixTrieEdge &e=i->second; NodeP r=new PrefixTrieNode(lhs_,rulep); IF_PRINT_PREFIX(r->backp=BP(w,this)); - e.dest=r; +// edge_for.insert(i,PrefixTrieEdgeFor::value_type(w,PrefixTrieEdge(w,r))); + add(edge_for,w,PrefixTrieEdge(w,r)); SHOWM4(DBUILDTRIE,"built node",this,w,*r,r); return r; } @@ -306,7 +344,7 @@ public: void set_final(NTHandle lhs_,best_t pf) { assert(no_adj()); final=true; - PrefixTrieEdge &e=edge_for[-1]; + PrefixTrieEdge &e=edge_for[null_wordid]; e.p=pf; e.dest=0; e.w=lhs_; @@ -335,6 +373,10 @@ public: PRINT_SELF(PrefixTrieNode) }; +inline best_t PrefixTrieEdge::p_dest() const { + return dest ? dest->p : p; // for final edge, p was set (no sentinel node) +} + //Trie starts with lhs (nonneg index), then continues w/ rhs (mixed >0 word, else NT) // trie ends with final edge, which points to a per-lhs prefix node @@ -358,7 +400,9 @@ struct PrefixTrie { SHOWM2(DBUILDTRIE,"PrefixTrie()",rulesp->size(),lhs2.size()); cfg.VisitRuleIds(*this); root.done_root(lhs2); - SHOWM4(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size(),lhs2[0]); + SHOWM3(DBUILDTRIE,"done w/ PrefixTrie: ",root,root.adj.size(),lhs2.size()); + DBUILDTRIE(print_by_nt(cerr,lhs2,cfgp)); + SHOWM1(DBUILDTRIE,"lhs2",OSTRF2(print_by_nt,lhs2,cfgp)); } void operator()(int ri) { @@ -526,12 +570,8 @@ struct Chart { } else { break; } - } - - } - } Chart(CFG &cfg,SentenceMetadata const& smeta,FsaFF const& fsa,unsigned reserve=FSA_AGENDA_RESERVE) diff --git a/decoder/cfg.h b/decoder/cfg.h index 95cb5fd7..8cb29bb9 100755 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -77,8 +77,16 @@ struct CFG { if (w<=0) return nt_name(-w); else return TD::Convert(w); } + static void static_print_nt_name(std::ostream &o,NTHandle n) { + o<<'['<