summaryrefslogtreecommitdiff
path: root/gi/markov_al/ml.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/markov_al/ml.cc')
-rw-r--r--gi/markov_al/ml.cc470
1 files changed, 470 insertions, 0 deletions
diff --git a/gi/markov_al/ml.cc b/gi/markov_al/ml.cc
new file mode 100644
index 00000000..1e71edd6
--- /dev/null
+++ b/gi/markov_al/ml.cc
@@ -0,0 +1,470 @@
+#include <iostream>
+#include <tr1/unordered_map>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/functional.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "tdict.h"
+#include "filelib.h"
+#include "sampler.h"
+#include "ccrp_onetable.h"
+#include "array2d.h"
+
+using namespace std;
+using namespace std::tr1;
+namespace po = boost::program_options;
+
+void PrintTopCustomers(const CCRP_OneTable<WordID>& crp) {
+ for (CCRP_OneTable<WordID>::const_iterator it = crp.begin(); it != crp.end(); ++it) {
+ cerr << " " << TD::Convert(it->first) << " = " << it->second << endl;
+ }
+}
+
+void PrintAlignment(const vector<WordID>& src, const vector<WordID>& trg, const vector<unsigned char>& a) {
+ cerr << TD::GetString(src) << endl << TD::GetString(trg) << endl;
+ Array2D<bool> al(src.size(), trg.size());
+ for (int i = 0; i < a.size(); ++i)
+ if (a[i] != 255) al(a[i], i) = true;
+ cerr << al << endl;
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples")
+ ("input,i",po::value<string>(),"Read parallel data from")
+ ("random_seed,S",po::value<uint32_t>(), "Random seed");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ ifstream config((*conf)["config"].as<string>().c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || (conf->count("input") == 0)) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+struct Unigram;
+struct Bigram {
+ Bigram() : trg(), cond() {}
+ Bigram(WordID prev, WordID cur, WordID t) : trg(t) { cond.first = prev; cond.second = cur; }
+ const pair<WordID,WordID>& ConditioningPair() const {
+ return cond;
+ }
+ WordID& prev_src() { return cond.first; }
+ WordID& cur_src() { return cond.second; }
+ const WordID& prev_src() const { return cond.first; }
+ const WordID& cur_src() const { return cond.second; }
+ WordID trg;
+ private:
+ pair<WordID, WordID> cond;
+};
+
+struct Unigram {
+ Unigram() : cur_src(), trg() {}
+ Unigram(WordID s, WordID t) : cur_src(s), trg(t) {}
+ WordID cur_src;
+ WordID trg;
+};
+
+ostream& operator<<(ostream& os, const Bigram& b) {
+ os << "( " << TD::Convert(b.trg) << " | " << TD::Convert(b.prev_src()) << " , " << TD::Convert(b.cur_src()) << " )";
+ return os;
+}
+
+ostream& operator<<(ostream& os, const Unigram& u) {
+ os << "( " << TD::Convert(u.trg) << " | " << TD::Convert(u.cur_src) << " )";
+ return os;
+}
+
+bool operator==(const Bigram& a, const Bigram& b) {
+ return a.trg == b.trg && a.cur_src() == b.cur_src() && a.prev_src() == b.prev_src();
+}
+
+bool operator==(const Unigram& a, const Unigram& b) {
+ return a.trg == b.trg && a.cur_src == b.cur_src;
+}
+
+size_t hash_value(const Bigram& b) {
+ size_t h = boost::hash_value(b.prev_src());
+ boost::hash_combine(h, boost::hash_value(b.cur_src()));
+ boost::hash_combine(h, boost::hash_value(b.trg));
+ return h;
+}
+
+size_t hash_value(const Unigram& u) {
+ size_t h = boost::hash_value(u.cur_src);
+ boost::hash_combine(h, boost::hash_value(u.trg));
+ return h;
+}
+
+void ReadParallelCorpus(const string& filename,
+ vector<vector<WordID> >* f,
+ vector<vector<WordID> >* e,
+ set<WordID>* vocab_f,
+ set<WordID>* vocab_e) {
+ f->clear();
+ e->clear();
+ vocab_f->clear();
+ vocab_e->clear();
+ istream* in;
+ if (filename == "-")
+ in = &cin;
+ else
+ in = new ifstream(filename.c_str());
+ assert(*in);
+ string line;
+ const WordID kDIV = TD::Convert("|||");
+ vector<WordID> tmp;
+ while(*in) {
+ getline(*in, line);
+ if (line.empty() && !*in) break;
+ e->push_back(vector<int>());
+ f->push_back(vector<int>());
+ vector<int>& le = e->back();
+ vector<int>& lf = f->back();
+ tmp.clear();
+ TD::ConvertSentence(line, &tmp);
+ bool isf = true;
+ for (unsigned i = 0; i < tmp.size(); ++i) {
+ const int cur = tmp[i];
+ if (isf) {
+ if (kDIV == cur) { isf = false; } else {
+ lf.push_back(cur);
+ vocab_f->insert(cur);
+ }
+ } else {
+ assert(cur != kDIV);
+ le.push_back(cur);
+ vocab_e->insert(cur);
+ }
+ }
+ assert(isf == false);
+ }
+ if (in != &cin) delete in;
+}
+
+struct UnigramModel {
+ UnigramModel(size_t src_voc_size, size_t trg_voc_size) :
+ unigrams(TD::NumWords() + 1, CCRP_OneTable<WordID>(1,1,1,1)),
+ p0(1.0 / trg_voc_size) {}
+
+ void increment(const Bigram& b) {
+ unigrams[b.cur_src()].increment(b.trg);
+ }
+
+ void decrement(const Bigram& b) {
+ unigrams[b.cur_src()].decrement(b.trg);
+ }
+
+ double prob(const Bigram& b) const {
+ const double q0 = unigrams[b.cur_src()].prob(b.trg, p0);
+ return q0;
+ }
+
+ double LogLikelihood() const {
+ double llh = 0;
+ for (unsigned i = 0; i < unigrams.size(); ++i) {
+ const CCRP_OneTable<WordID>& crp = unigrams[i];
+ if (crp.num_customers() > 0) {
+ llh += crp.log_crp_prob();
+ llh += crp.num_tables() * log(p0);
+ }
+ }
+ return llh;
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ for (unsigned i = 0; i < unigrams.size(); ++i)
+ unigrams[i].resample_hyperparameters(rng);
+ }
+
+ vector<CCRP_OneTable<WordID> > unigrams; // unigrams[src].prob(trg, p0) = p(trg|src)
+
+ const double p0;
+};
+
+struct BigramModel {
+ BigramModel(size_t src_voc_size, size_t trg_voc_size) :
+ unigrams(TD::NumWords() + 1, CCRP_OneTable<WordID>(1,1,1,1)),
+ p0(1.0 / trg_voc_size) {}
+
+ void increment(const Bigram& b) {
+ BigramMap::iterator it = bigrams.find(b.ConditioningPair());
+ if (it == bigrams.end()) {
+ it = bigrams.insert(make_pair(b.ConditioningPair(), CCRP_OneTable<WordID>(1,1,1,1))).first;
+ }
+ if (it->second.increment(b.trg))
+ unigrams[b.cur_src()].increment(b.trg);
+ }
+
+ void decrement(const Bigram& b) {
+ BigramMap::iterator it = bigrams.find(b.ConditioningPair());
+ assert(it != bigrams.end());
+ if (it->second.decrement(b.trg)) {
+ unigrams[b.cur_src()].decrement(b.trg);
+ if (it->second.num_customers() == 0)
+ bigrams.erase(it);
+ }
+ }
+
+ double prob(const Bigram& b) const {
+ const double q0 = unigrams[b.cur_src()].prob(b.trg, p0);
+ const BigramMap::const_iterator it = bigrams.find(b.ConditioningPair());
+ if (it == bigrams.end()) return q0;
+ return it->second.prob(b.trg, q0);
+ }
+
+ double LogLikelihood() const {
+ double llh = 0;
+ for (unsigned i = 0; i < unigrams.size(); ++i) {
+ const CCRP_OneTable<WordID>& crp = unigrams[i];
+ if (crp.num_customers() > 0) {
+ llh += crp.log_crp_prob();
+ llh += crp.num_tables() * log(p0);
+ }
+ }
+ for (BigramMap::const_iterator it = bigrams.begin(); it != bigrams.end(); ++it) {
+ const CCRP_OneTable<WordID>& crp = it->second;
+ const WordID cur_src = it->first.second;
+ llh += crp.log_crp_prob();
+ for (CCRP_OneTable<WordID>::const_iterator bit = crp.begin(); bit != crp.end(); ++bit) {
+ llh += log(unigrams[cur_src].prob(bit->second, p0));
+ }
+ }
+ return llh;
+ }
+
+ void ResampleHyperparameters(MT19937* rng) {
+ for (unsigned i = 0; i < unigrams.size(); ++i)
+ unigrams[i].resample_hyperparameters(rng);
+ for (BigramMap::iterator it = bigrams.begin(); it != bigrams.end(); ++it)
+ it->second.resample_hyperparameters(rng);
+ }
+
+ typedef unordered_map<pair<WordID,WordID>, CCRP_OneTable<WordID>, boost::hash<pair<WordID,WordID> > > BigramMap;
+ BigramMap bigrams; // bigrams[(src-1,src)].prob(trg, q0) = p(trg|src,src-1)
+ vector<CCRP_OneTable<WordID> > unigrams; // unigrams[src].prob(trg, p0) = p(trg|src)
+
+ const double p0;
+};
+
+struct BigramAlignmentModel {
+ BigramAlignmentModel(size_t src_voc_size, size_t trg_voc_size) : bigrams(TD::NumWords() + 1, CCRP_OneTable<WordID>(1,1,1,1)), p0(1.0 / src_voc_size) {}
+ void increment(WordID prev, WordID next) {
+ bigrams[prev].increment(next); // hierarchy?
+ }
+ void decrement(WordID prev, WordID next) {
+ bigrams[prev].decrement(next); // hierarchy?
+ }
+ double prob(WordID prev, WordID next) {
+ return bigrams[prev].prob(next, p0);
+ }
+ double LogLikelihood() const {
+ double llh = 0;
+ for (unsigned i = 0; i < bigrams.size(); ++i) {
+ const CCRP_OneTable<WordID>& crp = bigrams[i];
+ if (crp.num_customers() > 0) {
+ llh += crp.log_crp_prob();
+ llh += crp.num_tables() * log(p0);
+ }
+ }
+ return llh;
+ }
+
+ vector<CCRP_OneTable<WordID> > bigrams; // bigrams[prev].prob(next, p0) = p(next|prev)
+ const double p0;
+};
+
+struct Alignment {
+ vector<unsigned char> a;
+};
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const unsigned samples = conf["samples"].as<unsigned>();
+
+ boost::shared_ptr<MT19937> prng;
+ if (conf.count("random_seed"))
+ prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
+ else
+ prng.reset(new MT19937);
+ MT19937& rng = *prng;
+
+ vector<vector<WordID> > corpuse, corpusf;
+ set<WordID> vocabe, vocabf;
+ cerr << "Reading corpus...\n";
+ ReadParallelCorpus(conf["input"].as<string>(), &corpusf, &corpuse, &vocabf, &vocabe);
+ cerr << "F-corpus size: " << corpusf.size() << " sentences\t (" << vocabf.size() << " word types)\n";
+ cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n";
+ assert(corpusf.size() == corpuse.size());
+ const size_t corpus_len = corpusf.size();
+ const WordID kNULL = TD::Convert("<eps>");
+ const WordID kBOS = TD::Convert("<s>");
+ const WordID kEOS = TD::Convert("</s>");
+ Bigram TT(kBOS, TD::Convert("我"), TD::Convert("i"));
+ Bigram TT2(kBOS, TD::Convert("要"), TD::Convert("i"));
+
+ UnigramModel model(vocabf.size(), vocabe.size());
+ vector<Alignment> alignments(corpus_len);
+ for (unsigned ci = 0; ci < corpus_len; ++ci) {
+ const vector<WordID>& src = corpusf[ci];
+ const vector<WordID>& trg = corpuse[ci];
+ vector<unsigned char>& alg = alignments[ci].a;
+ alg.resize(trg.size());
+ int lenp1 = src.size() + 1;
+ WordID prev_src = kBOS;
+ for (int j = 0; j < trg.size(); ++j) {
+ int samp = lenp1 * rng.next();
+ --samp;
+ if (samp < 0) samp = 255;
+ alg[j] = samp;
+ WordID cur_src = (samp == 255 ? kNULL : src[alg[j]]);
+ Bigram b(prev_src, cur_src, trg[j]);
+ model.increment(b);
+ prev_src = cur_src;
+ }
+ Bigram b(prev_src, kEOS, kEOS);
+ model.increment(b);
+ }
+ cerr << "Initial LLH: " << model.LogLikelihood() << endl;
+
+ SampleSet<double> ss;
+ for (unsigned si = 0; si < 50; ++si) {
+ for (unsigned ci = 0; ci < corpus_len; ++ci) {
+ const vector<WordID>& src = corpusf[ci];
+ const vector<WordID>& trg = corpuse[ci];
+ vector<unsigned char>& alg = alignments[ci].a;
+ WordID prev_src = kBOS;
+ for (unsigned j = 0; j < trg.size(); ++j) {
+ unsigned char& a_j = alg[j];
+ WordID cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]);
+ Bigram b(prev_src, cur_e_a_j, trg[j]);
+ //cerr << "DEC: " << b << "\t" << nextb << endl;
+ model.decrement(b);
+ ss.clear();
+ for (unsigned i = 0; i <= src.size(); ++i) {
+ const WordID cur_src = (i ? src[i-1] : kNULL);
+ b.cur_src() = cur_src;
+ ss.add(model.prob(b));
+ }
+ int sampled_a_j = rng.SelectSample(ss);
+ a_j = (sampled_a_j ? sampled_a_j - 1 : 255);
+ cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]);
+ b.cur_src() = cur_e_a_j;
+ //cerr << "INC: " << b << "\t" << nextb << endl;
+ model.increment(b);
+ prev_src = cur_e_a_j;
+ }
+ }
+ cerr << '.' << flush;
+ if (si % 10 == 9) {
+ cerr << "[LLH prev=" << model.LogLikelihood();
+ //model.ResampleHyperparameters(&rng);
+ cerr << " new=" << model.LogLikelihood() << "]\n";
+ //pair<WordID,WordID> xx = make_pair(kBOS, TD::Convert("我"));
+ //PrintTopCustomers(model.bigrams.find(xx)->second);
+ cerr << "p(" << TT << ") = " << model.prob(TT) << endl;
+ cerr << "p(" << TT2 << ") = " << model.prob(TT2) << endl;
+ PrintAlignment(corpusf[0], corpuse[0], alignments[0].a);
+ }
+ }
+ {
+ // MODEL 2
+ BigramModel model(vocabf.size(), vocabe.size());
+ BigramAlignmentModel amodel(vocabf.size(), vocabe.size());
+ for (unsigned ci = 0; ci < corpus_len; ++ci) {
+ const vector<WordID>& src = corpusf[ci];
+ const vector<WordID>& trg = corpuse[ci];
+ vector<unsigned char>& alg = alignments[ci].a;
+ WordID prev_src = kBOS;
+ for (int j = 0; j < trg.size(); ++j) {
+ WordID cur_src = (alg[j] == 255 ? kNULL : src[alg[j]]);
+ Bigram b(prev_src, cur_src, trg[j]);
+ model.increment(b);
+ amodel.increment(prev_src, cur_src);
+ prev_src = cur_src;
+ }
+ amodel.increment(prev_src, kEOS);
+ Bigram b(prev_src, kEOS, kEOS);
+ model.increment(b);
+ }
+ cerr << "Initial LLH: " << model.LogLikelihood() << " " << amodel.LogLikelihood() << endl;
+
+ SampleSet<double> ss;
+ for (unsigned si = 0; si < samples; ++si) {
+ for (unsigned ci = 0; ci < corpus_len; ++ci) {
+ const vector<WordID>& src = corpusf[ci];
+ const vector<WordID>& trg = corpuse[ci];
+ vector<unsigned char>& alg = alignments[ci].a;
+ WordID prev_src = kBOS;
+ for (unsigned j = 0; j < trg.size(); ++j) {
+ unsigned char& a_j = alg[j];
+ WordID cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]);
+ Bigram b(prev_src, cur_e_a_j, trg[j]);
+ WordID next_src = kEOS;
+ WordID next_trg = kEOS;
+ if (j < (trg.size() - 1)) {
+ next_src = (alg[j+1] == 255 ? kNULL : src[alg[j + 1]]);
+ next_trg = trg[j + 1];
+ }
+ Bigram nextb(cur_e_a_j, next_src, next_trg);
+ //cerr << "DEC: " << b << "\t" << nextb << endl;
+ model.decrement(b);
+ model.decrement(nextb);
+ amodel.decrement(prev_src, cur_e_a_j);
+ amodel.decrement(cur_e_a_j, next_src);
+ ss.clear();
+ for (unsigned i = 0; i <= src.size(); ++i) {
+ const WordID cur_src = (i ? src[i-1] : kNULL);
+ b.cur_src() = cur_src;
+ ss.add(model.prob(b) * model.prob(nextb) * amodel.prob(prev_src, cur_src) * amodel.prob(cur_src, next_src));
+ //cerr << log(ss[ss.size() - 1]) << "\t" << b << endl;
+ }
+ int sampled_a_j = rng.SelectSample(ss);
+ a_j = (sampled_a_j ? sampled_a_j - 1 : 255);
+ cur_e_a_j = (a_j == 255 ? kNULL : src[a_j]);
+ b.cur_src() = cur_e_a_j;
+ nextb.prev_src() = cur_e_a_j;
+ //cerr << "INC: " << b << "\t" << nextb << endl;
+ //exit(1);
+ model.increment(b);
+ model.increment(nextb);
+ amodel.increment(prev_src, cur_e_a_j);
+ amodel.increment(cur_e_a_j, next_src);
+ prev_src = cur_e_a_j;
+ }
+ }
+ cerr << '.' << flush;
+ if (si % 10 == 9) {
+ cerr << "[LLH prev=" << (model.LogLikelihood() + amodel.LogLikelihood());
+ //model.ResampleHyperparameters(&rng);
+ cerr << " new=" << model.LogLikelihood() << "]\n";
+ pair<WordID,WordID> xx = make_pair(kBOS, TD::Convert("我"));
+ cerr << "p(" << TT << ") = " << model.prob(TT) << endl;
+ cerr << "p(" << TT2 << ") = " << model.prob(TT2) << endl;
+ pair<WordID,WordID> xx2 = make_pair(kBOS, TD::Convert("要"));
+ PrintTopCustomers(model.bigrams.find(xx)->second);
+ //PrintTopCustomers(amodel.bigrams[TD::Convert("<s>")]);
+ //PrintTopCustomers(model.unigrams[TD::Convert("<eps>")]);
+ PrintAlignment(corpusf[0], corpuse[0], alignments[0].a);
+ }
+ }
+ }
+ return 0;
+}
+