diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/cdec.cc | 3 | ||||
-rwxr-xr-x | decoder/cfg.cc | 26 | ||||
-rwxr-xr-x | decoder/cfg.h | 7 | ||||
-rwxr-xr-x | decoder/cfg_format.h | 1 | ||||
-rwxr-xr-x | decoder/cfg_options.h | 15 | ||||
-rw-r--r-- | decoder/trule.cc | 2 | ||||
-rw-r--r-- | decoder/trule.h | 2 |
7 files changed, 40 insertions, 16 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 9696fb69..8c4a25e0 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -380,7 +380,6 @@ void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) ms.show_features(cerr,cerr,conf.count("warn_0_weight")); } - template <class V> bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { if (conf.count(name)) { @@ -642,6 +641,8 @@ int main(int argc, char** argv) { maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); + cfg_options.maybe_output_source(forest); + bool has_late_models = !late_models.empty(); if (has_late_models) { Timer t("Forest rescoring:"); diff --git a/decoder/cfg.cc b/decoder/cfg.cc index c43ff9d0..0dfd04d5 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -17,6 +17,13 @@ void CFG::Binarize(CFGBinarize const& b) { //TODO. } +namespace { +inline int nt_index(int nvar,Hypergraph::TailNodeVector const& t,bool target_side,int w) { + assert(w<0 || (target_side&&w==0)); + return t[target_side?-w:nvar]; +} +} + void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) { uninit=false; hg_=&hg; @@ -34,8 +41,6 @@ void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool pus for (int i=0;i<ne;++i) { Rule &cfgr=rules[i]; Hypergraph::Edge const& e=hg.edges_[i]; - TRule const& er=*e.rule_; vector<WordID> const& rule_rhs=target_side?er.e():er.f(); - RHS &rhs=cfgr.rhs; prob_t &crp=cfgr.p; crp=e.edge_prob_; cfgr.lhs=e.head_node_; @@ -44,18 +49,27 @@ void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool pus #endif if (copy_features) cfgr.f=e.feature_values_; if (push_weights) crp /=np[e.head_node_]; + TRule const& er=*e.rule_; + vector<WordID> const& rule_rhs=target_side?er.e():er.f(); int nr=rule_rhs.size(); - rhs.resize(nr); + RHS &rhs_out=cfgr.rhs; + rhs_out.resize(nr); + Hypergraph::TailNodeVector const& tails=e.tail_nodes_; + int nvar=0; + //split out into separate target_side, source_side loops? for (int j=0;j<nr;++j) { WordID w=rule_rhs[j]; if (w>0) - rhs[j]=w; + rhs_out[j]=w; else { - int n=e.tail_nodes_[-w]; + int n=nt_index(nvar,tails,target_side,w); + ++nvar; if (push_weights) crp*=np[n]; - rhs[j]=n; + rhs_out[j]=n; } } + assert(nvar==er.Arity()); + assert(nvar==tails.size()); } } diff --git a/decoder/cfg.h b/decoder/cfg.h index 808c7a32..a390ece9 100755 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -1,6 +1,7 @@ -#ifndef CFG_H -#define CFG_H +#ifndef CDEC_CFG_H +#define CDEC_CFG_H +// for now, debug means remembering and printing the TRule behind each CFG rule #ifndef CFG_DEBUG # define CFG_DEBUG 1 #endif @@ -9,7 +10,7 @@ question: how much does making a copy (essentially) of hg simplify things? is the space used worth it? is the node in/out edges index really that much of a waste? is the use of indices that annoying? - the only thing that excites me right now about an explicit cfg is that access to the target rhs can be less painful, and binarization *on the target side* is easier to define + answer: access to the source side and target side rhs is less painful - less indirection; if not a word (w>0) then -w is the NT index. also, non-synchronous ops like binarization make sense. hg is a somewhat bulky encoding of non-synchronous forest using indices to refer to NTs saves space (32 bit index vs 64 bit pointer) and allows more efficient ancillary maps for e.g. chart info (if we used pointers to actual node structures, it would be tempting to add various void * or other slots for use by mapped-during-computation ephemera) */ diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index 1066c510..ccf6e3fa 100755 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -17,6 +17,7 @@ struct CFGFormat { std::string nt_prefix; std::string logprob_feat_name; std::string partsep; + bool goal_nt() const { return !goal_nt_name.empty(); } template <class Opts> // template to support both printable_opts and boost nonprintable void AddOptions(Opts *opts) { //using namespace boost::program_options; diff --git a/decoder/cfg_options.h b/decoder/cfg_options.h index cbbe3b42..956586f0 100755 --- a/decoder/cfg_options.h +++ b/decoder/cfg_options.h @@ -9,7 +9,7 @@ struct CFGOptions { CFGFormat format; CFGBinarize binarize; - std::string cfg_output; + std::string cfg_output,source_cfg_output; void set_defaults() { format.set_defaults(); binarize.set_defaults(); @@ -19,7 +19,8 @@ struct CFGOptions { template <class Opts> // template to support both printable_opts and boost nonprintable void AddOptions(Opts *opts) { opts->add_options() - ("cfg_output", defaulted_value(&cfg_output),"write final target CFG (before FSA rescorinn) to this file") + ("cfg_output", defaulted_value(&cfg_output),"write final target CFG (before FSA rescoring) to this file") + ("source_cfg_output", defaulted_value(&source_cfg_output),"write source CFG (after prelm-scoring, prelm-prune) to this file") ; binarize.AddOptions(opts); format.AddOptions(opts); @@ -31,9 +32,16 @@ struct CFGOptions { char const* description() const { return "CFG output options"; } + void maybe_output_source(Hypergraph const& hg) { + if (source_cfg_output.empty()) return; + std::cerr<<"Printing source CFG to "<<source_cfg_output<<": "<<format<<'\n'; + WriteFile o(source_cfg_output); + CFG cfg(hg,false,format.features,format.goal_nt()); + cfg.Print(o.get(),format); + } void maybe_output(HgCFG &hgcfg) { if (cfg_output.empty()) return; - std::cerr<<"Printing CFG to "<<cfg_output<<": "<<format<<'\n'; + std::cerr<<"Printing target CFG to "<<cfg_output<<": "<<format<<'\n'; WriteFile o(cfg_output); maybe_binarize(hgcfg); hgcfg.GetCFG().Print(o.get(),format); @@ -43,7 +51,6 @@ struct CFGOptions { hgcfg.GetCFG().Binarize(binarize); hgcfg.binarized=true; } - }; diff --git a/decoder/trule.cc b/decoder/trule.cc index b9494951..a40c4e14 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -10,7 +10,7 @@ using namespace std; ostream &operator<<(ostream &o,TRule const& r) { - o<<r.AsString(true); + return o<<r.AsString(true); } bool TRule::IsGoal() const { diff --git a/decoder/trule.h b/decoder/trule.h index e73fd0fe..acdbc5cf 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -125,7 +125,7 @@ class TRule { WordID GetLHS() const { return lhs_; } void ComputeArity(); - // 0 = first variable, -1 = second variable, -2 = third ... + // 0 = first variable, -1 = second variable, -2 = third ..., i.e. tail_nodes_[-w] if w<=0, TD::Convert(w) otherwise std::vector<WordID> e_; // < 0: *-1 = encoding of category of variable std::vector<WordID> f_; |