diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-10 23:49:33 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-10 23:49:33 +0000 |
commit | 667675465486e1f9729931c0b38b5a1124a1c000 (patch) | |
tree | 731cbe983050220b995c643421beb04b4e271560 /decoder | |
parent | 131c2280809e890a817688b708f03a231025fd77 (diff) |
CFG binarize opts
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@503 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder')
-rwxr-xr-x | decoder/apply_fsa_models.cc | 11 | ||||
-rwxr-xr-x | decoder/apply_fsa_models.h | 43 | ||||
-rw-r--r-- | decoder/cdec.cc | 23 | ||||
-rwxr-xr-x | decoder/cfg.cc | 13 | ||||
-rwxr-xr-x | decoder/cfg.h | 2 | ||||
-rwxr-xr-x | decoder/cfg_binarize.h | 72 | ||||
-rwxr-xr-x | decoder/cfg_format.h | 9 | ||||
-rwxr-xr-x | decoder/cfg_options.h | 49 | ||||
-rwxr-xr-x | decoder/hg_cfg.h | 42 |
9 files changed, 208 insertions, 56 deletions
diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc index 31b2002d..3a1e8050 100755 --- a/decoder/apply_fsa_models.cc +++ b/decoder/apply_fsa_models.cc @@ -8,6 +8,7 @@ #include <stdexcept> #include <cassert> #include "cfg.h" +#include "hg_cfg.h" using namespace std; @@ -118,3 +119,13 @@ int ApplyFsaBy::BottomUpAlgorithm() const { :IntersectionConfiguration::FULL; } +void ApplyFsaModels(Hypergraph const& ih, + const SentenceMetadata& smeta, + const FsaFeatureFunction& fsa, + DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) + ApplyFsaBy const& cfg, + Hypergraph* out) +{ + HgCFG i(ih); + ApplyFsaModels(i,smeta,fsa,weights,cfg,out); +} diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h index 0227664a..5120fb4e 100755 --- a/decoder/apply_fsa_models.h +++ b/decoder/apply_fsa_models.h @@ -4,11 +4,11 @@ #include <string> #include <iostream> #include "feature_vector.h" -#include "cfg.h" struct FsaFeatureFunction; struct Hypergraph; struct SentenceMetadata; +struct HgCFG; struct ApplyFsaBy { enum { @@ -36,40 +36,6 @@ struct ApplyFsaBy { static std::string all_names(); // space separated }; -// in case you might want the CFG whether or not you apply FSA models: -struct HgCFG { - HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=false; } - Hypergraph const& ih; - CFG cfg; - bool have_cfg; - void InitCFG(CFG &to) { - to.Init(ih,true,false,true); - } - - CFG &GetCFG() - { - if (!have_cfg) { - have_cfg=true; - InitCFG(cfg); - } - return cfg; - } - void GiveCFG(CFG &to) { - if (!have_cfg) - InitCFG(to); - else { - have_cfg=false; - to.Clear(); - to.Swap(cfg); - } - } - CFG const& GetCFG() const { - assert(have_cfg); - return cfg; - } -}; - - void ApplyFsaModels(HgCFG &hg_or_cfg_in, const SentenceMetadata& smeta, const FsaFeatureFunction& fsa, @@ -77,15 +43,12 @@ void ApplyFsaModels(HgCFG &hg_or_cfg_in, ApplyFsaBy const& cfg, Hypergraph* out); -inline void ApplyFsaModels(Hypergraph const& ih, +void ApplyFsaModels(Hypergraph const& ih, const SentenceMetadata& smeta, const FsaFeatureFunction& fsa, DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) ApplyFsaBy const& cfg, - Hypergraph* out) { - HgCFG i(ih); - ApplyFsaModels(i,smeta,fsa,weights,cfg,out); -} + Hypergraph* out); #endif diff --git a/decoder/cdec.cc b/decoder/cdec.cc index a179e029..b5d768e8 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -37,7 +37,9 @@ #include "../vest/scorer.h" #include "apply_fsa_models.h" #include "program_options.h" -#include "cfg_format.h" +#include "cfg_options.h" + +CFGOptions cfg_options; using namespace std; using namespace std::tr1; @@ -62,7 +64,6 @@ void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { trg->set_value(it->first, it->second); } - inline string str(char const* name,po::variables_map const& conf) { return conf[name].as<string>(); } @@ -106,9 +107,6 @@ void print_options(std::ostream &out,po::options_description const& opts) { out << '"'; } - -CFGFormat cfgf; - void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) { po::variables_map &conf=*confp; po::options_description opts("Configuration options"); @@ -172,10 +170,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c ("forest_output,O",po::value<string>(),"Directory to write forests to") ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); ob.AddOptions(&opts); - po::options_description cfgo("CFG output options"); - cfgo.add_options() - ("cfg_output", po::value<string>(),"write final target CFG (before FSA rescorinn) to this file"); - cfgf.AddOptions(&cfgo); + po::options_description cfgo(cfg_options.description()); + cfg_options.AddOptions(&cfgo); po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value<vector<string> >(), "Configuration file(s) - latest has priority") @@ -555,6 +551,7 @@ int main(int argc, char** argv) { const bool crf_uniform_empirical = conf.count("crf_uniform_empirical"); const bool get_oracle_forest = conf.count("get_oracle_forest"); + cfg_options.Validate(); if (get_oracle_forest) oracle.UseConf(conf); @@ -663,9 +660,11 @@ int main(int argc, char** argv) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); HgCFG hgcfg(forest); - if (conf.count("cfg_output")) { - WriteFile o(str("cfg_output",conf)); - hgcfg.GetCFG().Print(o.get(),cfgf); + cfg_options.maybe_output(hgcfg); + if (!cfg_options.cfg_output.empty()) { + WriteFile o(cfg_options.cfg_output); + CFG &cfg=hgcfg.GetCFG(); + cfg.Print(o.get(),cfg_options.format); } if (!fsa_ffs.empty()) { Timer t("Target FSA rescoring:"); diff --git a/decoder/cfg.cc b/decoder/cfg.cc index ace0ebb0..6a5d8342 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -1,9 +1,22 @@ #include "cfg.h" #include "hg.h" #include "cfg_format.h" +#include "cfg_binarize.h" using namespace std; + +void CFG::Binarize(CFGBinarize const& b) { + if (!b.Binarizing()) return; + if (!b.bin_l2r) { + assert(b.bin_l2r); + return; + } + // l2r only so far: + cerr << "Binarizing "<<b<<endl; + //TODO. +} + void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) { uninit=false; hg_=&hg; diff --git a/decoder/cfg.h b/decoder/cfg.h index e325c4cd..808c7a32 100755 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -29,6 +29,7 @@ class Hypergraph; class CFGFormat; // #include "cfg_format.h" +class CFGBinarize; // #include "cfg_binarize.h" struct CFG { typedef int RuleHandle; @@ -77,6 +78,7 @@ struct CFG { swap(nts,o.nts); swap(goal_nt,o.goal_nt); } + void Binarize(CFGBinarize const& binarize_options); protected: bool uninit; Hypergraph const* hg_; // shouldn't be used for anything, esp. after binarization diff --git a/decoder/cfg_binarize.h b/decoder/cfg_binarize.h new file mode 100755 index 00000000..f76619d2 --- /dev/null +++ b/decoder/cfg_binarize.h @@ -0,0 +1,72 @@ +#ifndef CFG_BINARIZE_H +#define CFG_BINARIZE_H + +#include <iostream> + +/* + binarization: decimate rhs of original rules until their rhs have been reduced to length 2 (or 1 if bin_unary). also decimate rhs of newly binarized rules until length 2. newly created rules are all binary (never unary/nullary). + + bin_name_nts: nts[i].from will be initialized, including adding new names to TD + + bin_l2r: right-branching (a (b c)) means suffixes are shared. if requested, the only other option that matters is bin_unary + + otherwise, greedy binarization: the pairs that are most frequent in the rules are binarized, one at a time. this should be done efficiently: each pair has a count of and list of its left and right adjacent pair+count (or maybe a non-count collapsed list of adjacent instances). this can be efficiently updated when a pair is chosen for replacement by a new virtual NT. + */ + +struct CFGBinarize { + int bin_at; + bool bin_l2r; + bool bin_unary; + bool bin_name_nts; + template <class Opts> // template to support both printable_opts and boost nonprintable + void AddOptions(Opts *opts) { + opts->add_options() + ("cfg_binarize_at", defaulted_value(&bin_at),"(if >0) binarize CFG rhs segments which appear at least this many times") + ("cfg_binarize_unary", defaulted_value(&bin_unary),"if true, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least cfg_binarize_at times.") + ("cfg_binarize_l2r", defaulted_value(&bin_l2r),"force left to right (a (b (c d))) binarization (ignore _at threshold)") + ("cfg_binarize_name_nts", defaulted_value(&bin_name_nts),"create named virtual NT tokens e.g. 'A12+the' when binarizing 'B->[A12] the cat'") + ; + } + void Validate() { + if (bin_l2r) + bin_at=0; + if (bin_at>0&&!bin_l2r) { + std::cerr<<"\nWARNING: greedy binarization not yet supported; using l2r (right branching) instead.\n"; + bin_l2r=true; + } + } + + bool Binarizing() const { + return bin_l2r || bin_at>0; + } + void set_defaults() { + bin_at=0; + bin_unary=false; + bin_name_nts=true; + bin_l2r=false; + } + CFGBinarize() { set_defaults(); } + void print(std::ostream &o) const { + o<<'('; + if (!Binarizing()) + o << "Unbinarized"; + else { + if (bin_unary) + o << "unary-sharing "; + if (bin_l2r) + o << "left->right"; + else + o << "greedy count>="<<bin_at; + if (bin_name_nts) + o << " named-NTs"; + } + o<<')'; + } + friend inline std::ostream &operator<<(std::ostream &o,CFGBinarize const& me) { + me.print(o); return o; + } + +}; + + +#endif diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index 169632a6..10361804 100755 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -1,10 +1,10 @@ #ifndef CFG_FORMAT_H #define CFG_FORMAT_H -#include <program_options.h> #include <string> #include "wordid.h" #include "feature_vector.h" +#include "program_options.h" struct CFGFormat { bool identity_scfg; @@ -18,8 +18,8 @@ struct CFGFormat { std::string partsep; template <class Opts> // template to support both printable_opts and boost nonprintable void AddOptions(Opts *opts) { - using namespace boost::program_options; - using namespace std; + //using namespace boost::program_options; + //using namespace std; opts->add_options() ("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.") ("features",defaulted_value(&features),"print the CFG feature vector") @@ -31,7 +31,7 @@ struct CFGFormat { ("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j). this is after --nt_prefix if any") ; } - + void Validate() { } template<class CFG> void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const { o<<'['; @@ -105,4 +105,5 @@ struct CFGFormat { }; + #endif diff --git a/decoder/cfg_options.h b/decoder/cfg_options.h new file mode 100755 index 00000000..bc7fed5f --- /dev/null +++ b/decoder/cfg_options.h @@ -0,0 +1,49 @@ +#ifndef CFG_OPTIONS_H +#define CFG_OPTIONS_H + +#include "hg_cfg.h" +#include "cfg_format.h" +#include "cfg_binarize.h" +//#include "program_options.h" + +struct CFGOptions { + CFGFormat format; + CFGBinarize binarize; + std::string cfg_output; + void set_defaults() { + format.set_defaults(); + binarize.set_defaults(); + cfg_output=""; + } + CFGOptions() { set_defaults(); } + template <class Opts> // template to support both printable_opts and boost nonprintable + void AddOptions(Opts *opts) { + opts->add_options() + ("cfg_output", defaulted_value(&cfg_output),"write final target CFG (before FSA rescorinn) to this file") + ; + binarize.AddOptions(opts); + format.AddOptions(opts); + } + void Validate() { + format.Validate(); + binarize.Validate(); + } + char const* description() const { + return "CFG output options"; + } + void maybe_output(HgCFG &hgcfg) { + if (cfg_output.empty()) return; + WriteFile o(cfg_output); + maybe_binarize(hgcfg); + hgcfg.GetCFG().Print(o.get(),format); + } + void maybe_binarize(HgCFG &hgcfg) { + if (hgcfg.binarized) return; + hgcfg.GetCFG().Binarize(binarize); + hgcfg.binarized=true; + } + +}; + + +#endif diff --git a/decoder/hg_cfg.h b/decoder/hg_cfg.h new file mode 100755 index 00000000..0a3eb53c --- /dev/null +++ b/decoder/hg_cfg.h @@ -0,0 +1,42 @@ +#ifndef HG_CFG_H +#define HG_CFG_H + +#include "cfg.h" + +class Hypergraph; + +// in case you might want the CFG whether or not you apply FSA models: +struct HgCFG { + HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=binarized=false; } + Hypergraph const& ih; + CFG cfg; + bool have_cfg; + void InitCFG(CFG &to) { + to.Init(ih,true,false,true); + } + bool binarized; + CFG &GetCFG() + { + if (!have_cfg) { + have_cfg=true; + InitCFG(cfg); + } + return cfg; + } + void GiveCFG(CFG &to) { + if (!have_cfg) + InitCFG(to); + else { + have_cfg=false; + to.Clear(); + to.Swap(cfg); + } + } + CFG const& GetCFG() const { + assert(have_cfg); + return cfg; + } +}; + + +#endif |