diff options
Diffstat (limited to 'gi/pf')
| -rw-r--r-- | gi/pf/learn_cfg.cc | 46 | 
1 files changed, 40 insertions, 6 deletions
diff --git a/gi/pf/learn_cfg.cc b/gi/pf/learn_cfg.cc index 6e574035..b2ca029a 100644 --- a/gi/pf/learn_cfg.cc +++ b/gi/pf/learn_cfg.cc @@ -30,6 +30,7 @@ vector<int> nt_id_to_index;  static unsigned kMAX_RULE_SIZE = 0;  static unsigned kMAX_ARITY = 0;  static bool kALLOW_MIXED = true;  // allow rules with mixed terminals and NTs +static bool kHIERARCHICAL_PRIOR = false;  void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options"); @@ -40,11 +41,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("max_arity,a", po::value<unsigned>()->default_value(0), "Maximum number of nonterminals in a rule (0 for unlimited)")          ("no_mixed_rules,M", "Do not mix terminals and nonterminals in a rule RHS")          ("nonterminals,n", po::value<unsigned>()->default_value(1), "Size of nonterminal vocabulary") +        ("hierarchical_prior,h", "Use hierarchical prior")          ("random_seed,S",po::value<uint32_t>(), "Random seed");    po::options_description clo("Command line options");    clo.add_options()          ("config", po::value<string>(), "Configuration file") -        ("help,h", "Print this help message and exit"); +        ("help", "Print this help message and exit");    po::options_description dconfig_options, dcmdline_options;    dconfig_options.add(opts);    dcmdline_options.add(opts).add(clo); @@ -119,19 +121,35 @@ struct BaseRuleModel {  };  struct HieroLMModel { -  explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : p0(vocab_size, num_nts), nts(num_nts, CCRP<TRule>(1,1,1,1)) {} +  explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : +      base(vocab_size, num_nts), +      q0(1,1,1,1), +      nts(num_nts, CCRP<TRule>(1,1,1,1)) {}    prob_t Prob(const TRule& r) const {      return nts[nt_id_to_index[-r.lhs_]].probT<prob_t>(r, p0(r));    } +  inline prob_t p0(const TRule& r) const { +    if (kHIERARCHICAL_PRIOR) +      return q0.probT<prob_t>(r, base(r)); +    else +      return base(r); +  } +    int Increment(const TRule& r, MT19937* rng) { -    return nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng); +    const int delta = nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng); +    if (kHIERARCHICAL_PRIOR && delta) +      q0.incrementT<prob_t>(r, base(r), rng); +    return delta;      // return x.increment(r);    }    int Decrement(const TRule& r, MT19937* rng) { -    return nts[nt_id_to_index[-r.lhs_]].decrement(r, rng); +    const int delta = nts[nt_id_to_index[-r.lhs_]].decrement(r, rng); +    if (kHIERARCHICAL_PRIOR && delta) +      q0.decrement(r, rng); +    return delta;      //return x.decrement(r);    } @@ -146,18 +164,32 @@ struct HieroLMModel {          p *= tp;        }      } +    if (kHIERARCHICAL_PRIOR) { +      prob_t q; q.logeq(q0.log_crp_prob()); +      p *= q; +      for (CCRP<TRule>::const_iterator it = q0.begin(); it != q0.end(); ++it) { +        prob_t tp = base(it->first); +        tp.poweq(it->second.table_counts_.size()); +        p *= tp; +      } +    }      //for (CCRP_OneTable<TRule>::const_iterator it = x.begin(); it != x.end(); ++it) -    //    p *= p0(it->first); +    //    p *= base(it->first);      return p;    }    void ResampleHyperparameters(MT19937* rng) {      for (unsigned i = 0; i < nts.size(); ++i)        nts[i].resample_hyperparameters(rng); +    if (kHIERARCHICAL_PRIOR) { +      q0.resample_hyperparameters(rng); +      cerr << "[base d=" << q0.discount() << ", alpha=" << q0.discount() << "]"; +    }      cerr << " d=" << nts[0].discount() << ", alpha=" << nts[0].concentration() << endl;    } -  const BaseRuleModel p0; +  const BaseRuleModel base; +  CCRP<TRule> q0;    vector<CCRP<TRule> > nts;    //CCRP_OneTable<TRule> x;  }; @@ -316,6 +348,8 @@ int main(int argc, char** argv) {    }    kALLOW_MIXED = !conf.count("no_mixed_rules"); +  kHIERARCHICAL_PRIOR = conf.count("hierarchical_prior"); +    if (conf.count("random_seed"))      prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));    else  | 
