summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-02-28 00:47:20 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-02-28 00:47:20 -0500
commit5c63dae2edca73b2fa1c668d708b8b0c3ff1f7dc (patch)
tree5c6b57e2bb6fe9d57654addbf7c1d9882106bfdb /gi
parentc9fecc7613c075dc2e998479a9d39a538807e609 (diff)
optional hierarchical prior
Diffstat (limited to 'gi')
-rw-r--r--gi/pf/learn_cfg.cc46
1 files changed, 40 insertions, 6 deletions
diff --git a/gi/pf/learn_cfg.cc b/gi/pf/learn_cfg.cc
index 6e574035..b2ca029a 100644
--- a/gi/pf/learn_cfg.cc
+++ b/gi/pf/learn_cfg.cc
@@ -30,6 +30,7 @@ vector<int> nt_id_to_index;
static unsigned kMAX_RULE_SIZE = 0;
static unsigned kMAX_ARITY = 0;
static bool kALLOW_MIXED = true; // allow rules with mixed terminals and NTs
+static bool kHIERARCHICAL_PRIOR = false;
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
@@ -40,11 +41,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("max_arity,a", po::value<unsigned>()->default_value(0), "Maximum number of nonterminals in a rule (0 for unlimited)")
("no_mixed_rules,M", "Do not mix terminals and nonterminals in a rule RHS")
("nonterminals,n", po::value<unsigned>()->default_value(1), "Size of nonterminal vocabulary")
+ ("hierarchical_prior,h", "Use hierarchical prior")
("random_seed,S",po::value<uint32_t>(), "Random seed");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
- ("help,h", "Print this help message and exit");
+ ("help", "Print this help message and exit");
po::options_description dconfig_options, dcmdline_options;
dconfig_options.add(opts);
dcmdline_options.add(opts).add(clo);
@@ -119,19 +121,35 @@ struct BaseRuleModel {
};
struct HieroLMModel {
- explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : p0(vocab_size, num_nts), nts(num_nts, CCRP<TRule>(1,1,1,1)) {}
+ explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) :
+ base(vocab_size, num_nts),
+ q0(1,1,1,1),
+ nts(num_nts, CCRP<TRule>(1,1,1,1)) {}
prob_t Prob(const TRule& r) const {
return nts[nt_id_to_index[-r.lhs_]].probT<prob_t>(r, p0(r));
}
+ inline prob_t p0(const TRule& r) const {
+ if (kHIERARCHICAL_PRIOR)
+ return q0.probT<prob_t>(r, base(r));
+ else
+ return base(r);
+ }
+
int Increment(const TRule& r, MT19937* rng) {
- return nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng);
+ const int delta = nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng);
+ if (kHIERARCHICAL_PRIOR && delta)
+ q0.incrementT<prob_t>(r, base(r), rng);
+ return delta;
// return x.increment(r);
}
int Decrement(const TRule& r, MT19937* rng) {
- return nts[nt_id_to_index[-r.lhs_]].decrement(r, rng);
+ const int delta = nts[nt_id_to_index[-r.lhs_]].decrement(r, rng);
+ if (kHIERARCHICAL_PRIOR && delta)
+ q0.decrement(r, rng);
+ return delta;
//return x.decrement(r);
}
@@ -146,18 +164,32 @@ struct HieroLMModel {
p *= tp;
}
}
+ if (kHIERARCHICAL_PRIOR) {
+ prob_t q; q.logeq(q0.log_crp_prob());
+ p *= q;
+ for (CCRP<TRule>::const_iterator it = q0.begin(); it != q0.end(); ++it) {
+ prob_t tp = base(it->first);
+ tp.poweq(it->second.table_counts_.size());
+ p *= tp;
+ }
+ }
//for (CCRP_OneTable<TRule>::const_iterator it = x.begin(); it != x.end(); ++it)
- // p *= p0(it->first);
+ // p *= base(it->first);
return p;
}
void ResampleHyperparameters(MT19937* rng) {
for (unsigned i = 0; i < nts.size(); ++i)
nts[i].resample_hyperparameters(rng);
+ if (kHIERARCHICAL_PRIOR) {
+ q0.resample_hyperparameters(rng);
+ cerr << "[base d=" << q0.discount() << ", alpha=" << q0.discount() << "]";
+ }
cerr << " d=" << nts[0].discount() << ", alpha=" << nts[0].concentration() << endl;
}
- const BaseRuleModel p0;
+ const BaseRuleModel base;
+ CCRP<TRule> q0;
vector<CCRP<TRule> > nts;
//CCRP_OneTable<TRule> x;
};
@@ -316,6 +348,8 @@ int main(int argc, char** argv) {
}
kALLOW_MIXED = !conf.count("no_mixed_rules");
+ kHIERARCHICAL_PRIOR = conf.count("hierarchical_prior");
+
if (conf.count("random_seed"))
prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else