diff options
Diffstat (limited to 'gi/pf/learn_cfg.cc')
-rw-r--r-- | gi/pf/learn_cfg.cc | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/gi/pf/learn_cfg.cc b/gi/pf/learn_cfg.cc index 6e574035..b2ca029a 100644 --- a/gi/pf/learn_cfg.cc +++ b/gi/pf/learn_cfg.cc @@ -30,6 +30,7 @@ vector<int> nt_id_to_index; static unsigned kMAX_RULE_SIZE = 0; static unsigned kMAX_ARITY = 0; static bool kALLOW_MIXED = true; // allow rules with mixed terminals and NTs +static bool kHIERARCHICAL_PRIOR = false; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); @@ -40,11 +41,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_arity,a", po::value<unsigned>()->default_value(0), "Maximum number of nonterminals in a rule (0 for unlimited)") ("no_mixed_rules,M", "Do not mix terminals and nonterminals in a rule RHS") ("nonterminals,n", po::value<unsigned>()->default_value(1), "Size of nonterminal vocabulary") + ("hierarchical_prior,h", "Use hierarchical prior") ("random_seed,S",po::value<uint32_t>(), "Random seed"); po::options_description clo("Command line options"); clo.add_options() ("config", po::value<string>(), "Configuration file") - ("help,h", "Print this help message and exit"); + ("help", "Print this help message and exit"); po::options_description dconfig_options, dcmdline_options; dconfig_options.add(opts); dcmdline_options.add(opts).add(clo); @@ -119,19 +121,35 @@ struct BaseRuleModel { }; struct HieroLMModel { - explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : p0(vocab_size, num_nts), nts(num_nts, CCRP<TRule>(1,1,1,1)) {} + explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : + base(vocab_size, num_nts), + q0(1,1,1,1), + nts(num_nts, CCRP<TRule>(1,1,1,1)) {} prob_t Prob(const TRule& r) const { return nts[nt_id_to_index[-r.lhs_]].probT<prob_t>(r, p0(r)); } + inline prob_t p0(const TRule& r) const { + if (kHIERARCHICAL_PRIOR) + return q0.probT<prob_t>(r, base(r)); + else + return base(r); + } + int Increment(const TRule& r, MT19937* rng) { - return nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng); + const int delta = nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng); + if (kHIERARCHICAL_PRIOR && delta) + q0.incrementT<prob_t>(r, base(r), rng); + return delta; // return x.increment(r); } int Decrement(const TRule& r, MT19937* rng) { - return nts[nt_id_to_index[-r.lhs_]].decrement(r, rng); + const int delta = nts[nt_id_to_index[-r.lhs_]].decrement(r, rng); + if (kHIERARCHICAL_PRIOR && delta) + q0.decrement(r, rng); + return delta; //return x.decrement(r); } @@ -146,18 +164,32 @@ struct HieroLMModel { p *= tp; } } + if (kHIERARCHICAL_PRIOR) { + prob_t q; q.logeq(q0.log_crp_prob()); + p *= q; + for (CCRP<TRule>::const_iterator it = q0.begin(); it != q0.end(); ++it) { + prob_t tp = base(it->first); + tp.poweq(it->second.table_counts_.size()); + p *= tp; + } + } //for (CCRP_OneTable<TRule>::const_iterator it = x.begin(); it != x.end(); ++it) - // p *= p0(it->first); + // p *= base(it->first); return p; } void ResampleHyperparameters(MT19937* rng) { for (unsigned i = 0; i < nts.size(); ++i) nts[i].resample_hyperparameters(rng); + if (kHIERARCHICAL_PRIOR) { + q0.resample_hyperparameters(rng); + cerr << "[base d=" << q0.discount() << ", alpha=" << q0.discount() << "]"; + } cerr << " d=" << nts[0].discount() << ", alpha=" << nts[0].concentration() << endl; } - const BaseRuleModel p0; + const BaseRuleModel base; + CCRP<TRule> q0; vector<CCRP<TRule> > nts; //CCRP_OneTable<TRule> x; }; @@ -316,6 +348,8 @@ int main(int argc, char** argv) { } kALLOW_MIXED = !conf.count("no_mixed_rules"); + kHIERARCHICAL_PRIOR = conf.count("hierarchical_prior"); + if (conf.count("random_seed")) prng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); else |