summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-10 23:49:33 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-08-10 23:49:33 +0000
commit46dd30c7d4da68b83ebfd5975153521ee237311f (patch)
treeaba07e1f24639c56b341393880b67987b08854da
parent0fec6f37266ecbaf6f72d2f3d7e652c78004af17 (diff)
CFG binarize opts
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@503 ec762483-ff6d-05da-a07a-a48fb63a330f
-rwxr-xr-xdecoder/apply_fsa_models.cc11
-rwxr-xr-xdecoder/apply_fsa_models.h43
-rw-r--r--decoder/cdec.cc23
-rwxr-xr-xdecoder/cfg.cc13
-rwxr-xr-xdecoder/cfg.h2
-rwxr-xr-xdecoder/cfg_binarize.h72
-rwxr-xr-xdecoder/cfg_format.h9
-rwxr-xr-xdecoder/cfg_options.h49
-rwxr-xr-xdecoder/hg_cfg.h42
9 files changed, 208 insertions, 56 deletions
diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc
index 31b2002d..3a1e8050 100755
--- a/decoder/apply_fsa_models.cc
+++ b/decoder/apply_fsa_models.cc
@@ -8,6 +8,7 @@
#include <stdexcept>
#include <cassert>
#include "cfg.h"
+#include "hg_cfg.h"
using namespace std;
@@ -118,3 +119,13 @@ int ApplyFsaBy::BottomUpAlgorithm() const {
:IntersectionConfiguration::FULL;
}
+void ApplyFsaModels(Hypergraph const& ih,
+ const SentenceMetadata& smeta,
+ const FsaFeatureFunction& fsa,
+ DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
+ ApplyFsaBy const& cfg,
+ Hypergraph* out)
+{
+ HgCFG i(ih);
+ ApplyFsaModels(i,smeta,fsa,weights,cfg,out);
+}
diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h
index 0227664a..5120fb4e 100755
--- a/decoder/apply_fsa_models.h
+++ b/decoder/apply_fsa_models.h
@@ -4,11 +4,11 @@
#include <string>
#include <iostream>
#include "feature_vector.h"
-#include "cfg.h"
struct FsaFeatureFunction;
struct Hypergraph;
struct SentenceMetadata;
+struct HgCFG;
struct ApplyFsaBy {
enum {
@@ -36,40 +36,6 @@ struct ApplyFsaBy {
static std::string all_names(); // space separated
};
-// in case you might want the CFG whether or not you apply FSA models:
-struct HgCFG {
- HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=false; }
- Hypergraph const& ih;
- CFG cfg;
- bool have_cfg;
- void InitCFG(CFG &to) {
- to.Init(ih,true,false,true);
- }
-
- CFG &GetCFG()
- {
- if (!have_cfg) {
- have_cfg=true;
- InitCFG(cfg);
- }
- return cfg;
- }
- void GiveCFG(CFG &to) {
- if (!have_cfg)
- InitCFG(to);
- else {
- have_cfg=false;
- to.Clear();
- to.Swap(cfg);
- }
- }
- CFG const& GetCFG() const {
- assert(have_cfg);
- return cfg;
- }
-};
-
-
void ApplyFsaModels(HgCFG &hg_or_cfg_in,
const SentenceMetadata& smeta,
const FsaFeatureFunction& fsa,
@@ -77,15 +43,12 @@ void ApplyFsaModels(HgCFG &hg_or_cfg_in,
ApplyFsaBy const& cfg,
Hypergraph* out);
-inline void ApplyFsaModels(Hypergraph const& ih,
+void ApplyFsaModels(Hypergraph const& ih,
const SentenceMetadata& smeta,
const FsaFeatureFunction& fsa,
DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
ApplyFsaBy const& cfg,
- Hypergraph* out) {
- HgCFG i(ih);
- ApplyFsaModels(i,smeta,fsa,weights,cfg,out);
-}
+ Hypergraph* out);
#endif
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index a179e029..b5d768e8 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -37,7 +37,9 @@
#include "../vest/scorer.h"
#include "apply_fsa_models.h"
#include "program_options.h"
-#include "cfg_format.h"
+#include "cfg_options.h"
+
+CFGOptions cfg_options;
using namespace std;
using namespace std::tr1;
@@ -62,7 +64,6 @@ void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
trg->set_value(it->first, it->second);
}
-
inline string str(char const* name,po::variables_map const& conf) {
return conf[name].as<string>();
}
@@ -106,9 +107,6 @@ void print_options(std::ostream &out,po::options_description const& opts) {
out << '"';
}
-
-CFGFormat cfgf;
-
void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* confp) {
po::variables_map &conf=*confp;
po::options_description opts("Configuration options");
@@ -172,10 +170,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
("forest_output,O",po::value<string>(),"Directory to write forests to")
("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
ob.AddOptions(&opts);
- po::options_description cfgo("CFG output options");
- cfgo.add_options()
- ("cfg_output", po::value<string>(),"write final target CFG (before FSA rescorinn) to this file");
- cfgf.AddOptions(&cfgo);
+ po::options_description cfgo(cfg_options.description());
+ cfg_options.AddOptions(&cfgo);
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(), "Configuration file(s) - latest has priority")
@@ -555,6 +551,7 @@ int main(int argc, char** argv) {
const bool crf_uniform_empirical = conf.count("crf_uniform_empirical");
const bool get_oracle_forest = conf.count("get_oracle_forest");
+ cfg_options.Validate();
if (get_oracle_forest)
oracle.UseConf(conf);
@@ -663,9 +660,11 @@ int main(int argc, char** argv) {
maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
HgCFG hgcfg(forest);
- if (conf.count("cfg_output")) {
- WriteFile o(str("cfg_output",conf));
- hgcfg.GetCFG().Print(o.get(),cfgf);
+ cfg_options.maybe_output(hgcfg);
+ if (!cfg_options.cfg_output.empty()) {
+ WriteFile o(cfg_options.cfg_output);
+ CFG &cfg=hgcfg.GetCFG();
+ cfg.Print(o.get(),cfg_options.format);
}
if (!fsa_ffs.empty()) {
Timer t("Target FSA rescoring:");
diff --git a/decoder/cfg.cc b/decoder/cfg.cc
index ace0ebb0..6a5d8342 100755
--- a/decoder/cfg.cc
+++ b/decoder/cfg.cc
@@ -1,9 +1,22 @@
#include "cfg.h"
#include "hg.h"
#include "cfg_format.h"
+#include "cfg_binarize.h"
using namespace std;
+
+void CFG::Binarize(CFGBinarize const& b) {
+ if (!b.Binarizing()) return;
+ if (!b.bin_l2r) {
+ assert(b.bin_l2r);
+ return;
+ }
+ // l2r only so far:
+ cerr << "Binarizing "<<b<<endl;
+ //TODO.
+}
+
void CFG::Init(Hypergraph const& hg,bool target_side,bool copy_features,bool push_weights) {
uninit=false;
hg_=&hg;
diff --git a/decoder/cfg.h b/decoder/cfg.h
index e325c4cd..808c7a32 100755
--- a/decoder/cfg.h
+++ b/decoder/cfg.h
@@ -29,6 +29,7 @@
class Hypergraph;
class CFGFormat; // #include "cfg_format.h"
+class CFGBinarize; // #include "cfg_binarize.h"
struct CFG {
typedef int RuleHandle;
@@ -77,6 +78,7 @@ struct CFG {
swap(nts,o.nts);
swap(goal_nt,o.goal_nt);
}
+ void Binarize(CFGBinarize const& binarize_options);
protected:
bool uninit;
Hypergraph const* hg_; // shouldn't be used for anything, esp. after binarization
diff --git a/decoder/cfg_binarize.h b/decoder/cfg_binarize.h
new file mode 100755
index 00000000..f76619d2
--- /dev/null
+++ b/decoder/cfg_binarize.h
@@ -0,0 +1,72 @@
+#ifndef CFG_BINARIZE_H
+#define CFG_BINARIZE_H
+
+#include <iostream>
+
+/*
+ binarization: decimate rhs of original rules until their rhs have been reduced to length 2 (or 1 if bin_unary). also decimate rhs of newly binarized rules until length 2. newly created rules are all binary (never unary/nullary).
+
+ bin_name_nts: nts[i].from will be initialized, including adding new names to TD
+
+ bin_l2r: right-branching (a (b c)) means suffixes are shared. if requested, the only other option that matters is bin_unary
+
+ otherwise, greedy binarization: the pairs that are most frequent in the rules are binarized, one at a time. this should be done efficiently: each pair has a count of and list of its left and right adjacent pair+count (or maybe a non-count collapsed list of adjacent instances). this can be efficiently updated when a pair is chosen for replacement by a new virtual NT.
+ */
+
+struct CFGBinarize {
+ int bin_at;
+ bool bin_l2r;
+ bool bin_unary;
+ bool bin_name_nts;
+ template <class Opts> // template to support both printable_opts and boost nonprintable
+ void AddOptions(Opts *opts) {
+ opts->add_options()
+ ("cfg_binarize_at", defaulted_value(&bin_at),"(if >0) binarize CFG rhs segments which appear at least this many times")
+ ("cfg_binarize_unary", defaulted_value(&bin_unary),"if true, a rule-completing production A->BC may be binarized as A->U U->BC if U->BC would be used at least cfg_binarize_at times.")
+ ("cfg_binarize_l2r", defaulted_value(&bin_l2r),"force left to right (a (b (c d))) binarization (ignore _at threshold)")
+ ("cfg_binarize_name_nts", defaulted_value(&bin_name_nts),"create named virtual NT tokens e.g. 'A12+the' when binarizing 'B->[A12] the cat'")
+ ;
+ }
+ void Validate() {
+ if (bin_l2r)
+ bin_at=0;
+ if (bin_at>0&&!bin_l2r) {
+ std::cerr<<"\nWARNING: greedy binarization not yet supported; using l2r (right branching) instead.\n";
+ bin_l2r=true;
+ }
+ }
+
+ bool Binarizing() const {
+ return bin_l2r || bin_at>0;
+ }
+ void set_defaults() {
+ bin_at=0;
+ bin_unary=false;
+ bin_name_nts=true;
+ bin_l2r=false;
+ }
+ CFGBinarize() { set_defaults(); }
+ void print(std::ostream &o) const {
+ o<<'(';
+ if (!Binarizing())
+ o << "Unbinarized";
+ else {
+ if (bin_unary)
+ o << "unary-sharing ";
+ if (bin_l2r)
+ o << "left->right";
+ else
+ o << "greedy count>="<<bin_at;
+ if (bin_name_nts)
+ o << " named-NTs";
+ }
+ o<<')';
+ }
+ friend inline std::ostream &operator<<(std::ostream &o,CFGBinarize const& me) {
+ me.print(o); return o;
+ }
+
+};
+
+
+#endif
diff --git a/decoder/cfg_format.h b/decoder/cfg_format.h
index 169632a6..10361804 100755
--- a/decoder/cfg_format.h
+++ b/decoder/cfg_format.h
@@ -1,10 +1,10 @@
#ifndef CFG_FORMAT_H
#define CFG_FORMAT_H
-#include <program_options.h>
#include <string>
#include "wordid.h"
#include "feature_vector.h"
+#include "program_options.h"
struct CFGFormat {
bool identity_scfg;
@@ -18,8 +18,8 @@ struct CFGFormat {
std::string partsep;
template <class Opts> // template to support both printable_opts and boost nonprintable
void AddOptions(Opts *opts) {
- using namespace boost::program_options;
- using namespace std;
+ //using namespace boost::program_options;
+ //using namespace std;
opts->add_options()
("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.")
("features",defaulted_value(&features),"print the CFG feature vector")
@@ -31,7 +31,7 @@ struct CFGFormat {
("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j). this is after --nt_prefix if any")
;
}
-
+ void Validate() { }
template<class CFG>
void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const {
o<<'[';
@@ -105,4 +105,5 @@ struct CFGFormat {
};
+
#endif
diff --git a/decoder/cfg_options.h b/decoder/cfg_options.h
new file mode 100755
index 00000000..bc7fed5f
--- /dev/null
+++ b/decoder/cfg_options.h
@@ -0,0 +1,49 @@
+#ifndef CFG_OPTIONS_H
+#define CFG_OPTIONS_H
+
+#include "hg_cfg.h"
+#include "cfg_format.h"
+#include "cfg_binarize.h"
+//#include "program_options.h"
+
+struct CFGOptions {
+ CFGFormat format;
+ CFGBinarize binarize;
+ std::string cfg_output;
+ void set_defaults() {
+ format.set_defaults();
+ binarize.set_defaults();
+ cfg_output="";
+ }
+ CFGOptions() { set_defaults(); }
+ template <class Opts> // template to support both printable_opts and boost nonprintable
+ void AddOptions(Opts *opts) {
+ opts->add_options()
+ ("cfg_output", defaulted_value(&cfg_output),"write final target CFG (before FSA rescorinn) to this file")
+ ;
+ binarize.AddOptions(opts);
+ format.AddOptions(opts);
+ }
+ void Validate() {
+ format.Validate();
+ binarize.Validate();
+ }
+ char const* description() const {
+ return "CFG output options";
+ }
+ void maybe_output(HgCFG &hgcfg) {
+ if (cfg_output.empty()) return;
+ WriteFile o(cfg_output);
+ maybe_binarize(hgcfg);
+ hgcfg.GetCFG().Print(o.get(),format);
+ }
+ void maybe_binarize(HgCFG &hgcfg) {
+ if (hgcfg.binarized) return;
+ hgcfg.GetCFG().Binarize(binarize);
+ hgcfg.binarized=true;
+ }
+
+};
+
+
+#endif
diff --git a/decoder/hg_cfg.h b/decoder/hg_cfg.h
new file mode 100755
index 00000000..0a3eb53c
--- /dev/null
+++ b/decoder/hg_cfg.h
@@ -0,0 +1,42 @@
+#ifndef HG_CFG_H
+#define HG_CFG_H
+
+#include "cfg.h"
+
+class Hypergraph;
+
+// in case you might want the CFG whether or not you apply FSA models:
+struct HgCFG {
+ HgCFG(Hypergraph const& ih) : ih(ih) { have_cfg=binarized=false; }
+ Hypergraph const& ih;
+ CFG cfg;
+ bool have_cfg;
+ void InitCFG(CFG &to) {
+ to.Init(ih,true,false,true);
+ }
+ bool binarized;
+ CFG &GetCFG()
+ {
+ if (!have_cfg) {
+ have_cfg=true;
+ InitCFG(cfg);
+ }
+ return cfg;
+ }
+ void GiveCFG(CFG &to) {
+ if (!have_cfg)
+ InitCFG(to);
+ else {
+ have_cfg=false;
+ to.Clear();
+ to.Swap(cfg);
+ }
+ }
+ CFG const& GetCFG() const {
+ assert(have_cfg);
+ return cfg;
+ }
+};
+
+
+#endif