diff options
author | Chris Dyer <prguest11@taipan.cs> | 2012-02-27 02:19:34 +0000 |
---|---|---|
committer | Chris Dyer <prguest11@taipan.cs> | 2012-02-27 02:19:34 +0000 |
commit | dc2b2fc395ad496851f723c4da59181445c07047 (patch) | |
tree | 7774b4a528ae467abfc702d9a7f6510f60cb3b48 | |
parent | e279f1fd267bc18763fa8ff456462c5e677689e9 (diff) |
generic bayesian cfg learner with a bunch of cfg grammar types
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | decoder/trule.cc | 16 | ||||
-rw-r--r-- | gi/pf/Makefile.am | 4 | ||||
-rw-r--r-- | gi/pf/learn_cfg.cc (renamed from gi/pf/hierolm.cc) | 175 |
4 files changed, 134 insertions, 62 deletions
@@ -57,6 +57,7 @@ training/mpi_extract_reachable klm/lm/build_binary extools/extractor_monolingual gi/pf/.deps +gi/pf/learn_cfg gi/pf/brat gi/pf/cbgi gi/pf/dpnaive diff --git a/decoder/trule.cc b/decoder/trule.cc index 40235542..141b8faa 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -232,16 +232,6 @@ void TRule::ComputeArity() { arity_ = 1 - min; } -static string AnonymousStrVar(int i) { - string res("[v]"); - if(!(i <= 0 && i >= -8)) { - cerr << "Can't handle more than 9 non-terminals: index=" << (-i) << endl; - abort(); - } - res[1] = '1' - i; - return res; -} - string TRule::AsString(bool verbose) const { ostringstream os; int idx = 0; @@ -259,15 +249,11 @@ string TRule::AsString(bool verbose) const { } } os << " ||| "; - if (idx > 9) { - cerr << "Too many non-terminals!\n partial: " << os.str() << endl; - exit(1); - } for (int i =0; i<e_.size(); ++i) { if (i) os << ' '; const WordID& w = e_[i]; if (w < 1) - os << AnonymousStrVar(w); + os << '[' << (1-w) << ']'; else os << TD::Convert(w); } diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index ed5b6fd3..0cf0bc63 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp hierolm +bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp learn_cfg noinst_LIBRARIES = libpf.a libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc @@ -9,7 +9,7 @@ align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc itg_SOURCES = itg.cc -hierolm_SOURCES = hierolm.cc +learn_cfg_SOURCES = learn_cfg.cc condnaive_SOURCES = condnaive.cc diff --git a/gi/pf/hierolm.cc b/gi/pf/learn_cfg.cc index afb12fef..3d202816 100644 --- a/gi/pf/hierolm.cc +++ b/gi/pf/learn_cfg.cc @@ -25,12 +25,21 @@ using namespace tr1; namespace po = boost::program_options; shared_ptr<MT19937> prng; +vector<int> nt_vocab; +vector<int> nt_id_to_index; +static unsigned kMAX_RULE_SIZE = 0; +static unsigned kMAX_ARITY = 0; +static bool kALLOW_MIXED = true; // allow rules with mixed terminals and NTs void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples") ("input,i",po::value<string>(),"Read parallel data from") + ("max_rule_size,m", po::value<unsigned>()->default_value(0), "Maximum rule size (0 for unlimited)") + ("max_arity,a", po::value<unsigned>()->default_value(0), "Maximum number of nonterminals in a rule (0 for unlimited)") + ("no_mixed_rules,M", "Do not mix terminals and nonterminals in a rule RHS") + ("nonterminals,n", po::value<unsigned>()->default_value(1), "Size of nonterminal vocabulary") ("random_seed,S",po::value<uint32_t>(), "Random seed"); po::options_description clo("Command line options"); clo.add_options() @@ -53,9 +62,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -void ReadCorpus(const string& filename, - vector<vector<WordID> >* e, - set<WordID>* vocab_e) { +unsigned ReadCorpus(const string& filename, + vector<vector<WordID> >* e, + set<WordID>* vocab_e) { e->clear(); vocab_e->clear(); istream* in; @@ -65,6 +74,7 @@ void ReadCorpus(const string& filename, in = new ifstream(filename.c_str()); assert(*in); string line; + unsigned toks = 0; while(*in) { getline(*in, line); if (line.empty() && !*in) break; @@ -73,8 +83,10 @@ void ReadCorpus(const string& filename, TD::ConvertSentence(line, &le); for (unsigned i = 0; i < le.size(); ++i) vocab_e->insert(le[i]); + toks += le.size(); } if (in != &cin) delete in; + return toks; } struct Grid { @@ -107,29 +119,32 @@ struct BaseRuleModel { }; struct HieroLMModel { - explicit HieroLMModel(unsigned vocab_size) : p0(vocab_size), x(1,1,1,1) {} + explicit HieroLMModel(unsigned vocab_size, unsigned num_nts = 1) : p0(vocab_size, num_nts), nts(num_nts, CCRP<TRule>(1,1,1,1)) {} prob_t Prob(const TRule& r) const { - return x.probT<prob_t>(r, p0(r)); + return nts[nt_id_to_index[-r.lhs_]].probT<prob_t>(r, p0(r)); } int Increment(const TRule& r, MT19937* rng) { - return x.incrementT<prob_t>(r, p0(r), rng); + return nts[nt_id_to_index[-r.lhs_]].incrementT<prob_t>(r, p0(r), rng); // return x.increment(r); } int Decrement(const TRule& r, MT19937* rng) { - return x.decrement(r, rng); + return nts[nt_id_to_index[-r.lhs_]].decrement(r, rng); //return x.decrement(r); } prob_t Likelihood() const { - prob_t p; - p.logeq(x.log_crp_prob()); - for (CCRP<TRule>::const_iterator it = x.begin(); it != x.end(); ++it) { - prob_t tp = p0(it->first); - tp.poweq(it->second.table_counts_.size()); - p *= tp; + prob_t p = prob_t::One(); + for (unsigned i = 0; i < nts.size(); ++i) { + prob_t q; q.logeq(nts[i].log_crp_prob()); + p *= q; + for (CCRP<TRule>::const_iterator it = nts[i].begin(); it != nts[i].end(); ++it) { + prob_t tp = p0(it->first); + tp.poweq(it->second.table_counts_.size()); + p *= tp; + } } //for (CCRP_OneTable<TRule>::const_iterator it = x.begin(); it != x.end(); ++it) // p *= p0(it->first); @@ -137,12 +152,13 @@ struct HieroLMModel { } void ResampleHyperparameters(MT19937* rng) { - x.resample_hyperparameters(rng); - cerr << " d=" << x.discount() << ", alpha=" << x.concentration() << endl; + for (unsigned i = 0; i < nts.size(); ++i) + nts[i].resample_hyperparameters(rng); + cerr << " d=" << nts[0].discount() << ", alpha=" << nts[0].concentration() << endl; } const BaseRuleModel p0; - CCRP<TRule> x; + vector<CCRP<TRule> > nts; //CCRP_OneTable<TRule> x; }; @@ -152,24 +168,29 @@ HieroLMModel* plm; struct NPGrammarIter : public GrammarIter, public RuleBin { NPGrammarIter() : arity() { tofreelist.push_back(this); } - NPGrammarIter(const TRulePtr& inr, const int a, int symbol) : arity(a + (symbol < 0 ? 1 : 0)) { + NPGrammarIter(const TRulePtr& inr, const int a, int symbol) : arity(a) { if (inr) { r.reset(new TRule(*inr)); } else { - static const int kLHS = -TD::Convert("X"); r.reset(new TRule); - r->lhs_ = kLHS; } TRule& rr = *r; + rr.lhs_ = nt_vocab[0]; rr.f_.push_back(symbol); rr.e_.push_back(symbol < 0 ? (1-int(arity)) : symbol); tofreelist.push_back(this); } + inline static unsigned NextArity(int cur_a, int symbol) { + return cur_a + (symbol <= 0 ? 1 : 0); + } virtual int GetNumRules() const { - if (r) return 1; else return 0; + if (r) return nt_vocab.size(); else return 0; } - virtual TRulePtr GetIthRule(int) const { - return r; + virtual TRulePtr GetIthRule(int i) const { + if (i == 0) return r; + TRulePtr nr(new TRule(*r)); + nr->lhs_ = nt_vocab[i]; + return nr; } virtual int Arity() const { return arity; @@ -178,7 +199,18 @@ struct NPGrammarIter : public GrammarIter, public RuleBin { if (!r) return NULL; else return this; } virtual const GrammarIter* Extend(int symbol) const { - return new NPGrammarIter(r, arity, symbol); + const int next_arity = NextArity(arity, symbol); + if (kMAX_ARITY && next_arity > kMAX_ARITY) + return NULL; + if (!kALLOW_MIXED && r) { + bool t1 = r->f_.front() <= 0; + bool t2 = symbol <= 0; + if (t1 != t2) return NULL; + } + if (!kMAX_RULE_SIZE || !r || (r->f_.size() < kMAX_RULE_SIZE)) + return new NPGrammarIter(r, next_arity, symbol); + else + return NULL; } const unsigned char arity; TRulePtr r; @@ -190,12 +222,15 @@ struct NPGrammar : public Grammar { } }; -void SampleDerivation(const Hypergraph& hg, MT19937* rng, vector<unsigned>* sampled_deriv, HieroLMModel* plm) { - HieroLMModel& lm = *plm; +prob_t TotalProb(const Hypergraph& hg) { + return Inside<prob_t, EdgeProb>(hg); +} + +void SampleDerivation(const Hypergraph& hg, MT19937* rng, vector<unsigned>* sampled_deriv) { vector<prob_t> node_probs; - const prob_t total_prob = Inside<prob_t, EdgeProb>(hg, &node_probs); + Inside<prob_t, EdgeProb>(hg, &node_probs); queue<unsigned> q; - q.push(hg.nodes_.size() - 3); + q.push(hg.nodes_.size() - 2); while(!q.empty()) { unsigned cur_node_id = q.front(); // cerr << "NODE=" << cur_node_id << endl; @@ -248,53 +283,95 @@ void DecrementDerivation(const Hypergraph& hg, const vector<unsigned>& d, HieroL int main(int argc, char** argv) { po::variables_map conf; + + InitCommandLine(argc, argv, &conf); + nt_vocab.resize(conf["nonterminals"].as<unsigned>()); + assert(nt_vocab.size() > 0); + assert(nt_vocab.size() < 26); + { + string nt = "X"; + for (unsigned i = 0; i < nt_vocab.size(); ++i) { + if (nt_vocab.size() > 1) nt[0] = ('A' + i); + int pid = TD::Convert(nt); + nt_vocab[i] = -pid; + if (pid >= nt_id_to_index.size()) { + nt_id_to_index.resize(pid + 1, -1); + } + nt_id_to_index[pid] = i; + } + } vector<GrammarPtr> grammars; grammars.push_back(GrammarPtr(new NPGrammar)); - InitCommandLine(argc, argv, &conf); const unsigned samples = conf["samples"].as<unsigned>(); + kMAX_RULE_SIZE = conf["max_rule_size"].as<unsigned>(); + if (kMAX_RULE_SIZE == 1) { + cerr << "Invalid maximum rule size: must be 0 or >1\n"; + return 1; + } + kMAX_ARITY = conf["max_arity"].as<unsigned>(); + if (kMAX_ARITY == 1) { + cerr << "Invalid maximum arity: must be 0 or >1\n"; + return 1; + } + kALLOW_MIXED = !conf.count("no_mixed_rules"); if (conf.count("random_seed")) prng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); else prng.reset(new MT19937); MT19937& rng = *prng; - vector<vector<WordID> > corpuse; set<WordID> vocabe; cerr << "Reading corpus...\n"; - ReadCorpus(conf["input"].as<string>(), &corpuse, &vocabe); + const unsigned toks = ReadCorpus(conf["input"].as<string>(), &corpuse, &vocabe); cerr << "E-corpus size: " << corpuse.size() << " sentences\t (" << vocabe.size() << " word types)\n"; - HieroLMModel lm(vocabe.size()); + HieroLMModel lm(vocabe.size(), nt_vocab.size()); plm = &lm; - ExhaustiveBottomUpParser parser("X", grammars); + ExhaustiveBottomUpParser parser(TD::Convert(-nt_vocab[0]), grammars); Hypergraph hg; - const int kX = -TD::Convert("X"); + const int kGoal = -TD::Convert("Goal"); const int kLP = FD::Convert("LogProb"); SparseVector<double> v; v.set_value(kLP, 1.0); vector<vector<unsigned> > derivs(corpuse.size()); + vector<Lattice> cl(corpuse.size()); + for (int ci = 0; ci < corpuse.size(); ++ci) { + vector<int>& src = corpuse[ci]; + Lattice& lat = cl[ci]; + lat.resize(src.size()); + for (unsigned i = 0; i < src.size(); ++i) + lat[i].push_back(LatticeArc(src[i], 0.0, 1)); + } for (int SS=0; SS < samples; ++SS) { + const bool is_last = ((samples - 1) == SS); + prob_t dlh = prob_t::One(); for (int ci = 0; ci < corpuse.size(); ++ci) { - vector<int>& src = corpuse[ci]; - Lattice lat(src.size()); - for (unsigned i = 0; i < src.size(); ++i) - lat[i].push_back(LatticeArc(src[i], 0.0, 1)); + const vector<int>& src = corpuse[ci]; + const Lattice& lat = cl[ci]; cerr << TD::GetString(src) << endl; hg.clear(); parser.Parse(lat, &hg); // exhaustive parse - DecrementDerivation(hg, derivs[ci], &lm, &rng); + vector<unsigned>& d = derivs[ci]; + if (!is_last) DecrementDerivation(hg, d, &lm, &rng); for (unsigned i = 0; i < hg.edges_.size(); ++i) { TRule& r = *hg.edges_[i].rule_; - if (r.lhs_ == kX) + if (r.lhs_ == kGoal) + hg.edges_[i].edge_prob_ = prob_t::One(); + else hg.edges_[i].edge_prob_ = lm.Prob(r); } - vector<unsigned> d; - SampleDerivation(hg, &rng, &d, &lm); - derivs[ci] = d; - IncrementDerivation(hg, derivs[ci], &lm, &rng); - if (tofreelist.size() > 100000) { + if (!is_last) { + d.clear(); + SampleDerivation(hg, &rng, &d); + IncrementDerivation(hg, derivs[ci], &lm, &rng); + } else { + prob_t p = TotalProb(hg); + dlh *= p; + cerr << " p(sentence) = " << log(p) << "\t" << log(dlh) << endl; + } + if (tofreelist.size() > 200000) { cerr << "Freeing ... "; for (unsigned i = 0; i < tofreelist.size(); ++i) delete tofreelist[i]; @@ -302,8 +379,16 @@ int main(int argc, char** argv) { cerr << "Freed.\n"; } } - cerr << "LLH=" << lm.Likelihood() << endl; + double llh = log(lm.Likelihood()); + cerr << "LLH=" << llh << "\tENTROPY=" << (-llh / log(2) / toks) << "\tPPL=" << pow(2, -llh / log(2) / toks) << endl; + if (SS % 10 == 9) lm.ResampleHyperparameters(&rng); + if (is_last) { + double z = log(dlh); + cerr << "TOTAL_PROB=" << z << "\tENTROPY=" << (-z / log(2) / toks) << "\tPPL=" << pow(2, -z / log(2) / toks) << endl; + } } + for (unsigned i = 0; i < nt_vocab.size(); ++i) + cerr << lm.nts[i] << endl; return 0; } |