From 1da72bb6211f196a210302f18c1ef020c0c84f12 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 5 Jul 2011 23:19:43 -0400 Subject: fast phrasinator that uses DPs instead of PYPs --- phrasinator/Makefile.am | 8 +- phrasinator/ccrp_nt.h | 154 +++++++++++++++ phrasinator/gibbs_train_plm.notables.cc | 335 ++++++++++++++++++++++++++++++++ phrasinator/train-phrasinator.pl | 2 +- 4 files changed, 497 insertions(+), 2 deletions(-) create mode 100644 phrasinator/ccrp_nt.h create mode 100644 phrasinator/gibbs_train_plm.notables.cc (limited to 'phrasinator') diff --git a/phrasinator/Makefile.am b/phrasinator/Makefile.am index 0b15a250..95a603df 100644 --- a/phrasinator/Makefile.am +++ b/phrasinator/Makefile.am @@ -1,6 +1,12 @@ -bin_PROGRAMS = gibbs_train_plm +bin_PROGRAMS = gibbs_train_plm head_bigram_model gibbs_train_plm_notables + +gibbs_train_plm_notables_SOURCES = gibbs_train_plm.notables.cc +gibbs_train_plm_notables_LDADD = $(top_srcdir)/utils/libutils.a -lz gibbs_train_plm_SOURCES = gibbs_train_plm.cc gibbs_train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz +head_bigram_model_SOURCES = head_bigram_model.cc +head_bigram_model_LDADD = $(top_srcdir)/utils/libutils.a -lz + AM_CPPFLAGS = -funroll-loops -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h new file mode 100644 index 00000000..163b643a --- /dev/null +++ b/phrasinator/ccrp_nt.h @@ -0,0 +1,154 @@ +#ifndef _CCRP_NT_H_ +#define _CCRP_NT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "sampler.h" +#include "slice_sampler.h" + +// Chinese restaurant process (Pitman-Yor parameters) with table tracking. + +template > +class CCRP_NoTable { + public: + explicit CCRP_NoTable(double conc) : + num_customers_(), + concentration_(conc), + concentration_prior_shape_(std::numeric_limits::quiet_NaN()), + concentration_prior_rate_(std::numeric_limits::quiet_NaN()) {} + + CCRP_NoTable(double c_shape, double c_rate, double c = 10.0) : + num_customers_(), + concentration_(c), + concentration_prior_shape_(c_shape), + concentration_prior_rate_(c_rate) {} + + double concentration() const { return concentration_; } + + bool has_concentration_prior() const { + return !std::isnan(concentration_prior_shape_); + } + + void clear() { + num_customers_ = 0; + custs_.clear(); + } + + unsigned num_customers() const { + return num_customers_; + } + + unsigned num_customers(const Dish& dish) const { + const typename std::tr1::unordered_map::const_iterator it = custs_.find(dish); + if (it == custs_.end()) return 0; + return it->second; + } + + void increment(const Dish& dish) { + ++custs_[dish]; + ++num_customers_; + } + + void decrement(const Dish& dish) { + if ((--custs_[dish]) == 0) + custs_.erase(dish); + --num_customers_; + } + + double prob(const Dish& dish, const double& p0) const { + const unsigned at_table = num_customers(dish); + return (at_table + p0 * concentration_) / (num_customers_ + concentration_); + } + + double log_crp_prob() const { + return log_crp_prob(concentration_); + } + + static double log_gamma_density(const double& x, const double& shape, const double& rate) { + assert(x >= 0.0); + assert(shape > 0.0); + assert(rate > 0.0); + const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); + return lp; + } + + // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process + // does not include P_0's + double log_crp_prob(const double& concentration) const { + double lp = 0.0; + if (has_concentration_prior()) + lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_); + assert(lp <= 0.0); + if (num_customers_) { + lp += lgamma(concentration) - lgamma(concentration + num_customers_) + + custs_.size() * log(concentration); + assert(std::isfinite(lp)); + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + lp += lgamma(it->second); + } + } + assert(std::isfinite(lp)); + return lp; + } + + void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) { + assert(has_concentration_prior()); + ConcentrationResampler cr(*this); + for (int iter = 0; iter < nloop; ++iter) { + concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + } + } + + struct ConcentrationResampler { + ConcentrationResampler(const CCRP_NoTable& crp) : crp_(crp) {} + const CCRP_NoTable& crp_; + double operator()(const double& proposed_concentration) const { + return crp_.log_crp_prob(proposed_concentration); + } + }; + + void Print(std::ostream* out) const { + (*out) << "DP(alpha=" << concentration_ << ") customers=" << num_customers_ << std::endl; + int cc = 0; + for (typename std::tr1::unordered_map::const_iterator it = custs_.begin(); + it != custs_.end(); ++it) { + (*out) << " " << it->first << "(" << it->second << " eating)"; + ++cc; + if (cc > 10) { (*out) << " ..."; break; } + } + (*out) << std::endl; + } + + unsigned num_customers_; + std::tr1::unordered_map custs_; + + typedef typename std::tr1::unordered_map::const_iterator const_iterator; + const_iterator begin() const { + return custs_.begin(); + } + const_iterator end() const { + return custs_.end(); + } + + double concentration_; + + // optional gamma prior on concentration_ (NaN if no prior) + double concentration_prior_shape_; + double concentration_prior_rate_; +}; + +template +std::ostream& operator<<(std::ostream& o, const CCRP_NoTable& c) { + c.Print(&o); + return o; +} + +#endif diff --git a/phrasinator/gibbs_train_plm.notables.cc b/phrasinator/gibbs_train_plm.notables.cc new file mode 100644 index 00000000..4b431b90 --- /dev/null +++ b/phrasinator/gibbs_train_plm.notables.cc @@ -0,0 +1,335 @@ +#include +#include + +#include +#include + +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp.h" +#include "ccrp_nt.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +Dict d; // global dictionary + +string Join(char joiner, const vector& phrase) { + ostringstream os; + for (int i = 0; i < phrase.size(); ++i) { + if (i > 0) os << joiner; + os << d.Convert(phrase[i]); + } + return os.str(); +} + +template +void WriteSeg(const vector& line, const vector& label, const Dict& d) { + assert(line.size() == label.size()); + assert(label.back()); + int prev = 0; + int cur = 0; + while (cur < line.size()) { + if (label[cur]) { + if (prev) cout << ' '; + cout << "{{"; + for (int i = prev; i <= cur; ++i) + cout << (i == prev ? "" : " ") << d.Convert(line[i]); + cout << "}}:" << label[cur]; + prev = cur + 1; + } + ++cur; + } + cout << endl; +} + +ostream& operator<<(ostream& os, const vector& phrase) { + for (int i = 0; i < phrase.size(); ++i) + os << (i == 0 ? "" : " ") << d.Convert(phrase[i]); + return os; +} + +struct UnigramLM { + explicit UnigramLM(const string& fname) { + ifstream in(fname.c_str()); + assert(in); + } + + double logprob(int word) const { + assert(word < freqs_.size()); + return freqs_[word]; + } + + vector freqs_; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read file from") + ("random_seed,S",po::value(), "Random seed") + ("write_cdec_grammar,g", po::value(), "Write cdec grammar to this file") + ("write_cdec_weights,w", po::value(), "Write cdec weights to this file") + ("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion") + ("no_hyperparameter_inference,N", "Disable hyperparameter inference"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadCorpus(const string& filename, vector >* c, set* vocab) { + c->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + c->push_back(vector()); + vector& v = c->back(); + d.ConvertWhitespaceDelimitedLine(line, &v); + for (int i = 0; i < v.size(); ++i) vocab->insert(v[i]); + } + if (in != &cin) delete in; +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct UniphraseLM { + UniphraseLM(const vector >& corpus, + const set& vocab, + const po::variables_map& conf) : + phrases_(1,1), + gen_(1,1), + corpus_(corpus), + uniform_word_(1.0 / vocab.size()), + gen_p0_(0.5), + p_end_(0.5), + use_poisson_(conf.count("poisson_length") > 0) {} + + double p0(const vector& phrase) const { + static vector p0s(10000, 0.0); + assert(phrase.size() < 10000); + double& p = p0s[phrase.size()]; + if (p) return p; + p = exp(log_p0(phrase)); + if (!p) { + cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits::min()\n"; + p = std::numeric_limits::min(); + } + return p; + } + + double log_p0(const vector& phrase) const { + double len_logprob; + if (use_poisson_) + len_logprob = log_poisson(phrase.size(), 1.0); + else + len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_); + return log(uniform_word_) * phrase.size() + len_logprob; + } + + double llh() const { + double llh = gen_.log_crp_prob(); + llh += log(gen_p0_) + log(1 - gen_p0_); + double llhr = phrases_.log_crp_prob(); + for (CCRP_NoTable >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) { + llhr += log_p0(it->first); + //llhr += log_p0(it->first); + if (!isfinite(llh)) { + cerr << it->first << endl; + cerr << log_p0(it->first) << endl; + abort(); + } + } + return llh + llhr; + } + + void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) { + cerr << "Initializing...\n"; + z_.resize(corpus_.size()); + int tc = 0; + for (int i = 0; i < corpus_.size(); ++i) { + const vector& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector& z = z_[i]; + z.resize(ls); + int prev = 0; + for (int j = 0; j < ls; ++j) { + z[j] = rng->next() < 0.5; + if (j == last_pos) z[j] = true; // break phrase at the end of the sentence + if (z[j]) { + const vector p(line.begin() + prev, line.begin() + j + 1); + phrases_.increment(p); + //cerr << p << ": " << p0(p) << endl; + prev = j + 1; + gen_.increment(false); + ++tc; // remove + } + } + ++tc; + gen_.increment(true); // end of utterance + } + cerr << "TC: " << tc << endl; + cerr << "Initial LLH: " << llh() << endl; + cerr << "Sampling...\n"; + cerr << gen_ << endl; + for (int s = 1; s < samples; ++s) { + cerr << '.'; + if (s % 10 == 0) { + cerr << " [" << s; + if (hyp_inf) ResampleHyperparameters(rng); + cerr << " LLH=" << llh() << "]\n"; + vector z(z_[0].size(), 0); + //for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j]; + //SegCorpus::Write(corpus_[0], z, d); + } + for (int i = 0; i < corpus_.size(); ++i) { + const vector& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector& z = z_[i]; + int prev = 0; + for (int j = 0; j < last_pos; ++j) { // don't resample last position + int next = j+1; while(!z[next]) { ++next; } + const vector p1p2(line.begin() + prev, line.begin() + next + 1); + const vector p1(line.begin() + prev, line.begin() + j + 1); + const vector p2(line.begin() + j + 1, line.begin() + next + 1); + + if (z[j]) { + phrases_.decrement(p1); + phrases_.decrement(p2); + gen_.decrement(false); + gen_.decrement(false); + } else { + phrases_.decrement(p1p2); + gen_.decrement(false); + } + + const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_); + double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_); + phrases_.increment(p1); + gen_.increment(false); + d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_); + phrases_.decrement(p1); + gen_.decrement(false); + z[j] = rng->SelectSample(d1, d2); + + if (z[j]) { + phrases_.increment(p1); + phrases_.increment(p2); + gen_.increment(false); + gen_.increment(false); + prev = j + 1; + } else { + phrases_.increment(p1p2); + gen_.increment(false); + } + } + } + } +// cerr << endl << endl << gen_ << endl << phrases_ << endl; + cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl; + } + + void WriteCdecGrammarForCurrentSample(ostream* os) const { + CCRP_NoTable >::const_iterator it = phrases_.begin(); + for (; it != phrases_.end(); ++it) { + (*os) << "[X] ||| " << Join(' ', it->first) << " ||| " + << Join('_', it->first) << " ||| C=1 P=" + << log(phrases_.prob(it->first, p0(it->first))) << endl; + } + } + + double OOVUnigramLogProb() const { + vector x(1,99999999); + return log(phrases_.prob(x, p0(x))); + } + + void ResampleHyperparameters(MT19937* rng) { + phrases_.resample_hyperparameters(rng); + gen_.resample_hyperparameters(rng); + cerr << " " << phrases_.concentration(); + } + + CCRP_NoTable > phrases_; + CCRP_NoTable gen_; + vector > z_; // z_[i] is there a phrase boundary after the ith word + const vector >& corpus_; + const double uniform_word_; + const double gen_p0_; + const double p_end_; // in base length distribution, p of the end of a phrase + const bool use_poisson_; +}; + + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + shared_ptr prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector > corpus; + set vocab; + ReadCorpus(conf["input"].as(), &corpus, &vocab); + cerr << "Corpus size: " << corpus.size() << " sentences\n"; + cerr << "Vocabulary size: " << vocab.size() << " types\n"; + + UniphraseLM ulm(corpus, vocab, conf); + ulm.Sample(conf["samples"].as(), conf.count("no_hyperparameter_inference") == 0, &rng); + cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl; + + for (int i = 0; i < corpus.size(); ++i) + WriteSeg(corpus[i], ulm.z_[i], d); + + if (conf.count("write_cdec_grammar")) { + string fname = conf["write_cdec_grammar"].as(); + cerr << "Writing model to " << fname << " ...\n"; + WriteFile wf(fname); + ulm.WriteCdecGrammarForCurrentSample(wf.stream()); + } + + if (conf.count("write_cdec_weights")) { + string fname = conf["write_cdec_weights"].as(); + cerr << "Writing weights to " << fname << " .\n"; + WriteFile wf(fname); + ostream& os = *wf.stream(); + os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n"; + } + + + + return 0; +} + diff --git a/phrasinator/train-phrasinator.pl b/phrasinator/train-phrasinator.pl index de258caf..c50b8e68 100755 --- a/phrasinator/train-phrasinator.pl +++ b/phrasinator/train-phrasinator.pl @@ -5,7 +5,7 @@ use Getopt::Long; use File::Spec qw (rel2abs); my $DECODER = "$script_dir/../decoder/cdec"; -my $TRAINER = "$script_dir/gibbs_train_plm"; +my $TRAINER = "$script_dir/gibbs_train_plm_notables"; die "Can't find $TRAINER" unless -f $TRAINER; die "Can't execute $TRAINER" unless -x $TRAINER; -- cgit v1.2.3