diff options
-rw-r--r-- | decoder/cdec.cc | 3 | ||||
-rw-r--r-- | decoder/decoder.cc | 8 | ||||
-rw-r--r-- | phrasinator/Makefile.am | 8 | ||||
-rw-r--r-- | phrasinator/gibbs_train_plm.cc | 315 | ||||
-rwxr-xr-x | phrasinator/train-phrasinator.pl | 89 | ||||
-rw-r--r-- | phrasinator/train_plm.cc | 5 | ||||
-rw-r--r-- | utils/dict.h | 24 |
7 files changed, 439 insertions, 13 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 5b930c69..b47ab380 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -3,6 +3,7 @@ #include "filelib.h" #include "decoder.h" #include "ff_register.h" +#include "verbose.h" using namespace std; @@ -11,7 +12,7 @@ int main(int argc, char** argv) { Decoder decoder(argc, argv); const string input = decoder.GetConf()["input"].as<string>(); - cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; + if (!SILENT) cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; ReadFile in_read(input); istream *in = in_read.stream(); assert(*in); diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 478a1cf3..3d818429 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -453,7 +453,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream cout << endl; exit(0); } - ShowBanner(); + if (conf.count("quiet")) + SetSilent(true); + if (!SILENT) ShowBanner(); } if (conf.count("show_config")) // special handling needed because we only want to notify() once. show_config=true; @@ -467,6 +469,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream po::store(po::parse_config_file(*conff, dconfig_options), conf); } } + if (conf.count("quiet")) + SetSilent(true); if (cfg) po::store(po::parse_config_file(*cfg, dconfig_options), conf); po::notify(conf); if (show_config && !cfg_files.empty()) { @@ -482,8 +486,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream cerr<<" "<<argv[i]; cerr << "\n\n"; } - if (conf.count("quiet")) - SetSilent(true); if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; diff --git a/phrasinator/Makefile.am b/phrasinator/Makefile.am index c9b2a513..0b15a250 100644 --- a/phrasinator/Makefile.am +++ b/phrasinator/Makefile.am @@ -1,6 +1,6 @@ -bin_PROGRAMS = train_plm +bin_PROGRAMS = gibbs_train_plm -train_plm_SOURCES = train_plm.cc -train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz +gibbs_train_plm_SOURCES = gibbs_train_plm.cc +gibbs_train_plm_LDADD = $(top_srcdir)/utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval +AM_CPPFLAGS = -funroll-loops -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/phrasinator/gibbs_train_plm.cc b/phrasinator/gibbs_train_plm.cc new file mode 100644 index 00000000..29b3d7ea --- /dev/null +++ b/phrasinator/gibbs_train_plm.cc @@ -0,0 +1,315 @@ +#include <iostream> +#include <tr1/memory> + +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "ccrp.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +Dict d; // global dictionary + +string Join(char joiner, const vector<int>& phrase) { + ostringstream os; + for (int i = 0; i < phrase.size(); ++i) { + if (i > 0) os << joiner; + os << d.Convert(phrase[i]); + } + return os.str(); +} + +ostream& operator<<(ostream& os, const vector<int>& phrase) { + for (int i = 0; i < phrase.size(); ++i) + os << (i == 0 ? "" : " ") << d.Convert(phrase[i]); + return os; +} + +struct UnigramLM { + explicit UnigramLM(const string& fname) { + ifstream in(fname.c_str()); + assert(in); + } + + double logprob(int word) const { + assert(word < freqs_.size()); + return freqs_[word]; + } + + vector<double> freqs_; +}; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples") + ("input,i",po::value<string>(),"Read file from") + ("random_seed,S",po::value<uint32_t>(), "Random seed") + ("write_cdec_grammar,g", po::value<string>(), "Write cdec grammar to this file") + ("write_cdec_weights,w", po::value<string>(), "Write cdec weights to this file") + ("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion") + ("no_hyperparameter_inference,N", "Disable hyperparameter inference"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value<string>(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ReadCorpus(const string& filename, vector<vector<int> >* c, set<int>* vocab) { + c->clear(); + istream* in; + if (filename == "-") + in = &cin; + else + in = new ifstream(filename.c_str()); + assert(*in); + string line; + while(*in) { + getline(*in, line); + if (line.empty() && !*in) break; + c->push_back(vector<int>()); + vector<int>& v = c->back(); + d.ConvertWhitespaceDelimitedLine(line, &v); + for (int i = 0; i < v.size(); ++i) vocab->insert(v[i]); + } + if (in != &cin) delete in; +} + +double log_poisson(unsigned x, const double& lambda) { + assert(lambda > 0.0); + return log(lambda) * x - lgamma(x + 1) - lambda; +} + +struct UniphraseLM { + UniphraseLM(const vector<vector<int> >& corpus, + const set<int>& vocab, + const po::variables_map& conf) : + phrases_(1,1,1,1), + gen_(1,1,1,1), + corpus_(corpus), + uniform_word_(1.0 / vocab.size()), + gen_p0_(0.5), + p_end_(0.5), + use_poisson_(conf.count("poisson_length") > 0) {} + + double p0(const vector<int>& phrase) const { + static vector<double> p0s(10000, 0.0); + assert(phrase.size() < 10000); + double& p = p0s[phrase.size()]; + if (p) return p; + p = exp(log_p0(phrase)); + if (!p) { + cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits<double>::min()\n"; + p = std::numeric_limits<double>::min(); + } + return p; + } + + double log_p0(const vector<int>& phrase) const { + double len_logprob; + if (use_poisson_) + len_logprob = log_poisson(phrase.size(), 1.0); + else + len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_); + return log(uniform_word_) * phrase.size() + len_logprob; + } + + double llh() const { + double llh = gen_.log_crp_prob(); + llh += gen_.num_tables(false) * log(gen_p0_) + + gen_.num_tables(true) * log(1 - gen_p0_); + double llhr = phrases_.log_crp_prob(); + for (CCRP<vector<int> >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) { + llhr += phrases_.num_tables(it->first) * log_p0(it->first); + //llhr += log_p0(it->first); + if (!isfinite(llh)) { + cerr << it->first << endl; + cerr << log_p0(it->first) << endl; + abort(); + } + } + return llh + llhr; + } + + void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) { + cerr << "Initializing...\n"; + z_.resize(corpus_.size()); + int tc = 0; + for (int i = 0; i < corpus_.size(); ++i) { + const vector<int>& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector<bool>& z = z_[i]; + z.resize(ls); + int prev = 0; + for (int j = 0; j < ls; ++j) { + z[j] = rng->next() < 0.5; + if (j == last_pos) z[j] = true; // break phrase at the end of the sentence + if (z[j]) { + const vector<int> p(line.begin() + prev, line.begin() + j + 1); + phrases_.increment(p, p0(p), rng); + //cerr << p << ": " << p0(p) << endl; + prev = j + 1; + gen_.increment(false, gen_p0_, rng); + ++tc; // remove + } + } + ++tc; + gen_.increment(true, 1.0 - gen_p0_, rng); // end of utterance + } + cerr << "TC: " << tc << endl; + cerr << "Initial LLH: " << llh() << endl; + cerr << "Sampling...\n"; + cerr << gen_ << endl; + for (int s = 1; s < samples; ++s) { + cerr << '.'; + if (s % 10 == 0) { + cerr << " [" << s; + if (hyp_inf) ResampleHyperparameters(rng); + cerr << " LLH=" << llh() << "]\n"; + vector<int> z(z_[0].size(), 0); + //for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j]; + //SegCorpus::Write(corpus_[0], z, d); + } + for (int i = 0; i < corpus_.size(); ++i) { + const vector<int>& line = corpus_[i]; + const int ls = line.size(); + const int last_pos = ls - 1; + vector<bool>& z = z_[i]; + int prev = 0; + for (int j = 0; j < last_pos; ++j) { // don't resample last position + int next = j+1; while(!z[next]) { ++next; } + const vector<int> p1p2(line.begin() + prev, line.begin() + next + 1); + const vector<int> p1(line.begin() + prev, line.begin() + j + 1); + const vector<int> p2(line.begin() + j + 1, line.begin() + next + 1); + + if (z[j]) { + phrases_.decrement(p1, rng); + phrases_.decrement(p2, rng); + gen_.decrement(false, rng); + gen_.decrement(false, rng); + } else { + phrases_.decrement(p1p2, rng); + gen_.decrement(false, rng); + } + + const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_); + double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_); + phrases_.increment(p1, p0(p1), rng); + gen_.increment(false, gen_p0_, rng); + d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_); + phrases_.decrement(p1, rng); + gen_.decrement(false, rng); + z[j] = rng->SelectSample(d1, d2); + + if (z[j]) { + phrases_.increment(p1, p0(p1), rng); + phrases_.increment(p2, p0(p2), rng); + gen_.increment(false, gen_p0_, rng); + gen_.increment(false, gen_p0_, rng); + prev = j + 1; + } else { + phrases_.increment(p1p2, p0(p1p2), rng); + gen_.increment(false, gen_p0_, rng); + } + } + } + } +// cerr << endl << endl << gen_ << endl << phrases_ << endl; + cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl; + } + + void WriteCdecGrammarForCurrentSample(ostream* os) const { + CCRP<vector<int> >::const_iterator it = phrases_.begin(); + for (; it != phrases_.end(); ++it) { + (*os) << "[X] ||| " << Join(' ', it->first) << " ||| " + << Join('_', it->first) << " ||| C=1 P=" + << log(phrases_.prob(it->first, p0(it->first))) << endl; + } + } + + double OOVUnigramLogProb() const { + vector<int> x(1,99999999); + return log(phrases_.prob(x, p0(x))); + } + + void ResampleHyperparameters(MT19937* rng) { + phrases_.resample_hyperparameters(rng); + gen_.resample_hyperparameters(rng); + cerr << " d=" << phrases_.discount() << ",c=" << phrases_.concentration(); + } + + CCRP<vector<int> > phrases_; + CCRP<bool> gen_; + vector<vector<bool> > z_; // z_[i] is there a phrase boundary after the ith word + const vector<vector<int> >& corpus_; + const double uniform_word_; + const double gen_p0_; + const double p_end_; // in base length distribution, p of the end of a phrase + const bool use_poisson_; +}; + + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + shared_ptr<MT19937> prng; + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); + else + prng.reset(new MT19937); + MT19937& rng = *prng; + + vector<vector<int> > corpus; + set<int> vocab; + ReadCorpus(conf["input"].as<string>(), &corpus, &vocab); + cerr << "Corpus size: " << corpus.size() << " sentences\n"; + cerr << "Vocabulary size: " << vocab.size() << " types\n"; + + UniphraseLM ulm(corpus, vocab, conf); + ulm.Sample(conf["samples"].as<unsigned>(), conf.count("no_hyperparameter_inference") == 0, &rng); + cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl; + + for (int i = 0; i < corpus.size(); ++i) +// SegCorpus::Write(corpus[i], shmmlm.z_[i], d); + ; + if (conf.count("write_cdec_grammar")) { + string fname = conf["write_cdec_grammar"].as<string>(); + cerr << "Writing model to " << fname << " ...\n"; + WriteFile wf(fname); + ulm.WriteCdecGrammarForCurrentSample(wf.stream()); + } + + if (conf.count("write_cdec_weights")) { + string fname = conf["write_cdec_weights"].as<string>(); + cerr << "Writing weights to " << fname << " .\n"; + WriteFile wf(fname); + ostream& os = *wf.stream(); + os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n"; + } + + + + return 0; +} + diff --git a/phrasinator/train-phrasinator.pl b/phrasinator/train-phrasinator.pl new file mode 100755 index 00000000..de258caf --- /dev/null +++ b/phrasinator/train-phrasinator.pl @@ -0,0 +1,89 @@ +#!/usr/bin/perl -w +use strict; +my $script_dir; BEGIN { use Cwd qw/ abs_path cwd /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, $script_dir; } +use Getopt::Long; +use File::Spec qw (rel2abs); + +my $DECODER = "$script_dir/../decoder/cdec"; +my $TRAINER = "$script_dir/gibbs_train_plm"; + +die "Can't find $TRAINER" unless -f $TRAINER; +die "Can't execute $TRAINER" unless -x $TRAINER; + +if (!GetOptions( + "decoder=s" => \$DECODER, +)) { usage(); } + +die "Can't find $DECODER" unless -f $DECODER; +die "Can't execute $DECODER" unless -x $DECODER; +if (scalar @ARGV != 2) { usage(); } +my $INFILE = shift @ARGV; +my $OUTDIR = shift @ARGV; +$OUTDIR = File::Spec->rel2abs($OUTDIR); +print STDERR " Input file: $INFILE\n"; +print STDERR "Output directory: $OUTDIR\n"; +open F, "<$INFILE" or die "Failed to open $INFILE for reading: $!"; +close F; +die "Please remove existing directory $OUTDIR\n" if (-f $OUTDIR || -d $OUTDIR); + +my $CMD = "mkdir $OUTDIR"; +safesystem($CMD) or die "Failed to create directory $OUTDIR\n$!"; + +my $grammar="$OUTDIR/grammar.gz"; +my $weights="$OUTDIR/weights"; +$CMD = "$TRAINER -w $weights -g $grammar -i $INFILE"; +safesystem($CMD) or die "Failed to train model!\n"; +my $cdecini = "$OUTDIR/cdec.ini"; +open C, ">$cdecini" or die "Failed to open $cdecini for writing: $!"; + +print C <<EOINI; +quiet=true +formalism=scfg +grammar=$grammar +add_pass_through_rules=true +weights=$OUTDIR/weights +EOINI + +close C; + +print <<EOT; + +Model trained successfully. Text can be decoded into phrasal units with +the following command: + + $DECODER -c $OUTDIR/cdec.ini < FILE.TXT + +EOT +exit(0); + +sub usage { + print <<EOT; +Usage: $0 [options] INPUT.TXT OUTPUT-DIRECTORY + + Infers a phrasal segmentation model from the tokenized text in INPUT.TXT + and writes it to OUTPUT-DIRECTORY/ so that it can be applied to other + text or have its granularity altered. + +EOT + exit(1); +} + +sub safesystem { + print STDERR "Executing: @_\n"; + system(@_); + if ($? == -1) { + print STDERR "ERROR: Failed to execute: @_\n $!\n"; + exit(1); + } + elsif ($? & 127) { + printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n", + ($? & 127), ($? & 128) ? 'with' : 'without'; + exit(1); + } + else { + my $exitcode = $? >> 8; + print STDERR "Exit code: $exitcode\n" if $exitcode; + return ! $exitcode; + } +} + diff --git a/phrasinator/train_plm.cc b/phrasinator/train_plm.cc deleted file mode 100644 index bb41ad26..00000000 --- a/phrasinator/train_plm.cc +++ /dev/null @@ -1,5 +0,0 @@ -#include <iostream> - -int main(int argc, char** argv) { - return 0; -} diff --git a/utils/dict.h b/utils/dict.h index 348a97e3..75ea3def 100644 --- a/utils/dict.h +++ b/utils/dict.h @@ -21,6 +21,30 @@ class Dict { inline int max() const { return words_.size(); } + static bool is_ws(char x) { + return (x == ' ' || x == '\t'); + } + + inline void ConvertWhitespaceDelimitedLine(const std::string& line, std::vector<int>* out) { + size_t cur = 0; + size_t last = 0; + int state = 0; + out->clear(); + while(cur < line.size()) { + if (is_ws(line[cur++])) { + if (state == 0) continue; + out->push_back(Convert(line.substr(last, cur - last - 1))); + state = 0; + } else { + if (state == 1) continue; + last = cur - 1; + state = 1; + } + } + if (state == 1) + out->push_back(Convert(line.substr(last, cur - last))); + } + inline WordID Convert(const std::string& word, bool frozen = false) { Map::iterator i = d_.find(word); if (i == d_.end()) { |