diff options
Diffstat (limited to 'decoder')
| -rwxr-xr-x | decoder/apply_fsa_models.cc | 11 | ||||
| -rwxr-xr-x | decoder/apply_fsa_models.h | 43 | ||||
| -rw-r--r-- | decoder/cdec.cc | 23 | ||||
| -rwxr-xr-x | decoder/cfg.cc | 13 | ||||
| -rwxr-xr-x | decoder/cfg.h | 2 | ||||
| -rwxr-xr-x | decoder/cfg_binarize.h | 72 | ||||
| -rwxr-xr-x | decoder/cfg_format.h | 9 | ||||
| -rwxr-xr-x | decoder/cfg_options.h | 49 | ||||
| -rwxr-xr-x | decoder/hg_cfg.h | 42 | 
9 files changed, 208 insertions, 56 deletions
| diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc index 31b2002d..3a1e8050 100755 --- a/decoder/apply_fsa_models.cc +++ b/decoder/apply_fsa_models.cc @@ -8,6 +8,7 @@  #include <stdexcept>  #include <cassert>  #include "cfg.h" +#include "hg_cfg.h"  using namespace std; @@ -118,3 +119,13 @@ int ApplyFsaBy::BottomUpAlgorithm() const {      :IntersectionConfiguration::FULL;  } +void ApplyFsaModels(Hypergraph const& ih, +                    const SentenceMetadata& smeta, +                    const FsaFeatureFunction& fsa, +                    DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this) +                    ApplyFsaBy const& cfg, +                    Hypergraph* out) +{ +  HgCFG i(ih); +  ApplyFsaModels(i,smeta,fsa,weights,cfg,out); +} diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h index 0227664a..5120fb4e 100755 --- a/decoder/apply_fsa_models.h +++ b/decoder/apply_fsa_models.h @@ -4,11 +4,11 @@  #include <string>  #include <iostream>  #include "feature_vector.h" -#include "cfg.h"  struct FsaFeatureFunction;  struct Hypergraph;  struct SentenceMetadata; +struct HgCFG;  struct ApplyFsaBy {    enum { @@ -36,40 +36,6 @@ struct ApplyFsaBy {    static std::string all_names(); // space separated  }; -// in case you might want the CFG whether or not you apply FSA models: -struct HgCFG { -  HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=false; } -  Hypergraph const& ih; -  CFG cfg; -  bool have_cfg; -  void InitCFG(CFG &to) { -    to.Init(ih,true,false,true); -  } - -  CFG &GetCFG() -  { -    if (!have_cfg) { -      have_cfg=true; -      InitCFG(cfg); -    } -    return cfg; -  } -  void GiveCFG(CFG &to) { -    if (!have_cfg) -      InitCFG(to); -    else { -      have_cfg=false; -      to.Clear(); -      to.Swap(cfg); -    } -  } -  CFG const& GetCFG() const { -    assert(have_cfg); -    return cfg; -  } -}; - -  void ApplyFsaModels(HgCFG &hg_or_cfg_in,                      const SentenceMetadata& smeta,                      const FsaFeatureFunction& fsa, @@ -77,15 +43,12 @@ void ApplyFsaModels(HgCFG &hg_or_cfg_in,                      ApplyFsaBy const& cfg,                      Hypergraph* out); -inline void ApplyFsaModels(Hypergraph const& ih, +void ApplyFsaModels(Hypergraph const& ih,                      const SentenceMetadata& smeta,                      const FsaFeatureFunction& fsa,                      DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)                      ApplyFsaBy const& cfg, -                    Hypergraph* out) { -  HgCFG i(ih); -  ApplyFsaModels(i,smeta,fsa,weights,cfg,out); -} +                    Hypergraph* out);  #endif diff --git a/decoder/cdec.cc b/decoder/cdec.cc index a179e029..b5d768e8 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -37,7 +37,9 @@  #include "../vest/scorer.h"  #include "apply_fsa_models.h"  #include "program_options.h" -#include "cfg_format.h" +#include "cfg_options.h" + +CFGOptions cfg_options;  using namespace std;  using namespace std::tr1; @@ -62,7 +64,6 @@ void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {      trg->set_value(it->first, it->second);  } -  inline string str(char const* name,po::variables_map const& conf) {    return conf[name].as<string>();  } @@ -106,9 +107,6 @@ void print_options(std::ostream &out,po::options_description const& opts) {    out << '"';  } - -CFGFormat cfgf; -  void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) {    po::variables_map &conf=*confp;    po::options_description opts("Configuration options"); @@ -172,10 +170,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c          ("forest_output,O",po::value<string>(),"Directory to write forests to")          ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");    ob.AddOptions(&opts); -  po::options_description cfgo("CFG output options"); -  cfgo.add_options() -    ("cfg_output", po::value<string>(),"write final target CFG (before FSA rescorinn) to this file"); -  cfgf.AddOptions(&cfgo); +  po::options_description cfgo(cfg_options.description()); +  cfg_options.AddOptions(&cfgo);    po::options_description clo("Command line options");    clo.add_options()      ("config,c", po::value<vector<string> >(), "Configuration file(s) - latest has priority") @@ -555,6 +551,7 @@ int main(int argc, char** argv) {    const bool crf_uniform_empirical = conf.count("crf_uniform_empirical");    const bool get_oracle_forest = conf.count("get_oracle_forest"); +  cfg_options.Validate();    if (get_oracle_forest)      oracle.UseConf(conf); @@ -663,9 +660,11 @@ int main(int argc, char** argv) {      maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);      HgCFG hgcfg(forest); -    if (conf.count("cfg_output")) { -      WriteFile o(str("cfg_output",conf)); -      hgcfg.GetCFG().Print(o.get(),cfgf); +    cfg_options.maybe_output(hgcfg); +    if (!cfg_options.cfg_output.empty()) { +      WriteFile o(cfg_options.cfg_output); +      CFG &cfg=hgcfg.GetCFG(); +      cfg.Print(o.get(),cfg_options.format);      }      if (!fsa_ffs.empty()) {        Timer t("Target FSA rescoring:"); diff --git a/decoder/cfg.cc b/decoder/cfg.cc index ace0ebb0..6a5d8342 100755 --- a/decoder/cfg.cc +++ b/decoder/cfg.cc @@ -1,9 +1,22 @@  #include "cfg.h"  #include "hg.h"  #include "cfg_format.h" +#include "cfg_binarize.h"  using namespace std; + +void CFG::Binarize(CFGBinarize const& b) { +  if (!b.Binarizing()) return; +  if (!b.bin_l2r) { +    assert(b.bin_l2r); +    return; +  } +  // l2r only so far: +  cerr << "Binarizing "<<b<<endl; +  //TODO. +} +  void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) {    uninit=false;    hg_=&hg; diff --git a/decoder/cfg.h b/decoder/cfg.h index e325c4cd..808c7a32 100755 --- a/decoder/cfg.h +++ b/decoder/cfg.h @@ -29,6 +29,7 @@  class Hypergraph;  class CFGFormat; // #include "cfg_format.h" +class CFGBinarize; // #include "cfg_binarize.h"  struct CFG {    typedef int RuleHandle; @@ -77,6 +78,7 @@ struct CFG {      swap(nts,o.nts);      swap(goal_nt,o.goal_nt);    } +  void Binarize(CFGBinarize const& binarize_options);  protected:    bool uninit;    Hypergraph const* hg_; // shouldn't be used for anything, esp. after binarization diff --git a/decoder/cfg_binarize.h b/decoder/cfg_binarize.h new file mode 100755 index 00000000..f76619d2 --- /dev/null +++ b/decoder/cfg_binarize.h @@ -0,0 +1,72 @@ +#ifndef CFG_BINARIZE_H +#define CFG_BINARIZE_H + +#include <iostream> + +/* +  binarization: decimate rhs of original rules until their rhs have been reduced to length 2 (or 1 if bin_unary).  also decimate rhs of newly binarized rules until length 2.  newly created rules are all binary (never unary/nullary). + +  bin_name_nts: nts[i].from will be initialized, including adding new names to TD + +  bin_l2r: right-branching (a (b c)) means suffixes are shared.  if requested, the only other option that matters is bin_unary + +  otherwise, greedy binarization: the pairs that are most frequent in the rules are binarized, one at a time.  this should be done efficiently: each pair has a count of and list of its left and right adjacent pair+count (or maybe a non-count collapsed list of adjacent instances).  this can be efficiently updated when a pair is chosen for replacement by a new virtual NT. + */ + +struct CFGBinarize { +  int bin_at; +  bool bin_l2r; +  bool bin_unary; +  bool bin_name_nts; +  template <class Opts> // template to support both printable_opts and boost nonprintable +  void AddOptions(Opts *opts) { +    opts->add_options() +      ("cfg_binarize_at", defaulted_value(&bin_at),"(if >0) binarize CFG rhs segments which appear at least this many times") +      ("cfg_binarize_unary", defaulted_value(&bin_unary),"if true, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least cfg_binarize_at times.") +      ("cfg_binarize_l2r", defaulted_value(&bin_l2r),"force left to right (a (b (c d))) binarization (ignore _at threshold)") +      ("cfg_binarize_name_nts", defaulted_value(&bin_name_nts),"create named virtual NT tokens e.g. 'A12+the' when binarizing 'B->[A12] the cat'") +    ; +  } +  void Validate() { +    if (bin_l2r) +      bin_at=0; +    if (bin_at>0&&!bin_l2r) { +      std::cerr<<"\nWARNING: greedy binarization not yet supported; using l2r (right branching) instead.\n"; +      bin_l2r=true; +    } +  } + +  bool Binarizing() const { +    return bin_l2r || bin_at>0; +  } +  void set_defaults() { +    bin_at=0; +    bin_unary=false; +    bin_name_nts=true; +    bin_l2r=false; +  } +  CFGBinarize() { set_defaults(); } +  void print(std::ostream &o) const { +    o<<'('; +    if (!Binarizing()) +      o << "Unbinarized"; +    else { +      if (bin_unary) +        o << "unary-sharing "; +      if (bin_l2r) +        o << "left->right"; +      else +        o << "greedy count>="<<bin_at; +      if (bin_name_nts) +        o << " named-NTs"; +    } +    o<<')'; +  } +  friend inline std::ostream &operator<<(std::ostream &o,CFGBinarize const& me) { +    me.print(o); return o; +  } + +}; + + +#endif diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h index 169632a6..10361804 100755 --- a/decoder/cfg_format.h +++ b/decoder/cfg_format.h @@ -1,10 +1,10 @@  #ifndef CFG_FORMAT_H  #define CFG_FORMAT_H -#include <program_options.h>  #include <string>  #include "wordid.h"  #include "feature_vector.h" +#include "program_options.h"  struct CFGFormat {    bool identity_scfg; @@ -18,8 +18,8 @@ struct CFGFormat {    std::string partsep;    template <class Opts> // template to support both printable_opts and boost nonprintable    void AddOptions(Opts *opts) { -    using namespace boost::program_options; -    using namespace std; +    //using namespace boost::program_options; +    //using namespace std;      opts->add_options()        ("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.")        ("features",defaulted_value(&features),"print the CFG feature vector") @@ -31,7 +31,7 @@ struct CFGFormat {        ("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j).  this is after --nt_prefix if any")        ;    } - +  void Validate() {  }    template<class CFG>    void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const {      o<<'['; @@ -105,4 +105,5 @@ struct CFGFormat {  }; +  #endif diff --git a/decoder/cfg_options.h b/decoder/cfg_options.h new file mode 100755 index 00000000..bc7fed5f --- /dev/null +++ b/decoder/cfg_options.h @@ -0,0 +1,49 @@ +#ifndef CFG_OPTIONS_H +#define CFG_OPTIONS_H + +#include "hg_cfg.h" +#include "cfg_format.h" +#include "cfg_binarize.h" +//#include "program_options.h" + +struct CFGOptions { +  CFGFormat format; +  CFGBinarize binarize; +  std::string cfg_output; +  void set_defaults() { +    format.set_defaults(); +    binarize.set_defaults(); +    cfg_output=""; +  } +  CFGOptions() { set_defaults(); } +  template <class Opts> // template to support both printable_opts and boost nonprintable +  void AddOptions(Opts *opts) { +    opts->add_options() +      ("cfg_output", defaulted_value(&cfg_output),"write final target CFG (before FSA rescorinn) to this file") +    ; +    binarize.AddOptions(opts); +    format.AddOptions(opts); +  } +  void Validate() { +    format.Validate(); +    binarize.Validate(); +  } +  char const* description() const { +    return "CFG output options"; +  } +  void maybe_output(HgCFG &hgcfg) { +    if (cfg_output.empty()) return; +    WriteFile o(cfg_output); +    maybe_binarize(hgcfg); +    hgcfg.GetCFG().Print(o.get(),format); +  } +  void maybe_binarize(HgCFG &hgcfg) { +    if (hgcfg.binarized) return; +    hgcfg.GetCFG().Binarize(binarize); +    hgcfg.binarized=true; +  } + +}; + + +#endif diff --git a/decoder/hg_cfg.h b/decoder/hg_cfg.h new file mode 100755 index 00000000..0a3eb53c --- /dev/null +++ b/decoder/hg_cfg.h @@ -0,0 +1,42 @@ +#ifndef HG_CFG_H +#define HG_CFG_H + +#include "cfg.h" + +class Hypergraph; + +// in case you might want the CFG whether or not you apply FSA models: +struct HgCFG { +  HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=binarized=false; } +  Hypergraph const& ih; +  CFG cfg; +  bool have_cfg; +  void InitCFG(CFG &to) { +    to.Init(ih,true,false,true); +  } +  bool binarized; +  CFG &GetCFG() +  { +    if (!have_cfg) { +      have_cfg=true; +      InitCFG(cfg); +    } +    return cfg; +  } +  void GiveCFG(CFG &to) { +    if (!have_cfg) +      InitCFG(to); +    else { +      have_cfg=false; +      to.Clear(); +      to.Swap(cfg); +    } +  } +  CFG const& GetCFG() const { +    assert(have_cfg); +    return cfg; +  } +}; + + +#endif | 
