From 6cc769d102bfcf87822ceeb499cf45ff1e79e5f6 Mon Sep 17 00:00:00 2001 From: "graehl@gmail.com" Date: Mon, 16 Aug 2010 09:11:03 +0000 Subject: greedy binarization - needs testing, may have broke l2r git-svn-id: https://ws10smt.googlecode.com/svn/trunk@560 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/cfg.cc | 297 ++++++++++++++++++++++++++++++++++++++++++------- decoder/cfg.h | 4 + decoder/cfg_binarize.h | 35 ++++-- decoder/hg.h | 1 + 4 files changed, 286 insertions(+), 51 deletions(-) (limited to 'decoder') diff --git a/decoder/cfg.cc b/decoder/cfg.cc index 3076c75e..0fbb6a03 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -11,8 +11,20 @@ #define DUNIQ(x) x #define DBIN(x) +#define DSP(x) x +//SP:binarize by splitting. #define DCFG(x) IF_CFG_DEBUG(x) +#undef CFG_FOR_RULES +#define CFG_FOR_RULES(i,expr) \ + for (CFG::NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) { \ + NT const& nt=*n; \ + for (CFG::Ruleids::const_iterator ir=nt.ruleids.begin(),er=nt.ruleids.end();ir!=er;++ir) { \ + RuleHandle i=*ir; \ + expr; \ + } \ + } + using namespace std; @@ -178,75 +190,286 @@ string BinStr(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) return o.str(); } +string BinStr(RHS const& r,CFG::NTs const& N,CFG::NTs const& M) +{ + int nn=N.size(); + ostringstream o; + for (int i=0,e=r.size();i!=e;++i) { + if (i) + o<<'+'; + BinNameOWORD(r[i]); + } + return o.str(); +} + + WordID BinName(CFG::BinRhs const& b,CFG::NTs const& N,CFG::NTs const& M) { return TD::Convert(BinStr(b,N,M)); } +WordID BinName(RHS const& b,CFG::NTs const& N,CFG::NTs const& M) +{ + return TD::Convert(BinStr(b,N,M)); +} + +template +struct add_virtual_rules { + typedef CFG::RuleHandle RuleHandle; + typedef CFG::NTHandle NTHandle; + CFG::NTs &nts,new_nts; + CFG::Rules &rules, new_rules; +// above will be appended at the end, so we don't have to worry about iterator invalidation + NTHandle newnt; //not negative. TODO: i think we use it most often as negative. either way, be careful. + RuleHandle newruleid; + HASH_MAP > rhs2lhs; + bool name_nts; + add_virtual_rules(CFG &cfg,bool name_nts=false) : nts(cfg.nts),rules(cfg.rules),newnt(nts.size()),newruleid(rules.size()),name_nts(name_nts) { + HASH_MAP_EMPTY(rhs2lhs,null_for::null); + } + NTHandle get_virt(Rhs const& r) { + NTHandle nt=get_default(rhs2lhs,r,newnt); + if (newnt==nt) { + create_nt(r); + create_rule(r); + } + return nt; + } + NTHandle get_nt(Rhs const& r) { + NTHandle nt=get_default(rhs2lhs,r,newnt); + if (newnt==nt) { + create(r); + } + return nt; + } + inline void set_nt_name(Rhs const& r) { + if (name_nts) + new_nts.back().from.nt=BinName(r,nts,new_nts); + } + inline void create_nt(Rhs const& rhs) { + new_nts.push_back(CFG::NT(newruleid++)); + set_nt_name(rhs); + } + inline void create_rule(Rhs const& rhs) { + new_rules.push_back(CFG::Rule(newnt++,rhs)); + } + inline void create(Rhs const& rhs) { + create_nt(rhs); + create_rule(rhs); + assert(newruleid==rules.size()+new_rules.size());assert(newnt==nts.size()+new_nts.size()); + } + + ~add_virtual_rules() { + append_rules(); + } + void append_rules() { + // marginally more efficient + batched_append_swap(nts,new_nts); + batched_append_swap(rules,new_rules); + } + inline bool have(Rhs const& rhs,NTHandle &h) const { + return rhs2lhs.find(rhs)!=rhs2lhs.end(); + } + //HACK: prevent this for instantiating for BinRhs. we need to use rule index because we'll be adding rules before we can update. + // returns 1 per replaced NT (0,1, or 2) + template + int split_rhs(RHSi &rhs,bool only_free=false,bool only_reusing_1=false) { + //TODO: don't actually build substrings of rhs; define key type that stores ref to rhs in new_nts by index (because it may grow), and also allows an int* [b,e) range to serve as key (i.e. don't insert the latter type of key). + int n=rhs.size(); + if (n<=2) return 0; + int longest1=0; + int mid=n/2; + int best_k=mid; + bool haver,havel; + NTHandle ntr,ntl; + NTHandle bestntr,bestntl; + WordID *b=&rhs.front(),*e=b+n; + for (int k=1;klongest1) { + longest1=k; + haver=false; + havel=true; + bestntl=ntl; + best_k=k; + } + } else if (rlen>=longest1) { + Rhs r(wk,e); + if (have(r,ntr)) { + longest1=rlen; + havel=false; + haver=true; + bestntr=ntr; + best_k=k; + } + } + } + if (only_free) { + if (havel) { + rhs.erase(rhs.begin()+1,rhs.begin()+best_k); + rhs[0]=-ntl; + } else if (haver) { + rhs.erase(rhs.begin()+best_k+1,rhs.end()); + rhs[best_k]=-ntr; + } else + return 0; + return 1; + } + if (only_reusing_1 && longest1==0) return 0; + //TODO: swap order of consideration (l first or r?) depending on pre/post midpoint? one will be useless to check for beating the longest single match so far. check that second. + WordID *best_wk=b+best_k; + Rhs l(b,best_wk); + Rhs r(best_wk,e); + //we build these first because adding rules may invalidate the underlying pointers (we end up binarizing already split virt rules)! + rhs.resize(2); + int nnt=newnt; + rhs[0]=-(havel?bestntl:nnt++); + rhs[1]=-(haver?bestntr:nnt); + // now that we've set rhs, we can actually safely add rules + if (!havel) + create(l); + if (!haver) + create(r); + return 2; + } +}; + + +template +struct null_for; + +typedef CFG::BinRhs BinRhs; + +template <> +struct null_for { + static BinRhs null; +}; +BinRhs null_for::null(std::numeric_limits::min(),std::numeric_limits::min()); + +template <> +struct null_for { + static RHS null; +}; +RHS null_for::null(1,std::numeric_limits::min()); + }//ns +void CFG::BinarizeSplit(CFGBinarize const& b) { + add_virtual_rules v(*this,b.bin_name_nts); + CFG_FOR_RULES(i,v.split_rhs(rules[i].rhs,false,false)); +#undef CFG_FOR_VIRT +#define CFG_FOR_VIRT(r,expr) \ + for (Rules::iterator ri=v.new_rules.begin(),e=v.new_rules.end();ri!=e;++ri) { \ + Rule &r=*ri;expr; } + + int n_changed_total=0; + +#define CFG_SPLIT_PASS(N,free,just1) \ + for (int i=0;i v(*this,name); +cerr << "Binarizing left->right " << (bin_unary?"real to unary":"stop at binary") < > bin2lhs; // we're going to hash cons rather than build an explicit trie from right to left. HASH_MAP_EMPTY(bin2lhs,null_bin_rhs); // iterate using indices and not iterators because we'll be adding to both nts and rules list? we could instead pessimistically reserve space for both, but this is simpler. also: store original end of nts since we won't need to reprocess newly added ones. - int rhsmin=b.bin_unary?0:1; - NTs new_nts; // these will be appended at the end, so we don't have to worry about iterator invalidation - Rules new_rules; + int rhsmin=bin_unary?0:1; + //NTs new_nts; + //Rules new_rules; //TODO: this could be factored easily into in-place (append to new_* like below) and functional (nondestructive copy) versions (copy orig to target and append to target) - int newnt=-nts.size(); // we're going to store binary rhs with -nt to keep distinct from words (>=0) - int newruleid=rules.size(); +// int newnt=nts.size(); // we're going to store binary rhs with -nt to keep distinct from words (>=0) +// int newruleid=rules.size(); BinRhs bin; - for (NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) { + CFG_FOR_RULES(ruleid, +/* for (NTs::const_iterator n=nts.begin(),nn=nts.end();n!=nn;++n) { NT const& nt=*n; for (Ruleids::const_iterator ir=nt.ruleids.begin(),er=nt.ruleids.end();ir!=er;++ir) { - RuleHandle ruleid=*ir; - SHOW2(DBIN,ruleid,ShowRule(ruleid)) - RHS &rhs=rules[ruleid].rhs; // we're going to binarize this while adding newly created rules to new_... + RuleHandle ruleid=*ir;*/ +// SHOW2(DBIN,ruleid,ShowRule(ruleid)); + Rule & rule=rules[ruleid]; + RHS &rhs=rule.rhs; // we're going to binarize this while adding newly created rules to new_... if (rhs.empty()) continue; int r=rhs.size()-2; // loop below: [r,r+1) is to be reduced into a (maybe new) binary NT if (rhsmin<=r) { // means r>=0 also bin.second=rhs[r+1]; int bin_to; // the replacement for bin - assert(newruleid==rules.size()+new_rules.size());assert(-newnt==nts.size()+new_nts.size()); +// assert(newruleid==rules.size()+new_rules.size());assert(newnt==nts.size()+new_nts.size()); // also true at start/end of loop: for (;;) { // pairs from right to left (normally we leave the last pair alone) - bin.first=rhs[r]; - bin_to=get_default(bin2lhs,bin,newnt); - SHOW(DBIN,r) SHOW(DBIN,newnt) SHOWP(DBIN,"bin="<") SHOW(DBIN,bin_to); - if (newnt==bin_to) { // it's new! + bin_to=v.get_virt(bin); +/* bin_to=get_default(bin2lhs,bin,v.newnt); +// SHOW(DBIN,r) SHOW(DBIN,newnt) SHOWP(DBIN,"bin="<") SHOW(DBIN,bin_to); + if (v.newnt==bin_to) { // it's new! new_nts.push_back(NT(newruleid++)); - //now -newnt is the index of the last (after new_nts is appended) nt. bin is its rhs. bin_to is its lhs - new_rules.push_back(Rule(-newnt,bin)); - --newnt; - if (b.bin_name_nts) - new_nts.back().from.nt=BinName(bin,nts,new_nts); + //now newnt is the index of the last (after new_nts is appended) nt. bin is its rhs. bin_to is its lhs + new_rules.push_back(Rule(newnt,bin)); + ++newnt; + if (name) new_nts.back().from.nt=BinName(bin,nts,new_nts); } - bin.second=bin_to; +*/ + bin.second=-bin_to; --r; - if (r NTs; NTs nts; diff --git a/decoder/cfg_binarize.h b/decoder/cfg_binarize.h index 82c4dd1a..3aba5e9f 100755 --- a/decoder/cfg_binarize.h +++ b/decoder/cfg_binarize.h @@ -14,39 +14,50 @@ */ struct CFGBinarize { - int bin_at; + int bin_thresh; bool bin_l2r; - bool bin_unary; + int bin_unary; bool bin_name_nts; bool bin_topo; + bool bin_split; + int split_passes,split_share1_passes,split_free_passes; template // template to support both printable_opts and boost nonprintable void AddOptions(Opts *opts) { opts->add_options() - ("cfg_binarize_at", defaulted_value(&bin_at),"(if >0) binarize CFG rhs segments which appear at least this many times") - ("cfg_binarize_unary", defaulted_value(&bin_unary),"if true, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least cfg_binarize_at times.") + ("cfg_binarize_threshold", defaulted_value(&bin_thresh),"(if >0) repeatedly binarize CFG rhs bigrams which appear at least this many times, most frequent first. resulting rules may be 1,2, or >2-ary. this happens before the other types of binarization.") +// ("cfg_binarize_unary_threshold", defaulted_value(&bin_unary),"if >0, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least this many times. this happens last.") + ("cfg_binarize_greedy_split", defaulted_value(&bin_split),"(DeNero et al) for each rule until binarized, pick a split point k of L->r[0..n) to make rules L->V1 V2, V1->r[0..k) V2->r[k..n), to minimize the number of new rules created") + ("cfg_split_full_passes", defaulted_value(&split_passes),"pass through the virtual rules only (up to) this many times (all real rules will have been split if not already binary)") + ("cfg_split_share1_passes", defaulted_value(&split_share1_passes),"after the full passes, for up to this many times split when at least 1 of the items has been seen before") + ("cfg_split_free_passes", defaulted_value(&split_free_passes),"only split off from virtual nts pre/post nts that already exist - could check for interior phrases but after a few splits everything should be tiny already.") ("cfg_binarize_l2r", defaulted_value(&bin_l2r),"force left to right (a (b (c d))) binarization (ignore _at threshold)") ("cfg_binarize_name_nts", defaulted_value(&bin_name_nts),"create named virtual NT tokens e.g. 'A12+the' when binarizing 'B->[A12] the cat'") ("cfg_binarize_topo", defaulted_value(&bin_topo),"reorder nonterminals after binarization to maintain definition before use (topological order). otherwise the virtual NTs will all appear after the regular NTs") ; } void Validate() { - if (bin_l2r) - bin_at=0; - if (bin_at>0&&!bin_l2r) { + if (bin_thresh>0&&!bin_l2r) { std::cerr<<"\nWARNING: greedy binarization not yet supported; using l2r (right branching) instead.\n"; bin_l2r=true; } + if (false && bin_l2r && bin_split) { // actually, split may be slightly incomplete due to finite number of passes. + std::cerr<<"\nWARNING: l2r and split are both complete binarization and redundant. Using split.\n"; + bin_l2r=false; + } + } bool Binarizing() const { - return bin_l2r || bin_at>0; + return bin_split || bin_l2r || bin_thresh>0; } void set_defaults() { + bin_split=false; bin_topo=false; - bin_at=0; - bin_unary=false; + bin_thresh=0; + bin_unary=0; bin_name_nts=true; bin_l2r=false; + split_passes=10;split_share1_passes=0;split_free_passes=10; } CFGBinarize() { set_defaults(); } void print(std::ostream &o) const { @@ -56,10 +67,12 @@ struct CFGBinarize { else { if (bin_unary) o << "unary-sharing "; + if (bin_thresh) + o<<"greedy bigram count>="<right"; else - o << "greedy count>="<