summaryrefslogtreecommitdiff
path: root/phrasinator
diff options
context:
space:
mode:
Diffstat (limited to 'phrasinator')
-rw-r--r--phrasinator/Makefile.am10
-rw-r--r--phrasinator/ccrp_nt.h170
-rw-r--r--phrasinator/gibbs_train_plm.notables.cc335
-rwxr-xr-xphrasinator/train-phrasinator.pl2
4 files changed, 515 insertions, 2 deletions
diff --git a/phrasinator/Makefile.am b/phrasinator/Makefile.am
index 0b15a250..aba98601 100644
--- a/phrasinator/Makefile.am
+++ b/phrasinator/Makefile.am
@@ -1,6 +1,14 @@
-bin_PROGRAMS = gibbs_train_plm
+bin_PROGRAMS = gibbs_train_plm gibbs_train_plm_notables
+
+#head_bigram_model
+
+gibbs_train_plm_notables_SOURCES = gibbs_train_plm.notables.cc
+gibbs_train_plm_notables_LDADD = $(top_srcdir)/utils/libutils.a -lz
gibbs_train_plm_SOURCES = gibbs_train_plm.cc
gibbs_train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz
+#head_bigram_model_SOURCES = head_bigram_model.cc
+#head_bigram_model_LDADD = $(top_srcdir)/utils/libutils.a -lz
+
AM_CPPFLAGS = -funroll-loops -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/phrasinator/ccrp_nt.h b/phrasinator/ccrp_nt.h
new file mode 100644
index 00000000..811bce73
--- /dev/null
+++ b/phrasinator/ccrp_nt.h
@@ -0,0 +1,170 @@
+#ifndef _CCRP_NT_H_
+#define _CCRP_NT_H_
+
+#include <numeric>
+#include <cassert>
+#include <cmath>
+#include <list>
+#include <iostream>
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include "sampler.h"
+#include "slice_sampler.h"
+
+// Chinese restaurant process (Pitman-Yor parameters) with table tracking.
+
+template <typename Dish, typename DishHash = boost::hash<Dish> >
+class CCRP_NoTable {
+ public:
+ explicit CCRP_NoTable(double conc) :
+ num_customers_(),
+ concentration_(conc),
+ concentration_prior_shape_(std::numeric_limits<double>::quiet_NaN()),
+ concentration_prior_rate_(std::numeric_limits<double>::quiet_NaN()) {}
+
+ CCRP_NoTable(double c_shape, double c_rate, double c = 10.0) :
+ num_customers_(),
+ concentration_(c),
+ concentration_prior_shape_(c_shape),
+ concentration_prior_rate_(c_rate) {}
+
+ double concentration() const { return concentration_; }
+
+ bool has_concentration_prior() const {
+ return !std::isnan(concentration_prior_shape_);
+ }
+
+ void clear() {
+ num_customers_ = 0;
+ custs_.clear();
+ }
+
+ unsigned num_customers() const {
+ return num_customers_;
+ }
+
+ unsigned num_customers(const Dish& dish) const {
+ const typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator it = custs_.find(dish);
+ if (it == custs_.end()) return 0;
+ return it->second;
+ }
+
+ int increment(const Dish& dish) {
+ int table_diff = 0;
+ if (++custs_[dish] == 1)
+ table_diff = 1;
+ ++num_customers_;
+ return table_diff;
+ }
+
+ int decrement(const Dish& dish) {
+ int table_diff = 0;
+ int nc = --custs_[dish];
+ if (nc == 0) {
+ custs_.erase(dish);
+ table_diff = -1;
+ } else if (nc < 0) {
+ std::cerr << "Dish counts dropped below zero for: " << dish << std::endl;
+ abort();
+ }
+ --num_customers_;
+ return table_diff;
+ }
+
+ double prob(const Dish& dish, const double& p0) const {
+ const unsigned at_table = num_customers(dish);
+ return (at_table + p0 * concentration_) / (num_customers_ + concentration_);
+ }
+
+ double logprob(const Dish& dish, const double& logp0) const {
+ const unsigned at_table = num_customers(dish);
+ return log(at_table + exp(logp0 + log(concentration_))) - log(num_customers_ + concentration_);
+ }
+
+ double log_crp_prob() const {
+ return log_crp_prob(concentration_);
+ }
+
+ static double log_gamma_density(const double& x, const double& shape, const double& rate) {
+ assert(x >= 0.0);
+ assert(shape > 0.0);
+ assert(rate > 0.0);
+ const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape);
+ return lp;
+ }
+
+ // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process
+ // does not include P_0's
+ double log_crp_prob(const double& concentration) const {
+ double lp = 0.0;
+ if (has_concentration_prior())
+ lp += log_gamma_density(concentration, concentration_prior_shape_, concentration_prior_rate_);
+ assert(lp <= 0.0);
+ if (num_customers_) {
+ lp += lgamma(concentration) - lgamma(concentration + num_customers_) +
+ custs_.size() * log(concentration);
+ assert(std::isfinite(lp));
+ for (typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator it = custs_.begin();
+ it != custs_.end(); ++it) {
+ lp += lgamma(it->second);
+ }
+ }
+ assert(std::isfinite(lp));
+ return lp;
+ }
+
+ void resample_hyperparameters(MT19937* rng, const unsigned nloop = 5, const unsigned niterations = 10) {
+ assert(has_concentration_prior());
+ ConcentrationResampler cr(*this);
+ for (int iter = 0; iter < nloop; ++iter) {
+ concentration_ = slice_sampler1d(cr, concentration_, *rng, 0.0,
+ std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations);
+ }
+ }
+
+ struct ConcentrationResampler {
+ ConcentrationResampler(const CCRP_NoTable& crp) : crp_(crp) {}
+ const CCRP_NoTable& crp_;
+ double operator()(const double& proposed_concentration) const {
+ return crp_.log_crp_prob(proposed_concentration);
+ }
+ };
+
+ void Print(std::ostream* out) const {
+ (*out) << "DP(alpha=" << concentration_ << ") customers=" << num_customers_ << std::endl;
+ int cc = 0;
+ for (typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator it = custs_.begin();
+ it != custs_.end(); ++it) {
+ (*out) << " " << it->first << "(" << it->second << " eating)";
+ ++cc;
+ if (cc > 10) { (*out) << " ..."; break; }
+ }
+ (*out) << std::endl;
+ }
+
+ unsigned num_customers_;
+ std::tr1::unordered_map<Dish, unsigned, DishHash> custs_;
+
+ typedef typename std::tr1::unordered_map<Dish, unsigned, DishHash>::const_iterator const_iterator;
+ const_iterator begin() const {
+ return custs_.begin();
+ }
+ const_iterator end() const {
+ return custs_.end();
+ }
+
+ double concentration_;
+
+ // optional gamma prior on concentration_ (NaN if no prior)
+ double concentration_prior_shape_;
+ double concentration_prior_rate_;
+};
+
+template <typename T,typename H>
+std::ostream& operator<<(std::ostream& o, const CCRP_NoTable<T,H>& c) {
+ c.Print(&o);
+ return o;
+}
+
+#endif
diff --git a/phrasinator/gibbs_train_plm.notables.cc b/phrasinator/gibbs_train_plm.notables.cc
new file mode 100644
index 00000000..4b431b90
--- /dev/null
+++ b/phrasinator/gibbs_train_plm.notables.cc
@@ -0,0 +1,335 @@
+#include <iostream>
+#include <tr1/memory>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "filelib.h"
+#include "dict.h"
+#include "sampler.h"
+#include "ccrp.h"
+#include "ccrp_nt.h"
+
+using namespace std;
+using namespace std::tr1;
+namespace po = boost::program_options;
+
+Dict d; // global dictionary
+
+string Join(char joiner, const vector<int>& phrase) {
+ ostringstream os;
+ for (int i = 0; i < phrase.size(); ++i) {
+ if (i > 0) os << joiner;
+ os << d.Convert(phrase[i]);
+ }
+ return os.str();
+}
+
+template <typename BType>
+void WriteSeg(const vector<int>& line, const vector<BType>& label, const Dict& d) {
+ assert(line.size() == label.size());
+ assert(label.back());
+ int prev = 0;
+ int cur = 0;
+ while (cur < line.size()) {
+ if (label[cur]) {
+ if (prev) cout << ' ';
+ cout << "{{";
+ for (int i = prev; i <= cur; ++i)
+ cout << (i == prev ? "" : " ") << d.Convert(line[i]);
+ cout << "}}:" << label[cur];
+ prev = cur + 1;
+ }
+ ++cur;
+ }
+ cout << endl;
+}
+
+ostream& operator<<(ostream& os, const vector<int>& phrase) {
+ for (int i = 0; i < phrase.size(); ++i)
+ os << (i == 0 ? "" : " ") << d.Convert(phrase[i]);
+ return os;
+}
+
+struct UnigramLM {
+ explicit UnigramLM(const string& fname) {
+ ifstream in(fname.c_str());
+ assert(in);
+ }
+
+ double logprob(int word) const {
+ assert(word < freqs_.size());
+ return freqs_[word];
+ }
+
+ vector<double> freqs_;
+};
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples")
+ ("input,i",po::value<string>(),"Read file from")
+ ("random_seed,S",po::value<uint32_t>(), "Random seed")
+ ("write_cdec_grammar,g", po::value<string>(), "Write cdec grammar to this file")
+ ("write_cdec_weights,w", po::value<string>(), "Write cdec weights to this file")
+ ("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion")
+ ("no_hyperparameter_inference,N", "Disable hyperparameter inference");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || (conf->count("input") == 0)) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void ReadCorpus(const string& filename, vector<vector<int> >* c, set<int>* vocab) {
+ c->clear();
+ istream* in;
+ if (filename == "-")
+ in = &cin;
+ else
+ in = new ifstream(filename.c_str());
+ assert(*in);
+ string line;
+ while(*in) {
+ getline(*in, line);
+ if (line.empty() && !*in) break;
+ c->push_back(vector<int>());
+ vector<int>& v = c->back();
+ d.ConvertWhitespaceDelimitedLine(line, &v);
+ for (int i = 0; i < v.size(); ++i) vocab->insert(v[i]);
+ }
+ if (in != &cin) delete in;
+}
+
+double log_poisson(unsigned x, const double& lambda) {
+ assert(lambda > 0.0);
+ return log(lambda) * x - lgamma(x + 1) - lambda;
+}
+
+struct UniphraseLM {
+ UniphraseLM(const vector<vector<int> >& corpus,
+ const set<int>& vocab,
+ const po::variables_map& conf) :
+ phrases_(1,1),
+ gen_(1,1),
+ corpus_(corpus),
+ uniform_word_(1.0 / vocab.size()),
+ gen_p0_(0.5),
+ p_end_(0.5),
+ use_poisson_(conf.count("poisson_length") > 0) {}
+
+ double p0(const vector<int>& phrase) const {
+ static vector<double> p0s(10000, 0.0);
+ assert(phrase.size() < 10000);
+ double& p = p0s[phrase.size()];
+ if (p) return p;
+ p = exp(log_p0(phrase));
+ if (!p) {
+ cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits<double>::min()\n";
+ p = std::numeric_limits<double>::min();
+ }
+ return p;
+ }
+
+ double log_p0(const vector<int>& phrase) const {
+ double len_logprob;
+ if (use_poisson_)
+ len_logprob = log_poisson(phrase.size(), 1.0);
+ else
+ len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_);
+ return log(uniform_word_) * phrase.size() + len_logprob;
+ }
+
+ double llh() const {
+ double llh = gen_.log_crp_prob();
+ llh += log(gen_p0_) + log(1 - gen_p0_);
+ double llhr = phrases_.log_crp_prob();
+ for (CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) {
+ llhr += log_p0(it->first);
+ //llhr += log_p0(it->first);
+ if (!isfinite(llh)) {
+ cerr << it->first << endl;
+ cerr << log_p0(it->first) << endl;
+ abort();
+ }
+ }
+ return llh + llhr;
+ }
+
+ void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) {
+ cerr << "Initializing...\n";
+ z_.resize(corpus_.size());
+ int tc = 0;
+ for (int i = 0; i < corpus_.size(); ++i) {
+ const vector<int>& line = corpus_[i];
+ const int ls = line.size();
+ const int last_pos = ls - 1;
+ vector<bool>& z = z_[i];
+ z.resize(ls);
+ int prev = 0;
+ for (int j = 0; j < ls; ++j) {
+ z[j] = rng->next() < 0.5;
+ if (j == last_pos) z[j] = true; // break phrase at the end of the sentence
+ if (z[j]) {
+ const vector<int> p(line.begin() + prev, line.begin() + j + 1);
+ phrases_.increment(p);
+ //cerr << p << ": " << p0(p) << endl;
+ prev = j + 1;
+ gen_.increment(false);
+ ++tc; // remove
+ }
+ }
+ ++tc;
+ gen_.increment(true); // end of utterance
+ }
+ cerr << "TC: " << tc << endl;
+ cerr << "Initial LLH: " << llh() << endl;
+ cerr << "Sampling...\n";
+ cerr << gen_ << endl;
+ for (int s = 1; s < samples; ++s) {
+ cerr << '.';
+ if (s % 10 == 0) {
+ cerr << " [" << s;
+ if (hyp_inf) ResampleHyperparameters(rng);
+ cerr << " LLH=" << llh() << "]\n";
+ vector<int> z(z_[0].size(), 0);
+ //for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j];
+ //SegCorpus::Write(corpus_[0], z, d);
+ }
+ for (int i = 0; i < corpus_.size(); ++i) {
+ const vector<int>& line = corpus_[i];
+ const int ls = line.size();
+ const int last_pos = ls - 1;
+ vector<bool>& z = z_[i];
+ int prev = 0;
+ for (int j = 0; j < last_pos; ++j) { // don't resample last position
+ int next = j+1; while(!z[next]) { ++next; }
+ const vector<int> p1p2(line.begin() + prev, line.begin() + next + 1);
+ const vector<int> p1(line.begin() + prev, line.begin() + j + 1);
+ const vector<int> p2(line.begin() + j + 1, line.begin() + next + 1);
+
+ if (z[j]) {
+ phrases_.decrement(p1);
+ phrases_.decrement(p2);
+ gen_.decrement(false);
+ gen_.decrement(false);
+ } else {
+ phrases_.decrement(p1p2);
+ gen_.decrement(false);
+ }
+
+ const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_);
+ double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_);
+ phrases_.increment(p1);
+ gen_.increment(false);
+ d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_);
+ phrases_.decrement(p1);
+ gen_.decrement(false);
+ z[j] = rng->SelectSample(d1, d2);
+
+ if (z[j]) {
+ phrases_.increment(p1);
+ phrases_.increment(p2);
+ gen_.increment(false);
+ gen_.increment(false);
+ prev = j + 1;
+ } else {
+ phrases_.increment(p1p2);
+ gen_.increment(false);
+ }
+ }
+ }
+ }
+// cerr << endl << endl << gen_ << endl << phrases_ << endl;
+ cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl;
+ }
+
+ void WriteCdecGrammarForCurrentSample(ostream* os) const {
+ CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin();
+ for (; it != phrases_.end(); ++it) {
+ (*os) << "[X] ||| " << Join(' ', it->first) << " ||| "
+ << Join('_', it->first) << " ||| C=1 P="
+ << log(phrases_.prob(it->first, p0(it->first))) << endl;
+ }
+ }
+
+ double OOVUnigramLogProb() const {
+ vector<int> x(1,99999999);
+ return log(phrases_.prob(x, p0(x)));
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ phrases_.resample_hyperparameters(rng);
+ gen_.resample_hyperparameters(rng);
+ cerr << " " << phrases_.concentration();
+ }
+
+ CCRP_NoTable<vector<int> > phrases_;
+ CCRP_NoTable<bool> gen_;
+ vector<vector<bool> > z_; // z_[i] is there a phrase boundary after the ith word
+ const vector<vector<int> >& corpus_;
+ const double uniform_word_;
+ const double gen_p0_;
+ const double p_end_; // in base length distribution, p of the end of a phrase
+ const bool use_poisson_;
+};
+
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ shared_ptr<MT19937> prng;
+ if (conf.count("random_seed"))
+ prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ prng.reset(new MT19937);
+ MT19937& rng = *prng;
+
+ vector<vector<int> > corpus;
+ set<int> vocab;
+ ReadCorpus(conf["input"].as<string>(), &corpus, &vocab);
+ cerr << "Corpus size: " << corpus.size() << " sentences\n";
+ cerr << "Vocabulary size: " << vocab.size() << " types\n";
+
+ UniphraseLM ulm(corpus, vocab, conf);
+ ulm.Sample(conf["samples"].as<unsigned>(), conf.count("no_hyperparameter_inference") == 0, &rng);
+ cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl;
+
+ for (int i = 0; i < corpus.size(); ++i)
+ WriteSeg(corpus[i], ulm.z_[i], d);
+
+ if (conf.count("write_cdec_grammar")) {
+ string fname = conf["write_cdec_grammar"].as<string>();
+ cerr << "Writing model to " << fname << " ...\n";
+ WriteFile wf(fname);
+ ulm.WriteCdecGrammarForCurrentSample(wf.stream());
+ }
+
+ if (conf.count("write_cdec_weights")) {
+ string fname = conf["write_cdec_weights"].as<string>();
+ cerr << "Writing weights to " << fname << " .\n";
+ WriteFile wf(fname);
+ ostream& os = *wf.stream();
+ os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n";
+ }
+
+
+
+ return 0;
+}
+
diff --git a/phrasinator/train-phrasinator.pl b/phrasinator/train-phrasinator.pl
index de258caf..c50b8e68 100755
--- a/phrasinator/train-phrasinator.pl
+++ b/phrasinator/train-phrasinator.pl
@@ -5,7 +5,7 @@ use Getopt::Long;
use File::Spec qw (rel2abs);
my $DECODER = "$script_dir/../decoder/cdec";
-my $TRAINER = "$script_dir/gibbs_train_plm";
+my $TRAINER = "$script_dir/gibbs_train_plm_notables";
die "Can't find $TRAINER" unless -f $TRAINER;
die "Can't execute $TRAINER" unless -x $TRAINER;