From 7fd9fe26f00cf31a7b407364399d37b4eaf04eba Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 7 Mar 2012 20:25:53 -0500 Subject: lattice builder --- gi/pf/align-tl.cc | 334 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 gi/pf/align-tl.cc (limited to 'gi/pf/align-tl.cc') diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc new file mode 100644 index 00000000..0e0454e5 --- /dev/null +++ b/gi/pf/align-tl.cc @@ -0,0 +1,334 @@ +#include +#include +#include + +#include +#include +#include + +#include "array2d.h" +#include "base_distributions.h" +#include "monotonic_pseg.h" +#include "conditional_pseg.h" +#include "trule.h" +#include "tdict.h" +#include "stringlib.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "mfcr.h" +#include "corpus.h" +#include "ngram_base.h" +#include "transliterations.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("samples,s",po::value()->default_value(1000),"Number of samples") + ("input,i",po::value(),"Read parallel data from") + ("random_seed,S",po::value(), "Random seed"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || (conf->count("input") == 0)) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +shared_ptr prng; + +struct LexicalAlignment { + unsigned char src_index; + bool is_transliteration; + vector > derivation; +}; + +struct AlignedSentencePair { + vector src; + vector trg; + vector a; + Array2D posterior; +}; + +struct HierarchicalWordBase { + explicit HierarchicalWordBase(const unsigned vocab_e_size) : + base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {} + + void ResampleHyperparameters(MT19937* rng) { + r.resample_hyperparameters(rng); + } + + inline double logp0(const vector& s) const { + return Md::log_poisson(s.size(), 7.5) + s.size() * u0; + } + + // return p0 of rule.e_ + prob_t operator()(const TRule& rule) const { + v[0].logeq(logp0(rule.e_)); + return r.prob(rule.e_, v.begin(), l.begin()); + } + + void Increment(const TRule& rule) { + v[0].logeq(logp0(rule.e_)); + if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) { + base *= v[0] * l[0]; + } + } + + void Decrement(const TRule& rule) { + if (r.decrement(rule.e_, &*prng).count) { + base /= prob_t(exp(logp0(rule.e_))); + } + } + + prob_t Likelihood() const { + prob_t p; p.logeq(r.log_crp_prob()); + p *= base; + return p; + } + + void Summary() const { + cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << " (d=" << r.discount() << ",s=" << r.strength() << ')' << endl; + for (MFCR<1,vector >::const_iterator it = r.begin(); it != r.end(); ++it) + cerr << " " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables) " << TD::GetString(it->first) << endl; + } + + prob_t base; + MFCR<1,vector > r; + const double u0; + const vector l; + mutable vector v; +}; + +struct BasicLexicalAlignment { + explicit BasicLexicalAlignment(const vector >& lets, + const unsigned words_e, + const unsigned letters_e, + vector* corp) : + letters(lets), + corpus(*corp), + //up0(words_e), + //up0("en.chars.1gram", letters_e), + //up0("en.words.1gram"), + up0(letters_e), + //up0("en.chars.2gram"), + tmodel(up0) { + } + + void InstantiateRule(const WordID src, + const WordID trg, + TRule* rule) const { + static const WordID kX = TD::Convert("X") * -1; + rule->lhs_ = kX; + rule->e_ = letters[trg]; + rule->f_ = letters[src]; + } + + void InitializeRandom() { + const WordID kNULL = TD::Convert("NULL"); + cerr << "Initializing with random alignments ...\n"; + for (unsigned i = 0; i < corpus.size(); ++i) { + AlignedSentencePair& asp = corpus[i]; + asp.a.resize(asp.trg.size()); + for (unsigned j = 0; j < asp.trg.size(); ++j) { + const unsigned char a_j = prng->next() * (1 + asp.src.size()); + const WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + TRule r; + InstantiateRule(f_a_j, asp.trg[j], &r); + asp.a[j].is_transliteration = false; + asp.a[j].src_index = a_j; + if (tmodel.IncrementRule(r, &*prng)) + up0.Increment(r); + } + } + cerr << " LLH = " << Likelihood() << endl; + } + + prob_t Likelihood() const { + prob_t p = tmodel.Likelihood(); + p *= up0.Likelihood(); + return p; + } + + void ResampleHyperparemeters() { + tmodel.ResampleHyperparameters(&*prng); + up0.ResampleHyperparameters(&*prng); + cerr << " (base d=" << up0.r.discount() << ",s=" << up0.r.strength() << ")\n"; + } + + void ResampleCorpus(); + + const vector >& letters; // spelling dictionary + vector& corpus; + //PhraseConditionalUninformativeBase up0; + //PhraseConditionalUninformativeUnigramBase up0; + //UnigramWordBase up0; + //HierarchicalUnigramBase up0; + HierarchicalWordBase up0; + //CompletelyUniformBase up0; + //FixedNgramBase up0; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + MConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; + //ConditionalTranslationModel tmodel; +}; + +void BasicLexicalAlignment::ResampleCorpus() { + static const WordID kNULL = TD::Convert("NULL"); + for (unsigned i = 0; i < corpus.size(); ++i) { + AlignedSentencePair& asp = corpus[i]; + SampleSet ss; ss.resize(asp.src.size() + 1); + for (unsigned j = 0; j < asp.trg.size(); ++j) { + TRule r; + unsigned char& a_j = asp.a[j].src_index; + WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + InstantiateRule(f_a_j, asp.trg[j], &r); + if (tmodel.DecrementRule(r, &*prng)) + up0.Decrement(r); + + for (unsigned prop_a_j = 0; prop_a_j <= asp.src.size(); ++prop_a_j) { + const WordID prop_f = (prop_a_j ? asp.src[prop_a_j - 1] : kNULL); + InstantiateRule(prop_f, asp.trg[j], &r); + ss[prop_a_j] = tmodel.RuleProbability(r); + } + a_j = prng->SelectSample(ss); + f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + InstantiateRule(f_a_j, asp.trg[j], &r); + if (tmodel.IncrementRule(r, &*prng)) + up0.Increment(r); + } + } + cerr << " LLH = " << Likelihood() << endl; +} + +void ExtractLetters(const set& v, vector >* l, set* letset = NULL) { + for (set::const_iterator it = v.begin(); it != v.end(); ++it) { + vector& letters = (*l)[*it]; + if (letters.size()) continue; // if e and f have the same word + + const string& w = TD::Convert(*it); + + size_t cur = 0; + while (cur < w.size()) { + const size_t len = UTF8Len(w[cur]); + letters.push_back(TD::Convert(w.substr(cur, len))); + if (letset) letset->insert(letters.back()); + cur += len; + } + } +} + +void Debug(const AlignedSentencePair& asp) { + cerr << TD::GetString(asp.src) << endl << TD::GetString(asp.trg) << endl; + Array2D a(asp.src.size(), asp.trg.size()); + for (unsigned j = 0; j < asp.trg.size(); ++j) + if (asp.a[j].src_index) a(asp.a[j].src_index - 1, j) = true; + cerr << a << endl; +} + +void AddSample(AlignedSentencePair* asp) { + for (unsigned j = 0; j < asp->trg.size(); ++j) + asp->posterior(asp->a[j].src_index, j)++; +} + +void WriteAlignments(const AlignedSentencePair& asp) { + bool first = true; + for (unsigned j = 0; j < asp.trg.size(); ++j) { + int src_index = -1; + int mc = -1; + for (unsigned i = 0; i <= asp.src.size(); ++i) { + if (asp.posterior(i, j) > mc) { + mc = asp.posterior(i, j); + src_index = i; + } + } + + if (src_index) { + if (first) first = false; else cout << ' '; + cout << (src_index - 1) << '-' << j; + } + } + cout << endl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + + if (conf.count("random_seed")) + prng.reset(new MT19937(conf["random_seed"].as())); + else + prng.reset(new MT19937); +// MT19937& rng = *prng; + + vector > corpuse, corpusf; + set vocabe, vocabf; + corpus::ReadParallelCorpus(conf["input"].as(), &corpusf, &corpuse, &vocabf, &vocabe); + cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; + cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; + cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; + assert(corpusf.size() == corpuse.size()); + + vector corpus(corpuse.size()); + for (unsigned i = 0; i < corpuse.size(); ++i) { + corpus[i].src.swap(corpusf[i]); + corpus[i].trg.swap(corpuse[i]); + corpus[i].posterior.resize(corpus[i].src.size() + 1, corpus[i].trg.size()); + } + corpusf.clear(); corpuse.clear(); + + vocabf.insert(TD::Convert("NULL")); + vector > letters(TD::NumWords()); + set letset; + ExtractLetters(vocabe, &letters, &letset); + ExtractLetters(vocabf, &letters, NULL); + letters[TD::Convert("NULL")].clear(); + + Transliterations tl; + + // TODO CONFIGURE THIS + int min_trans_src = 4; + + cerr << "Initializing transliteration DPs ...\n"; + for (int i = 0; i < corpus.size(); ++i) { + const vector& src = corpus[i].src; + const vector& trg = corpus[i].trg; + cerr << '.' << flush; + if (i % 80 == 79) cerr << endl; + for (int j = 0; j < src.size(); ++j) { + const vector& src_let = letters[src[j]]; + for (int k = 0; k < trg.size(); ++k) { + const vector& trg_let = letters[trg[k]]; + if (src_let.size() < min_trans_src) + tl.Forbid(src[j], trg[k]); + else + tl.Initialize(src[j], src_let, trg[k], trg_let); + } + } + } + cerr << endl; + tl.GraphSummary(); + + return 0; +} -- cgit v1.2.3 From e2998fc79c9dd549b1c1bad537fdf1052276f82c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Mar 2012 01:46:32 -0500 Subject: simple context feature for tagger --- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 + decoder/ff_context.cc | 99 +++++++++++++++++++++++ decoder/ff_context.h | 23 ++++++ gi/pf/align-tl.cc | 6 +- gi/pf/reachability.cc | 2 + gi/pf/reachability.h | 6 +- gi/pf/transliterations.cc | 198 ++++++++++++++-------------------------------- gi/pf/transliterations.h | 5 +- 9 files changed, 194 insertions(+), 148 deletions(-) create mode 100644 decoder/ff_context.cc create mode 100644 decoder/ff_context.h (limited to 'gi/pf/align-tl.cc') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 30eaf04d..a00b18af 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -63,6 +63,7 @@ libcdec_a_SOURCES = \ ff.cc \ ff_rules.cc \ ff_wordset.cc \ + ff_context.cc \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 4ce5749e..b516c386 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -1,6 +1,7 @@ #include #include "ff.h" +#include "ff_context.h" #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" @@ -42,6 +43,7 @@ void register_feature_functions() { #endif ff_registry.Register("SpanFeatures", new FFFactory()); ff_registry.Register("NgramFeatures", new FFFactory()); + ff_registry.Register("RuleContextFeatures", new FFFactory()); ff_registry.Register("RuleIdentityFeatures", new FFFactory()); ff_registry.Register("SourceSyntaxFeatures", new FFFactory); ff_registry.Register("SourceSpanSizeFeatures", new FFFactory); diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc new file mode 100644 index 00000000..19f9a413 --- /dev/null +++ b/decoder/ff_context.cc @@ -0,0 +1,99 @@ +#include "ff_context.h" + +#include +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +RuleContextFeatures::RuleContextFeatures(const std::string& param) { + kSOS = TD::Convert(""); + kEOS = TD::Convert(""); + + // TODO param lets you pass in a string from the cdec.ini file +} + +void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) { + const Lattice& sl = smeta.GetSourceLattice(); + current_input.resize(sl.size()); + for (unsigned i = 0; i < sl.size(); ++i) { + if (sl[i].size() != 1) { + cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl; + abort(); + } + current_input[i] = sl[i][0].label; + } +} + +void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + const TRule& rule = *edge.rule_; + + if (rule.Arity() != 0 || // arity = 0, no nonterminals + rule.e_.size() != 1) return; // size = 1, predicted label is a single token + + + // you can see the current label "for free" + const WordID cur_label = rule.e_[0]; + // (if you want to see more labels, you have to be very careful, and muck + // about with contexts and ant_contexts) + + // but... you can look at as much of the source as you want! + const int from_src_index = edge.i_; // start of the span in the input being labeled + const int to_src_index = edge.j_; // end of the span in the input + // (note: in the case of tagging the size of the spans being labeled will + // always be 1, but in other formalisms, you can have bigger spans.) + + // this is the current token being labeled: + const WordID cur_input = current_input[from_src_index]; + + // let's get the previous token in the input (may be to the left of the start + // of the sentence!) + WordID prev_input = kSOS; + if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; } + // let's get the next token (may be to the left of the start of the sentence!) + WordID next_input = kEOS; + if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; } + + // now, build a feature string + ostringstream os; + // TD::Convert converts from the internal integer representation of a token + // to the actual token + os << "C1:" << TD::Convert(prev_input) << '_' + << TD::Convert(cur_input) << '|' << TD::Convert(cur_label); + // C1 is just to prevent a name clash + + // pick a value + double fval = 1.0; // can be any real value + + // add it to the feature vector FD::Convert converts the feature string to a + // feature int, Escape makes sure the feature string doesn't have any bad + // symbols that could confuse a parser somewhere + features->add_value(FD::Convert(Escape(os.str())), fval); + // that's it! + + // create more features if you like... +} + diff --git a/decoder/ff_context.h b/decoder/ff_context.h new file mode 100644 index 00000000..0d22b027 --- /dev/null +++ b/decoder/ff_context.h @@ -0,0 +1,23 @@ +#ifndef _FF_CONTEXT_H_ +#define _FF_CONTEXT_H_ + +#include +#include "ff.h" + +class RuleContextFeatures : public FeatureFunction { + public: + RuleContextFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + std::vector current_input; + WordID kSOS, kEOS; +}; + +#endif diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index 0e0454e5..6bb8c886 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -310,18 +310,16 @@ int main(int argc, char** argv) { // TODO CONFIGURE THIS int min_trans_src = 4; - cerr << "Initializing transliteration DPs ...\n"; + cerr << "Initializing transliteration graph structures ...\n"; for (int i = 0; i < corpus.size(); ++i) { const vector& src = corpus[i].src; const vector& trg = corpus[i].trg; - cerr << '.' << flush; - if (i % 80 == 79) cerr << endl; for (int j = 0; j < src.size(); ++j) { const vector& src_let = letters[src[j]]; for (int k = 0; k < trg.size(); ++k) { const vector& trg_let = letters[trg[k]]; if (src_let.size() < min_trans_src) - tl.Forbid(src[j], trg[k]); + tl.Forbid(src[j], src_let, trg[k], trg_let); else tl.Initialize(src[j], src_let, trg[k], trg_let); } diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index 73dd8d39..70fb76da 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -47,6 +47,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; int src_delta = i - prevs[k].prev_src_covered; edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + valid_deltas[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(make_pair(src_delta,j - prevs[k].prev_trg_covered)); short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; if (src_delta > msd) msd = src_delta; } @@ -56,6 +57,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras assert(!edges[0][0][0][1]); assert(!edges[0][0][0][0]); assert(max_src_delta[0][0] > 0); + cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n"; //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; //for (int i = 0; i < b[0][0].size(); ++i) { // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index 98450ec1..fb2f4965 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -12,12 +12,14 @@ // currently forbids 0 -> n and n -> 0 alignments struct Reachability { - boost::multi_array edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring? boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + boost::multi_array >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), - max_src_delta(boost::extents[srclen][trglen]) { + max_src_delta(boost::extents[srclen][trglen]), + valid_deltas(boost::extents[srclen][trglen]) { ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); } diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index 6e0c2e93..e29334fd 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -2,173 +2,92 @@ #include #include -#include -#include "grammar.h" -#include "bottom_up_parser.h" -#include "hg.h" -#include "hg_intersect.h" +#include "boost/shared_ptr.hpp" + #include "filelib.h" #include "ccrp.h" #include "m.h" -#include "lattice.h" -#include "verbose.h" +#include "reachability.h" using namespace std; using namespace std::tr1; -static WordID kX; -static int kMAX_SRC_SIZE = 0; -static vector > cur_trg_chunks; - -vector tlttofreelist; - -static void InitTargetChunks(int max_size, const vector& trg) { - cur_trg_chunks.clear(); - vector tmp; - unordered_set, boost::hash > > u; - for (int len = 1; len <= max_size; ++len) { - int end = trg.size() + 1; - end -= len; - for (int i = 0; i < end; ++i) { - tmp.clear(); - for (int j = 0; j < len; ++j) - tmp.push_back(trg[i + j]); - if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp); - } - } -} - -struct TransliterationGrammarIter : public GrammarIter, public RuleBin { - TransliterationGrammarIter() { tlttofreelist.push_back(this); } - TransliterationGrammarIter(const TRulePtr& inr, int symbol) { - if (inr) { - r.reset(new TRule(*inr)); - } else { - r.reset(new TRule); - } - TRule& rr = *r; - rr.lhs_ = kX; - rr.f_.push_back(symbol); - tlttofreelist.push_back(this); - } - virtual int GetNumRules() const { - if (!r) return 0; - return cur_trg_chunks.size(); - } - virtual TRulePtr GetIthRule(int i) const { - TRulePtr nr(new TRule(*r)); - nr->e_ = cur_trg_chunks[i]; - //cerr << nr->AsString() << endl; - return nr; - } - virtual int Arity() const { - return 0; - } - virtual const RuleBin* GetRules() const { - if (!r) return NULL; else return this; - } - virtual const GrammarIter* Extend(int symbol) const { - if (symbol <= 0) return NULL; - if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE) - return new TransliterationGrammarIter(r, symbol); - else - return NULL; - } - TRulePtr r; -}; - -struct TransliterationGrammar : public Grammar { - virtual const GrammarIter* GetRoot() const { - return new TransliterationGrammarIter; - } - virtual bool HasRuleForSpan(int, int, int distance) const { - return (distance < kMAX_SRC_SIZE); - } -}; - -struct TInfo { - TInfo() : initialized(false) {} +struct GraphStructure { + GraphStructure() : initialized(false) {} + boost::shared_ptr r; bool initialized; - Hypergraph lattice; // may be empty if transliteration is not possible - prob_t est_prob; // will be zero if not possible }; struct TransliterationsImpl { TransliterationsImpl() { - kX = TD::Convert("X")*-1; - kMAX_SRC_SIZE = 4; - grammars.push_back(GrammarPtr(new TransliterationGrammar)); - grammars.push_back(GrammarPtr(new GlueGrammar("S", "X"))); - SetSilent(true); } void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { - if (src >= graphs.size()) graphs.resize(src + 1); - if (graphs[src][trg].initialized) return; - int kMAX_TRG_SIZE = 4; - InitTargetChunks(kMAX_TRG_SIZE, trg_lets); - ExhaustiveBottomUpParser parser("S", grammars); - Lattice lat(src_lets.size()), tlat(trg_lets.size()); - for (unsigned i = 0; i < src_lets.size(); ++i) - lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1)); - for (unsigned i = 0; i < trg_lets.size(); ++i) - tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1)); - //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl; - //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl; - if (!parser.Parse(lat, &graphs[src][trg].lattice)) { - //cerr << "Failed to parse " << TD::GetString(src_lets) << endl; - abort(); - } - if (HG::Intersect(tlat, &graphs[src][trg].lattice)) { - graphs[src][trg].est_prob = prob_t(1e-4); + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + if (graphs[src_len][trg_len].initialized) return; + graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4)); + +#if 0 + if (HG::Intersect(tlat, &hg)) { + // TODO } else { - graphs[src][trg].lattice.clear(); - //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl; - graphs[src][trg].est_prob = prob_t::Zero(); + cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl; + hg.clear(); } - for (unsigned i = 0; i < tlttofreelist.size(); ++i) - delete tlttofreelist[i]; - tlttofreelist.clear(); //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; - graphs[src][trg].initialized = true; +#endif + graphs[src_len][trg_len].initialized = true; } - const prob_t& EstimateProbability(WordID src, WordID trg) const { - assert(src < graphs.size()); - const unordered_map& um = graphs[src]; - const unordered_map::const_iterator it = um.find(trg); - assert(it != um.end()); - assert(it->second.initialized); - return it->second.est_prob; + void Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + graphs[src_len][trg_len].r.reset(); + graphs[src_len][trg_len].initialized = true; } - void Forbid(WordID src, WordID trg) { - if (src >= graphs.size()) graphs.resize(src + 1); - graphs[src][trg].est_prob = prob_t::Zero(); - graphs[src][trg].initialized = true; + prob_t EstimateProbability(WordID s, const vector& src, WordID t, const vector& trg) const { + assert(src.size() < graphs.size()); + const vector& tv = graphs[src.size()]; + assert(trg.size() < tv.size()); + const GraphStructure& gs = tv[trg.size()]; + // TODO: do prob + return prob_t::Zero(); } void GraphSummary() const { - double tlp = 0; - int tt = 0; + double to = 0; + double tn = 0; + double tt = 0; for (int i = 0; i < graphs.size(); ++i) { - const unordered_map& um = graphs[i]; - unordered_map::const_iterator it; - for (it = um.begin(); it != um.end(); ++it) { - if (it->second.lattice.empty()) continue; - //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl; - tlp += log(it->second.lattice.NumberOfPaths()); + const vector& vt = graphs[i]; + for (int j = 0; j < vt.size(); ++j) { + const GraphStructure& gs = vt[j]; + if (!gs.r) continue; tt++; + for (int k = 0; k < i; ++k) { + for (int l = 0; l < j; ++l) { + size_t c = gs.r->valid_deltas[k][l].size(); + if (c) { + tn += 1; + to += c; + } + } + } } } - tlp /= tt; - cerr << "E[log paths] = " << tlp << endl; - cerr << "exp(E[log paths]) = " << exp(tlp) << endl; + cerr << " Average nodes = " << (tn / tt) << endl; + cerr << "Average out-degree = " << (to / tn) << endl; + cerr << " Unique structures = " << tt << endl; } - vector > graphs; - vector grammars; + vector > graphs; // graphs[src_len][trg_len] }; Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} @@ -178,16 +97,15 @@ void Transliterations::Initialize(WordID src, const vector& src_lets, Wo pimpl_->Initialize(src, src_lets, trg, trg_lets); } -prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const { - return pimpl_->EstimateProbability(src,trg); +prob_t Transliterations::EstimateProbability(WordID s, const vector& src, WordID t, const vector& trg) const { + return pimpl_->EstimateProbability(s, src,t, trg); } -void Transliterations::Forbid(WordID src, WordID trg) { - pimpl_->Forbid(src, trg); +void Transliterations::Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + pimpl_->Forbid(src, src_lets, trg, trg_lets); } void Transliterations::GraphSummary() const { pimpl_->GraphSummary(); } - diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index a548aacf..76eb2a05 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -10,9 +10,10 @@ struct Transliterations { explicit Transliterations(); ~Transliterations(); void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); - void Forbid(WordID src, WordID trg); + void Forbid(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); void GraphSummary() const; - prob_t EstimateProbability(WordID src, WordID trg) const; + prob_t EstimateProbability(WordID s, const std::vector& src, WordID t, const std::vector& trg) const; + private: TransliterationsImpl* pimpl_; }; -- cgit v1.2.3 From 9a8256604686a9283e7afce04e6feaab4922dd45 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Mar 2012 13:32:41 -0500 Subject: tl stuff --- gi/pf/Makefile.am | 8 +++-- gi/pf/align-tl.cc | 8 +++-- gi/pf/reachability.cc | 17 +++++++--- gi/pf/reachability.h | 4 +++ gi/pf/transliterations.cc | 82 ++++++++++++++++++++++++++++++++++------------- gi/pf/transliterations.h | 3 +- 6 files changed, 88 insertions(+), 34 deletions(-) (limited to 'gi/pf/align-tl.cc') diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index 5e89f02a..9888a70b 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -2,15 +2,17 @@ bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexon noinst_LIBRARIES = libpf.a -libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc +libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc transliterations.cc -nuisance_test_SOURCES = nuisance_test.cc transliterations.cc +nuisance_test_SOURCES = nuisance_test.cc +nuisance_test_LDADD = libpf.a align_lexonly_SOURCES = align-lexonly.cc align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc -align_tl_SOURCES = align-tl.cc transliterations.cc +align_tl_SOURCES = align-tl.cc +align_tl_LDADD = libpf.a itg_SOURCES = itg.cc diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index 6bb8c886..fe8950b5 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -305,7 +305,10 @@ int main(int argc, char** argv) { ExtractLetters(vocabf, &letters, NULL); letters[TD::Convert("NULL")].clear(); - Transliterations tl; + // TODO configure this + int max_src_chunk = 4; + int max_trg_chunk = 4; + Transliterations tl(max_src_chunk, max_trg_chunk); // TODO CONFIGURE THIS int min_trans_src = 4; @@ -318,10 +321,9 @@ int main(int argc, char** argv) { const vector& src_let = letters[src[j]]; for (int k = 0; k < trg.size(); ++k) { const vector& trg_let = letters[trg[k]]; + tl.Initialize(src[j], src_let, trg[k], trg_let); if (src_let.size() < min_trans_src) tl.Forbid(src[j], src_let, trg[k], trg_let); - else - tl.Initialize(src[j], src_let, trg[k], trg_let); } } } diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index 70fb76da..59bc6ace 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -39,6 +39,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras typedef boost::multi_array rarray_type; rarray_type r(boost::extents[srclen + 1][trglen + 1]); r[srclen][trglen] = true; + nodes = 0; for (int i = srclen; i >= 0; --i) { for (int j = trglen; j >= 0; --j) { vector& prevs = a[i][j]; @@ -57,10 +58,16 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras assert(!edges[0][0][0][1]); assert(!edges[0][0][0][0]); assert(max_src_delta[0][0] > 0); - cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n"; - //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; - //for (int i = 0; i < b[0][0].size(); ++i) { - // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; - //} + nodes = 0; + for (int i = 0; i < srclen; ++i) { + for (int j = 0; j < trglen; ++j) { + if (valid_deltas[i][j].size() > 0) { + node_addresses[i][j] = nodes++; + } else { + node_addresses[i][j] = -1; + } + } + } + cerr << "Sequence pair with lengths (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node, " << nodes << " nodes in total, and outside estimate matrix will require " << sizeof(float)*nodes << " bytes\n"; } diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index fb2f4965..1e22c76a 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -12,13 +12,17 @@ // currently forbids 0 -> n and n -> 0 alignments struct Reachability { + unsigned nodes; boost::multi_array edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring? boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + boost::multi_array node_addresses; // na[src_covered][trg_covered] -- the index of the node in a one-dimensional array (of size "nodes") boost::multi_array >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : + nodes(), edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), max_src_delta(boost::extents[srclen][trglen]), + node_addresses(boost::extents[srclen][trglen]), valid_deltas(boost::extents[srclen][trglen]) { ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); } diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index e29334fd..61e95b82 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -14,42 +14,75 @@ using namespace std; using namespace std::tr1; struct GraphStructure { - GraphStructure() : initialized(false) {} - boost::shared_ptr r; - bool initialized; + GraphStructure() : r() {} + // leak memory - these are basically static + const Reachability* r; + bool IsReachable() const { return r->nodes > 0; } +}; + +struct BackwardEstimates { + BackwardEstimates() : gs(), backward() {} + explicit BackwardEstimates(const GraphStructure& g) : + gs(&g), backward() { + if (g.r->nodes > 0) + backward = new float[g.r->nodes]; + } + // leak memory, these are static + + // returns an estimate of the marginal probability + double MarginalEstimate() const { + if (!backward) return 0; + return backward[0]; + } + + // returns an backward estimate + double operator()(int src_covered, int trg_covered) const { + if (!backward) return 0; + int ind = gs->r->node_addresses[src_covered][trg_covered]; + if (ind < 0) return 0; + return backward[ind]; + } + private: + const GraphStructure* gs; + float* backward; }; struct TransliterationsImpl { - TransliterationsImpl() { + TransliterationsImpl(int max_src, int max_trg) : + kMAX_SRC_CHUNK(max_src), + kMAX_TRG_CHUNK(max_trg), + tot_pairs() { } void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { const size_t src_len = src_lets.size(); const size_t trg_len = trg_lets.size(); + + // init graph structure if (src_len >= graphs.size()) graphs.resize(src_len + 1); if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); - if (graphs[src_len][trg_len].initialized) return; - graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4)); - -#if 0 - if (HG::Intersect(tlat, &hg)) { - // TODO - } else { - cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl; - hg.clear(); - } - //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; -#endif - graphs[src_len][trg_len].initialized = true; + GraphStructure& gs = graphs[src_len][trg_len]; + if (!gs.r) + gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK); + const Reachability& r = *gs.r; + + // init backward estimates + if (src >= bes.size()) bes.resize(src + 1); + unordered_map::iterator it = bes[src].find(trg); + if (it != bes[src].end()) return; // already initialized + + it = bes[src].insert(make_pair(trg, BackwardEstimates(gs))).first; + BackwardEstimates& b = it->second; + if (!gs.r->nodes) return; // not derivable subject to length constraints + + // TODO + tot_pairs++; } void Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { const size_t src_len = src_lets.size(); const size_t trg_len = trg_lets.size(); - if (src_len >= graphs.size()) graphs.resize(src_len + 1); - if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); - graphs[src_len][trg_len].r.reset(); - graphs[src_len][trg_len].initialized = true; + // TODO } prob_t EstimateProbability(WordID s, const vector& src, WordID t, const vector& trg) const { @@ -85,12 +118,17 @@ struct TransliterationsImpl { cerr << " Average nodes = " << (tn / tt) << endl; cerr << "Average out-degree = " << (to / tn) << endl; cerr << " Unique structures = " << tt << endl; + cerr << " Unique pairs = " << tot_pairs << endl; } + const int kMAX_SRC_CHUNK; + const int kMAX_TRG_CHUNK; + unsigned tot_pairs; vector > graphs; // graphs[src_len][trg_len] + vector > bes; // bes[src][trg] }; -Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} +Transliterations::Transliterations(int max_src, int max_trg) : pimpl_(new TransliterationsImpl(max_src, max_trg)) {} Transliterations::~Transliterations() { delete pimpl_; } void Transliterations::Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index 76eb2a05..e025547e 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -7,7 +7,8 @@ struct TransliterationsImpl; struct Transliterations { - explicit Transliterations(); + // max_src and max_trg indicate how big the transliteration phrases can be + explicit Transliterations(int max_src, int max_trg); ~Transliterations(); void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); void Forbid(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); -- cgit v1.2.3 From 78bf1457f606dd3880c2bc912201c4945d3f85b4 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Mar 2012 14:29:42 -0500 Subject: moar --- gi/pf/align-tl.cc | 15 +++++++++------ gi/pf/reachability.cc | 9 +++++---- gi/pf/reachability.h | 8 +++++--- gi/pf/transliterations.cc | 14 ++++++++++---- gi/pf/transliterations.h | 3 ++- 5 files changed, 31 insertions(+), 18 deletions(-) (limited to 'gi/pf/align-tl.cc') diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index fe8950b5..fc9b7ca5 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -30,6 +30,10 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { opts.add_options() ("samples,s",po::value()->default_value(1000),"Number of samples") ("input,i",po::value(),"Read parallel data from") + ("max_src_chunk", po::value()->default_value(4), "Maximum size of translitered chunk in source") + ("max_trg_chunk", po::value()->default_value(4), "Maximum size of translitered chunk in target") + ("min_transliterated_src_length", po::value()->default_value(3), "Minimum length of source words considered for transliteration") + ("filter_ratio", po::value()->default_value(0.66), "Filter ratio: basically, if the lengths differ by less than this ratio, mark the pair as non-transliteratable") ("random_seed,S",po::value(), "Random seed"); po::options_description clo("Command line options"); clo.add_options() @@ -306,12 +310,11 @@ int main(int argc, char** argv) { letters[TD::Convert("NULL")].clear(); // TODO configure this - int max_src_chunk = 4; - int max_trg_chunk = 4; - Transliterations tl(max_src_chunk, max_trg_chunk); - - // TODO CONFIGURE THIS - int min_trans_src = 4; + const int max_src_chunk = conf["max_src_chunk"].as(); + const int max_trg_chunk = conf["max_trg_chunk"].as(); + const double filter_rat = conf["filter_ratio"].as(); + const int min_trans_src = conf["min_transliterated_src_length"].as(); + Transliterations tl(max_src_chunk, max_trg_chunk, filter_rat); cerr << "Initializing transliteration graph structures ...\n"; for (int i = 0; i < corpus.size(); ++i) { diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index 59bc6ace..c10000f2 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -12,7 +12,7 @@ struct SState { int prev_trg_covered; }; -void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) { +void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio) { typedef boost::multi_array, 2> array_type; array_type a(boost::extents[srclen + 1][trglen + 1]); a[0][0].push_back(SState()); @@ -30,9 +30,10 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras } } a[0][0].clear(); - //cerr << "Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; - if (a[srclen][trglen].size() == 0) { - cerr << "Sentence with length (" << srclen << ',' << trglen << ") violates reachability constraints\n"; + //cerr << srclen << "," << trglen << ": Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; + size_t min_allowed = (src_max_phrase_len + 1) * (trg_max_phrase_len + 1) * (filter_ratio * filter_ratio); + if (a[srclen][trglen].size() < min_allowed) { + cerr << "Sequence pair with lengths (" << srclen << ',' << trglen << ") violates reachability constraint of min indegree " << min_allowed << " with " << a[srclen][trglen].size() << " in edges\n"; return; } diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index 1e22c76a..03967d44 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -18,17 +18,19 @@ struct Reachability { boost::multi_array node_addresses; // na[src_covered][trg_covered] -- the index of the node in a one-dimensional array (of size "nodes") boost::multi_array >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node - Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : + // filter_ratio says if the number of outgoing edges from the first cell is less than + // src_max * trg_max * filter_rat^2 then mark as non reachable + Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio = 0.0) : nodes(), edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), max_src_delta(boost::extents[srclen][trglen]), node_addresses(boost::extents[srclen][trglen]), valid_deltas(boost::extents[srclen][trglen]) { - ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); + ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len, filter_ratio); } private: - void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len); + void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio); }; #endif diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index 61e95b82..8ea4ebd2 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -48,10 +48,11 @@ struct BackwardEstimates { }; struct TransliterationsImpl { - TransliterationsImpl(int max_src, int max_trg) : + TransliterationsImpl(int max_src, int max_trg, double fr) : kMAX_SRC_CHUNK(max_src), kMAX_TRG_CHUNK(max_trg), - tot_pairs() { + kFILTER_RATIO(fr), + tot_pairs(), tot_mem() { } void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { @@ -63,7 +64,7 @@ struct TransliterationsImpl { if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); GraphStructure& gs = graphs[src_len][trg_len]; if (!gs.r) - gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK); + gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK, kFILTER_RATIO); const Reachability& r = *gs.r; // init backward estimates @@ -77,6 +78,7 @@ struct TransliterationsImpl { // TODO tot_pairs++; + tot_mem += sizeof(float) * gs.r->nodes; } void Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { @@ -119,16 +121,20 @@ struct TransliterationsImpl { cerr << "Average out-degree = " << (to / tn) << endl; cerr << " Unique structures = " << tt << endl; cerr << " Unique pairs = " << tot_pairs << endl; + cerr << " BEs size = " << (tot_mem / (1024.0*1024.0)) << " MB" << endl; } const int kMAX_SRC_CHUNK; const int kMAX_TRG_CHUNK; + const double kFILTER_RATIO; unsigned tot_pairs; + size_t tot_mem; vector > graphs; // graphs[src_len][trg_len] vector > bes; // bes[src][trg] }; -Transliterations::Transliterations(int max_src, int max_trg) : pimpl_(new TransliterationsImpl(max_src, max_trg)) {} +Transliterations::Transliterations(int max_src, int max_trg, double fr) : + pimpl_(new TransliterationsImpl(max_src, max_trg, fr)) {} Transliterations::~Transliterations() { delete pimpl_; } void Transliterations::Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index e025547e..ea9f9d3f 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -8,7 +8,8 @@ struct TransliterationsImpl; struct Transliterations { // max_src and max_trg indicate how big the transliteration phrases can be - explicit Transliterations(int max_src, int max_trg); + // see reachability.h for information about filter_ratio + explicit Transliterations(int max_src, int max_trg, double filter_ratio); ~Transliterations(); void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); void Forbid(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); -- cgit v1.2.3 From 113317266853abff2e1c0c3e889017d0eee55c93 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 9 Mar 2012 22:23:50 -0500 Subject: moar --- gi/pf/Makefile.am | 3 +- gi/pf/align-lexonly-pyp.cc | 207 ++++++++++------------------------------- gi/pf/align-tl.cc | 18 ++-- gi/pf/backward.cc | 89 ++++++++++++++++++ gi/pf/backward.h | 33 +++++++ gi/pf/base_distributions.h | 8 +- gi/pf/guess-translits.pl | 2 +- gi/pf/nuisance_test.cc | 6 +- gi/pf/pyp_lm.cc | 2 +- gi/pf/pyp_tm.cc | 113 +++++++++++++++++++++++ gi/pf/pyp_tm.h | 34 +++++++ gi/pf/pyp_word_model.cc | 20 ++++ gi/pf/pyp_word_model.h | 58 ++++++++++++ gi/pf/reachability.cc | 8 +- gi/pf/reachability.h | 8 +- gi/pf/transliterations.cc | 223 ++++++++++++++++++++++++++++++++++++++++----- gi/pf/transliterations.h | 3 +- utils/ccrp_nt.h | 17 ++-- 18 files changed, 628 insertions(+), 224 deletions(-) create mode 100644 gi/pf/backward.cc create mode 100644 gi/pf/backward.h create mode 100644 gi/pf/pyp_tm.cc create mode 100644 gi/pf/pyp_tm.h create mode 100644 gi/pf/pyp_word_model.cc create mode 100644 gi/pf/pyp_word_model.h (limited to 'gi/pf/align-tl.cc') diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index 94364c3d..4ce72ba1 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -2,7 +2,7 @@ bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexon noinst_LIBRARIES = libpf.a -libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc transliterations.cc +libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc transliterations.cc backward.cc pyp_word_model.cc pyp_tm.cc nuisance_test_SOURCES = nuisance_test.cc nuisance_test_LDADD = libpf.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz @@ -10,6 +10,7 @@ nuisance_test_LDADD = libpf.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mtev align_lexonly_SOURCES = align-lexonly.cc align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc +align_lexonly_pyp_LDADD = libpf.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz align_tl_SOURCES = align-tl.cc align_tl_LDADD = libpf.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz diff --git a/gi/pf/align-lexonly-pyp.cc b/gi/pf/align-lexonly-pyp.cc index 13a3a487..d68a4b8f 100644 --- a/gi/pf/align-lexonly-pyp.cc +++ b/gi/pf/align-lexonly-pyp.cc @@ -1,27 +1,18 @@ #include -#include #include -#include #include #include -#include "array2d.h" -#include "base_distributions.h" -#include "monotonic_pseg.h" -#include "conditional_pseg.h" -#include "trule.h" #include "tdict.h" #include "stringlib.h" #include "filelib.h" -#include "dict.h" +#include "array2d.h" #include "sampler.h" -#include "mfcr.h" #include "corpus.h" -#include "ngram_base.h" +#include "pyp_tm.h" using namespace std; -using namespace tr1; namespace po = boost::program_options; void InitCommandLine(int argc, char** argv, po::variables_map* conf) { @@ -51,7 +42,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -shared_ptr prng; +MT19937* prng; struct LexicalAlignment { unsigned char src_index; @@ -66,159 +57,59 @@ struct AlignedSentencePair { Array2D posterior; }; -struct HierarchicalWordBase { - explicit HierarchicalWordBase(const unsigned vocab_e_size) : - base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {} - - void ResampleHyperparameters(MT19937* rng) { - r.resample_hyperparameters(rng); - } - - inline double logp0(const vector& s) const { - return Md::log_poisson(s.size(), 7.5) + s.size() * u0; - } - - // return p0 of rule.e_ - prob_t operator()(const TRule& rule) const { - v[0].logeq(logp0(rule.e_)); - return r.prob(rule.e_, v.begin(), l.begin()); - } - - void Increment(const TRule& rule) { - v[0].logeq(logp0(rule.e_)); - if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) { - base *= v[0] * l[0]; - } - } - - void Decrement(const TRule& rule) { - if (r.decrement(rule.e_, &*prng).count) { - base /= prob_t(exp(logp0(rule.e_))); - } - } - - prob_t Likelihood() const { - prob_t p; p.logeq(r.log_crp_prob()); - p *= base; - return p; +struct Aligner { + Aligner(const vector >& lets, int num_letters, vector* c) : + corpus(*c), + model(lets, num_letters), + kNULL(TD::Convert("NULL")) { + assert(lets[kNULL].size() == 0); } - void Summary() const { - cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << " (d=" << r.discount() << ",s=" << r.strength() << ')' << endl; - for (MFCR<1,vector >::const_iterator it = r.begin(); it != r.end(); ++it) - cerr << " " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables) " << TD::GetString(it->first) << endl; - } - - prob_t base; - MFCR<1,vector > r; - const double u0; - const vector l; - mutable vector v; -}; - -struct BasicLexicalAlignment { - explicit BasicLexicalAlignment(const vector >& lets, - const unsigned words_e, - const unsigned letters_e, - vector* corp) : - letters(lets), - corpus(*corp), - //up0(words_e), - //up0("en.chars.1gram", letters_e), - //up0("en.words.1gram"), - up0(letters_e), - //up0("en.chars.2gram"), - tmodel(up0) { - } + vector& corpus; + PYPLexicalTranslation model; + const WordID kNULL; - void InstantiateRule(const WordID src, - const WordID trg, - TRule* rule) const { - static const WordID kX = TD::Convert("X") * -1; - rule->lhs_ = kX; - rule->e_ = letters[trg]; - rule->f_ = letters[src]; + void ResampleHyperparameters() { + model.ResampleHyperparameters(prng); } void InitializeRandom() { - const WordID kNULL = TD::Convert("NULL"); cerr << "Initializing with random alignments ...\n"; for (unsigned i = 0; i < corpus.size(); ++i) { AlignedSentencePair& asp = corpus[i]; asp.a.resize(asp.trg.size()); for (unsigned j = 0; j < asp.trg.size(); ++j) { - const unsigned char a_j = prng->next() * (1 + asp.src.size()); + unsigned char& a_j = asp.a[j].src_index; + a_j = prng->next() * (1 + asp.src.size()); const WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); - TRule r; - InstantiateRule(f_a_j, asp.trg[j], &r); - asp.a[j].is_transliteration = false; - asp.a[j].src_index = a_j; - if (tmodel.IncrementRule(r, &*prng)) - up0.Increment(r); + model.Increment(f_a_j, asp.trg[j], &*prng); } } - cerr << " LLH = " << Likelihood() << endl; - } - - prob_t Likelihood() const { - prob_t p = tmodel.Likelihood(); - p *= up0.Likelihood(); - return p; - } - - void ResampleHyperparemeters() { - tmodel.ResampleHyperparameters(&*prng); - up0.ResampleHyperparameters(&*prng); - cerr << " (base d=" << up0.r.discount() << ",s=" << up0.r.strength() << ")\n"; + cerr << "Corpus intialized randomly. LLH = " << model.Likelihood() << endl; } - void ResampleCorpus(); - - const vector >& letters; // spelling dictionary - vector& corpus; - //PhraseConditionalUninformativeBase up0; - //PhraseConditionalUninformativeUnigramBase up0; - //UnigramWordBase up0; - //HierarchicalUnigramBase up0; - HierarchicalWordBase up0; - //CompletelyUniformBase up0; - //FixedNgramBase up0; - //ConditionalTranslationModel tmodel; - //ConditionalTranslationModel tmodel; - //ConditionalTranslationModel tmodel; - //ConditionalTranslationModel tmodel; - MConditionalTranslationModel tmodel; - //ConditionalTranslationModel tmodel; - //ConditionalTranslationModel tmodel; -}; - -void BasicLexicalAlignment::ResampleCorpus() { - static const WordID kNULL = TD::Convert("NULL"); - for (unsigned i = 0; i < corpus.size(); ++i) { - AlignedSentencePair& asp = corpus[i]; - SampleSet ss; ss.resize(asp.src.size() + 1); - for (unsigned j = 0; j < asp.trg.size(); ++j) { - TRule r; - unsigned char& a_j = asp.a[j].src_index; - WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); - InstantiateRule(f_a_j, asp.trg[j], &r); - if (tmodel.DecrementRule(r, &*prng)) - up0.Decrement(r); - - for (unsigned prop_a_j = 0; prop_a_j <= asp.src.size(); ++prop_a_j) { - const WordID prop_f = (prop_a_j ? asp.src[prop_a_j - 1] : kNULL); - InstantiateRule(prop_f, asp.trg[j], &r); - ss[prop_a_j] = tmodel.RuleProbability(r); + void ResampleCorpus() { + for (unsigned i = 0; i < corpus.size(); ++i) { + AlignedSentencePair& asp = corpus[i]; + SampleSet ss; ss.resize(asp.src.size() + 1); + for (unsigned j = 0; j < asp.trg.size(); ++j) { + unsigned char& a_j = asp.a[j].src_index; + const WordID e_j = asp.trg[j]; + WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + model.Decrement(f_a_j, e_j, prng); + + for (unsigned prop_a_j = 0; prop_a_j <= asp.src.size(); ++prop_a_j) { + const WordID prop_f = (prop_a_j ? asp.src[prop_a_j - 1] : kNULL); + ss[prop_a_j] = model.Prob(prop_f, e_j); + } + a_j = prng->SelectSample(ss); + f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); + model.Increment(f_a_j, e_j, prng); } - a_j = prng->SelectSample(ss); - f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); - InstantiateRule(f_a_j, asp.trg[j], &r); - if (tmodel.IncrementRule(r, &*prng)) - up0.Increment(r); } + cerr << "LLH = " << model.Likelihood() << " " << model.UniqueConditioningContexts() << endl; } - cerr << " LLH = " << Likelihood() << endl; -} +}; void ExtractLetters(const set& v, vector >* l, set* letset = NULL) { for (set::const_iterator it = v.begin(); it != v.end(); ++it) { @@ -240,8 +131,10 @@ void ExtractLetters(const set& v, vector >* l, set a(asp.src.size(), asp.trg.size()); - for (unsigned j = 0; j < asp.trg.size(); ++j) + for (unsigned j = 0; j < asp.trg.size(); ++j) { + assert(asp.a[j].src_index <= asp.src.size()); if (asp.a[j].src_index) a(asp.a[j].src_index - 1, j) = true; + } cerr << a << endl; } @@ -275,10 +168,9 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); if (conf.count("random_seed")) - prng.reset(new MT19937(conf["random_seed"].as())); + prng = new MT19937(conf["random_seed"].as()); else - prng.reset(new MT19937); -// MT19937& rng = *prng; + prng = new MT19937; vector > corpuse, corpusf; set vocabe, vocabf; @@ -304,23 +196,18 @@ int main(int argc, char** argv) { ExtractLetters(vocabf, &letters, NULL); letters[TD::Convert("NULL")].clear(); - BasicLexicalAlignment x(letters, vocabe.size(), letset.size(), &corpus); - x.InitializeRandom(); + Aligner aligner(letters, letset.size(), &corpus); + aligner.InitializeRandom(); + const unsigned samples = conf["samples"].as(); for (int i = 0; i < samples; ++i) { for (int j = 65; j < 67; ++j) Debug(corpus[j]); - cerr << i << "\t" << x.tmodel.r.size() << "\t"; - if (i % 7 == 6) x.ResampleHyperparemeters(); - x.ResampleCorpus(); + if (i % 7 == 6) aligner.ResampleHyperparameters(); + aligner.ResampleCorpus(); if (i > (samples / 5) && (i % 10 == 9)) for (int j = 0; j < corpus.size(); ++j) AddSample(&corpus[j]); } for (unsigned i = 0; i < corpus.size(); ++i) WriteAlignments(corpus[i]); - //ModelAndData posterior(x, &corpus, vocabe, vocabf); - x.tmodel.Summary(); - x.up0.Summary(); - - //posterior.Sample(); return 0; } diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index fc9b7ca5..cbe8c6c8 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -6,6 +6,7 @@ #include #include +#include "backward.h" #include "array2d.h" #include "base_distributions.h" #include "monotonic_pseg.h" @@ -30,10 +31,11 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { opts.add_options() ("samples,s",po::value()->default_value(1000),"Number of samples") ("input,i",po::value(),"Read parallel data from") + ("s2t", po::value(), "character level source-to-target prior transliteration probabilities") + ("t2s", po::value(), "character level target-to-source prior transliteration probabilities") ("max_src_chunk", po::value()->default_value(4), "Maximum size of translitered chunk in source") ("max_trg_chunk", po::value()->default_value(4), "Maximum size of translitered chunk in target") - ("min_transliterated_src_length", po::value()->default_value(3), "Minimum length of source words considered for transliteration") - ("filter_ratio", po::value()->default_value(0.66), "Filter ratio: basically, if the lengths differ by less than this ratio, mark the pair as non-transliteratable") + ("expected_src_to_trg_ratio", po::value()->default_value(1.0), "If a word is transliterated, what is the expected length ratio from source to target?") ("random_seed,S",po::value(), "Random seed"); po::options_description clo("Command line options"); clo.add_options() @@ -303,7 +305,7 @@ int main(int argc, char** argv) { corpusf.clear(); corpuse.clear(); vocabf.insert(TD::Convert("NULL")); - vector > letters(TD::NumWords()); + vector > letters(TD::NumWords() + 1); set letset; ExtractLetters(vocabe, &letters, &letset); ExtractLetters(vocabf, &letters, NULL); @@ -312,9 +314,9 @@ int main(int argc, char** argv) { // TODO configure this const int max_src_chunk = conf["max_src_chunk"].as(); const int max_trg_chunk = conf["max_trg_chunk"].as(); - const double filter_rat = conf["filter_ratio"].as(); - const int min_trans_src = conf["min_transliterated_src_length"].as(); - Transliterations tl(max_src_chunk, max_trg_chunk, filter_rat); + const double s2t_rat = conf["expected_src_to_trg_ratio"].as(); + const BackwardEstimator be(conf["s2t"].as(), conf["t2s"].as()); + Transliterations tl(max_src_chunk, max_trg_chunk, s2t_rat, be); cerr << "Initializing transliteration graph structures ...\n"; for (int i = 0; i < corpus.size(); ++i) { @@ -325,8 +327,8 @@ int main(int argc, char** argv) { for (int k = 0; k < trg.size(); ++k) { const vector& trg_let = letters[trg[k]]; tl.Initialize(src[j], src_let, trg[k], trg_let); - if (src_let.size() < min_trans_src) - tl.Forbid(src[j], src_let, trg[k], trg_let); + //if (src_let.size() < min_trans_src) + // tl.Forbid(src[j], src_let, trg[k], trg_let); } } } diff --git a/gi/pf/backward.cc b/gi/pf/backward.cc new file mode 100644 index 00000000..b92629fd --- /dev/null +++ b/gi/pf/backward.cc @@ -0,0 +1,89 @@ +#include "backward.h" + +#include +#include + +#include "array2d.h" +#include "reachability.h" +#include "base_distributions.h" + +using namespace std; + +BackwardEstimator::BackwardEstimator(const string& s2t, + const string& t2s) : m1(new Model1(s2t)), m1inv(new Model1(t2s)) {} + +BackwardEstimator::~BackwardEstimator() { + delete m1; m1 = NULL; + delete m1inv; m1inv = NULL; +} + +float BackwardEstimator::ComputeBackwardProb(const std::vector& src, + const std::vector& trg, + unsigned src_covered, + unsigned trg_covered, + double s2t_ratio) const { + if (src_covered == src.size() || trg_covered == trg.size()) { + assert(src_covered == src.size()); + assert(trg_covered == trg.size()); + return 0; + } + static const WordID kNULL = TD::Convert(""); + const prob_t uniform_alignment(1.0 / (src.size() - src_covered + 1)); + // TODO factor in expected length ratio + prob_t e; e.logeq(Md::log_poisson(trg.size() - trg_covered, (src.size() - src_covered) * s2t_ratio)); // p(trg len remaining | src len remaining) + for (unsigned j = trg_covered; j < trg.size(); ++j) { + prob_t p = (*m1)(kNULL, trg[j]) + prob_t(1e-12); + for (unsigned i = src_covered; i < src.size(); ++i) + p += (*m1)(src[i], trg[j]); + if (p.is_0()) { + cerr << "ERROR: p(" << TD::Convert(trg[j]) << " | " << TD::GetString(src) << ") = 0!\n"; + assert(!"failed"); + } + p *= uniform_alignment; + e *= p; + } + // TODO factor in expected length ratio + const prob_t inv_uniform(1.0 / (trg.size() - trg_covered + 1.0)); + prob_t inv; + inv.logeq(Md::log_poisson(src.size() - src_covered, (trg.size() - trg_covered) / s2t_ratio)); + for (unsigned i = src_covered; i < src.size(); ++i) { + prob_t p = (*m1inv)(kNULL, src[i]) + prob_t(1e-12); + for (unsigned j = trg_covered; j < trg.size(); ++j) + p += (*m1inv)(trg[j], src[i]); + if (p.is_0()) { + cerr << "ERROR: p_inv(" << TD::Convert(src[i]) << " | " << TD::GetString(trg) << ") = 0!\n"; + assert(!"failed"); + } + p *= inv_uniform; + inv *= p; + } + return (log(e) + log(inv)) / 2; +} + +void BackwardEstimator::InitializeGrid(const vector& src, + const vector& trg, + const Reachability& r, + double s2t_ratio, + float* grid) const { + queue > q; + q.push(make_pair(0,0)); + Array2D done(src.size()+1, trg.size()+1, false); + //cerr << TD::GetString(src) << " ||| " << TD::GetString(trg) << endl; + while(!q.empty()) { + const pair n = q.front(); + q.pop(); + if (done(n.first,n.second)) continue; + done(n.first,n.second) = true; + + float lp = ComputeBackwardProb(src, trg, n.first, n.second, s2t_ratio); + if (n.first == 0 && n.second == 0) grid[0] = lp; + //cerr << " " << n.first << "," << n.second << "\t" << lp << endl; + + if (n.first == src.size() || n.second == trg.size()) continue; + const vector >& edges = r.valid_deltas[n.first][n.second]; + for (int i = 0; i < edges.size(); ++i) + q.push(make_pair(n.first + edges[i].first, n.second + edges[i].second)); + } + //static int cc = 0; ++cc; if (cc == 80) exit(1); +} + diff --git a/gi/pf/backward.h b/gi/pf/backward.h new file mode 100644 index 00000000..e67eff0c --- /dev/null +++ b/gi/pf/backward.h @@ -0,0 +1,33 @@ +#ifndef _BACKWARD_H_ +#define _BACKWARD_H_ + +#include +#include +#include "wordid.h" + +struct Reachability; +struct Model1; + +struct BackwardEstimator { + BackwardEstimator(const std::string& s2t, + const std::string& t2s); + ~BackwardEstimator(); + + void InitializeGrid(const std::vector& src, + const std::vector& trg, + const Reachability& r, + double src2trg_ratio, + float* grid) const; + + private: + float ComputeBackwardProb(const std::vector& src, + const std::vector& trg, + unsigned src_covered, + unsigned trg_covered, + double src2trg_ratio) const; + + Model1* m1; + Model1* m1inv; +}; + +#endif diff --git a/gi/pf/base_distributions.h b/gi/pf/base_distributions.h index 0d597c5c..84dacdf2 100644 --- a/gi/pf/base_distributions.h +++ b/gi/pf/base_distributions.h @@ -14,13 +14,7 @@ #include "tdict.h" #include "sampler.h" #include "m.h" - -inline std::ostream& operator<<(std::ostream& os, const std::vector& p) { - os << '['; - for (int i = 0; i < p.size(); ++i) - os << (i==0 ? "" : " ") << TD::Convert(p[i]); - return os << ']'; -} +#include "os_phrase.h" struct Model1 { explicit Model1(const std::string& fname) : diff --git a/gi/pf/guess-translits.pl b/gi/pf/guess-translits.pl index aafec13a..d00c2168 100755 --- a/gi/pf/guess-translits.pl +++ b/gi/pf/guess-translits.pl @@ -69,4 +69,4 @@ for my $f (keys %fs) { } } print STDERR "Extracted $num pairs.\n"; -print STDERR "Recommend running:\n ../../training/model1 -t -99999 output.txt\n"; +print STDERR "Recommend running:\n ../../training/model1 -v -d -t -99999 output.txt\n"; diff --git a/gi/pf/nuisance_test.cc b/gi/pf/nuisance_test.cc index 0f44fe95..fc0af9cb 100644 --- a/gi/pf/nuisance_test.cc +++ b/gi/pf/nuisance_test.cc @@ -124,9 +124,9 @@ int main(int argc, char** argv) { WordID y = TD::Convert("remember"); vector src; TD::ConvertSentence("s o u v e n o n s", &src); vector trg; TD::ConvertSentence("r e m e m b e r", &trg); - Transliterations xx; - xx.Initialize(x, src, y, trg); - return 1; +// Transliterations xx; +// xx.Initialize(x, src, y, trg); +// return 1; for (int j = 0; j < ITERS; ++j) { Base b; diff --git a/gi/pf/pyp_lm.cc b/gi/pf/pyp_lm.cc index 104f356b..52e6be2c 100644 --- a/gi/pf/pyp_lm.cc +++ b/gi/pf/pyp_lm.cc @@ -18,7 +18,7 @@ // I use templates to handle the recursive formalation of the prior, so // the order of the model has to be specified here, at compile time: -#define kORDER 4 +#define kORDER 3 using namespace std; using namespace tr1; diff --git a/gi/pf/pyp_tm.cc b/gi/pf/pyp_tm.cc new file mode 100644 index 00000000..94cbe7c3 --- /dev/null +++ b/gi/pf/pyp_tm.cc @@ -0,0 +1,113 @@ +#include "pyp_tm.h" + +#include +#include +#include + +#include "base_distributions.h" +#include "monotonic_pseg.h" +#include "conditional_pseg.h" +#include "tdict.h" +#include "ccrp.h" +#include "pyp_word_model.h" + +using namespace std; +using namespace std::tr1; + +template +struct ConditionalPYPWordModel { + ConditionalPYPWordModel(Base* b) : base(*b) {} + + void Summary() const { + cerr << "Number of conditioning contexts: " << r.size() << endl; + for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { + cerr << TD::Convert(it->first) << " \tPYP(d=" << it->second.discount() << ",s=" << it->second.strength() << ") --------------------------" << endl; + for (CCRP >::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + cerr << " " << i2->second.total_dish_count_ << '\t' << TD::GetString(i2->first) << endl; + } + } + + void ResampleHyperparameters(MT19937* rng) { + for (RuleModelHash::iterator it = r.begin(); it != r.end(); ++it) + it->second.resample_hyperparameters(rng); + } + + prob_t Prob(const WordID src, const vector& trglets) const { + RuleModelHash::const_iterator it = r.find(src); + if (it == r.end()) { + return base(trglets); + } else { + return it->second.prob(trglets, base(trglets)); + } + } + + void Increment(const WordID src, const vector& trglets, MT19937* rng) { + RuleModelHash::iterator it = r.find(src); + if (it == r.end()) + it = r.insert(make_pair(src, CCRP >(1,1,1,1,0.5,1.0))).first; + if (it->second.increment(trglets, base(trglets), rng)) + base.Increment(trglets, rng); + } + + void Decrement(const WordID src, const vector& trglets, MT19937* rng) { + RuleModelHash::iterator it = r.find(src); + assert(it != r.end()); + if (it->second.decrement(trglets, rng)) { + base.Decrement(trglets, rng); + if (it->second.num_customers() == 0) + r.erase(it); + } + } + + prob_t Likelihood() const { + prob_t p = prob_t::One(); + for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { + prob_t q; q.logeq(it->second.log_crp_prob()); + p *= q; + } + return p; + } + + unsigned UniqueConditioningContexts() const { + return r.size(); + } + + Base& base; + typedef unordered_map > > RuleModelHash; + RuleModelHash r; +}; + +PYPLexicalTranslation::PYPLexicalTranslation(const vector >& lets, + const unsigned num_letters) : + letters(lets), + up0(new PYPWordModel(num_letters)), + tmodel(new ConditionalPYPWordModel(up0)), + kX(-TD::Convert("X")) {} + +prob_t PYPLexicalTranslation::Likelihood() const { + prob_t p = up0->Likelihood(); + p *= tmodel->Likelihood(); + return p; +} + +void PYPLexicalTranslation::ResampleHyperparameters(MT19937* rng) { + tmodel->ResampleHyperparameters(rng); + up0->ResampleHyperparameters(rng); +} + +unsigned PYPLexicalTranslation::UniqueConditioningContexts() const { + return tmodel->UniqueConditioningContexts(); +} + +prob_t PYPLexicalTranslation::Prob(WordID src, WordID trg) const { + return tmodel->Prob(src, letters[trg]); +} + +void PYPLexicalTranslation::Increment(WordID src, WordID trg, MT19937* rng) { + tmodel->Increment(src, letters[trg], rng); +} + +void PYPLexicalTranslation::Decrement(WordID src, WordID trg, MT19937* rng) { + tmodel->Decrement(src, letters[trg], rng); +} + diff --git a/gi/pf/pyp_tm.h b/gi/pf/pyp_tm.h new file mode 100644 index 00000000..fa0fb28f --- /dev/null +++ b/gi/pf/pyp_tm.h @@ -0,0 +1,34 @@ +#ifndef PYP_LEX_TRANS +#define PYP_LEX_TRANS + +#include +#include "wordid.h" +#include "prob.h" +#include "sampler.h" + +struct TRule; +struct PYPWordModel; +template struct ConditionalPYPWordModel; + +struct PYPLexicalTranslation { + explicit PYPLexicalTranslation(const std::vector >& lets, + const unsigned num_letters); + + prob_t Likelihood() const; + + void ResampleHyperparameters(MT19937* rng); + prob_t Prob(WordID src, WordID trg) const; // return p(trg | src) + void Summary() const; + void Increment(WordID src, WordID trg, MT19937* rng); + void Decrement(WordID src, WordID trg, MT19937* rng); + unsigned UniqueConditioningContexts() const; + + private: + const std::vector >& letters; // spelling dictionary + PYPWordModel* up0; // base distribuction (model English word) + ConditionalPYPWordModel* tmodel; // translation distributions + // (model English word | French word) + const WordID kX; +}; + +#endif diff --git a/gi/pf/pyp_word_model.cc b/gi/pf/pyp_word_model.cc new file mode 100644 index 00000000..12df4abf --- /dev/null +++ b/gi/pf/pyp_word_model.cc @@ -0,0 +1,20 @@ +#include "pyp_word_model.h" + +#include + +using namespace std; + +void PYPWordModel::ResampleHyperparameters(MT19937* rng) { + r.resample_hyperparameters(rng); + cerr << " PYPWordModel(d=" << r.discount() << ",s=" << r.strength() << ")\n"; +} + +void PYPWordModel::Summary() const { + cerr << "PYPWordModel: generations=" << r.num_customers() + << " PYP(d=" << r.discount() << ",s=" << r.strength() << ')' << endl; + for (CCRP >::const_iterator it = r.begin(); it != r.end(); ++it) + cerr << " " << it->second.total_dish_count_ + << " (on " << it->second.table_counts_.size() << " tables) " + << TD::GetString(it->first) << endl; +} + diff --git a/gi/pf/pyp_word_model.h b/gi/pf/pyp_word_model.h new file mode 100644 index 00000000..800a4fd7 --- /dev/null +++ b/gi/pf/pyp_word_model.h @@ -0,0 +1,58 @@ +#ifndef _PYP_WORD_MODEL_H_ +#define _PYP_WORD_MODEL_H_ + +#include +#include +#include +#include "prob.h" +#include "ccrp.h" +#include "m.h" +#include "tdict.h" +#include "os_phrase.h" + +// PYP(d,s,poisson-uniform) represented as a CRP +struct PYPWordModel { + explicit PYPWordModel(const unsigned vocab_e_size, const double mean_len = 7.5) : + base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-std::log(vocab_e_size)), mean_length(mean_len) {} + + void ResampleHyperparameters(MT19937* rng); + + inline prob_t operator()(const std::vector& s) const { + return r.prob(s, p0(s)); + } + + inline void Increment(const std::vector& s, MT19937* rng) { + if (r.increment(s, p0(s), rng)) + base *= p0(s); + } + + inline void Decrement(const std::vector& s, MT19937 *rng) { + if (r.decrement(s, rng)) + base /= p0(s); + } + + inline prob_t Likelihood() const { + prob_t p; p.logeq(r.log_crp_prob()); + p *= base; + return p; + } + + void Summary() const; + + private: + inline double logp0(const std::vector& s) const { + return Md::log_poisson(s.size(), mean_length) + s.size() * u0; + } + + inline prob_t p0(const std::vector& s) const { + prob_t p; p.logeq(logp0(s)); + return p; + } + + prob_t base; // keeps track of the draws from the base distribution + CCRP > r; + const double u0; // uniform log prob of generating a letter + const double mean_length; // mean length of a word in the base distribution +}; + +#endif diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index c10000f2..7d0d04ac 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -12,7 +12,7 @@ struct SState { int prev_trg_covered; }; -void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio) { +void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) { typedef boost::multi_array, 2> array_type; array_type a(boost::extents[srclen + 1][trglen + 1]); a[0][0].push_back(SState()); @@ -31,9 +31,9 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras } a[0][0].clear(); //cerr << srclen << "," << trglen << ": Final cell contains " << a[srclen][trglen].size() << " back pointers\n"; - size_t min_allowed = (src_max_phrase_len + 1) * (trg_max_phrase_len + 1) * (filter_ratio * filter_ratio); - if (a[srclen][trglen].size() < min_allowed) { - cerr << "Sequence pair with lengths (" << srclen << ',' << trglen << ") violates reachability constraint of min indegree " << min_allowed << " with " << a[srclen][trglen].size() << " in edges\n"; + if (a[srclen][trglen].empty()) { + cerr << "Sequence pair with lengths (" << srclen << ',' << trglen << ") violates reachability constraints\n"; + nodes = 0; return; } diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index 03967d44..1e22c76a 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -18,19 +18,17 @@ struct Reachability { boost::multi_array node_addresses; // na[src_covered][trg_covered] -- the index of the node in a one-dimensional array (of size "nodes") boost::multi_array >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node - // filter_ratio says if the number of outgoing edges from the first cell is less than - // src_max * trg_max * filter_rat^2 then mark as non reachable - Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio = 0.0) : + Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : nodes(), edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), max_src_delta(boost::extents[srclen][trglen]), node_addresses(boost::extents[srclen][trglen]), valid_deltas(boost::extents[srclen][trglen]) { - ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len, filter_ratio); + ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); } private: - void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len, double filter_ratio); + void ComputeReachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len); }; #endif diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index 8ea4ebd2..2200715e 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -5,14 +5,173 @@ #include "boost/shared_ptr.hpp" +#include "backward.h" #include "filelib.h" -#include "ccrp.h" +#include "tdict.h" +#include "trule.h" +#include "filelib.h" +#include "ccrp_nt.h" #include "m.h" #include "reachability.h" using namespace std; using namespace std::tr1; +struct TruncatedConditionalLengthModel { + TruncatedConditionalLengthModel(unsigned max_src_size, unsigned max_trg_size, double expected_src_to_trg_ratio) : + plens(max_src_size+1, vector(max_trg_size+1, 0.0)) { + for (unsigned i = 1; i <= max_src_size; ++i) { + prob_t z = prob_t::Zero(); + for (unsigned j = 1; j <= max_trg_size; ++j) + z += (plens[i][j] = prob_t(0.01 + exp(Md::log_poisson(j, i * expected_src_to_trg_ratio)))); + for (unsigned j = 1; j <= max_trg_size; ++j) + plens[i][j] /= z; + //for (unsigned j = 1; j <= max_trg_size; ++j) + // cerr << "P(trg_len=" << j << " | src_len=" << i << ") = " << plens[i][j] << endl; + } + } + + // return p(tlen | slen) for *chunks* not full words + inline const prob_t& operator()(int slen, int tlen) const { + return plens[slen][tlen]; + } + + vector > plens; +}; + +struct CondBaseDist { + CondBaseDist(unsigned max_src_size, unsigned max_trg_size, double expected_src_to_trg_ratio) : + tclm(max_src_size, max_trg_size, expected_src_to_trg_ratio) {} + + prob_t operator()(const vector& src, unsigned sf, unsigned st, + const vector& trg, unsigned tf, unsigned tt) const { + prob_t p = tclm(st - sf, tt - tf); // target len | source length ~ TCLM(source len) + assert(!"not impl"); + return p; + } + inline prob_t operator()(const vector& src, const vector& trg) const { + return (*this)(src, 0, src.size(), trg, 0, trg.size()); + } + TruncatedConditionalLengthModel tclm; +}; + +// represents transliteration phrase probabilities, e.g. +// p( a l - | A l ) , p( o | A w ) , ... +struct TransliterationChunkConditionalModel { + explicit TransliterationChunkConditionalModel(const CondBaseDist& pp0) : + d(0.0), + strength(1.0), + rp0(pp0) { + } + + void Summary() const { + std::cerr << "Number of conditioning contexts: " << r.size() << std::endl; + for (RuleModelHash::const_iterator it = r.begin(); it != r.end(); ++it) { + std::cerr << TD::GetString(it->first) << " \t(\\alpha = " << it->second.alpha() << ") --------------------------" << std::endl; + for (CCRP_NoTable::const_iterator i2 = it->second.begin(); i2 != it->second.end(); ++i2) + std::cerr << " " << i2->second << '\t' << i2->first << std::endl; + } + } + + int DecrementRule(const TRule& rule) { + RuleModelHash::iterator it = r.find(rule.f_); + assert(it != r.end()); + int count = it->second.decrement(rule); + if (count) { + if (it->second.num_customers() == 0) r.erase(it); + } + return count; + } + + int IncrementRule(const TRule& rule) { + RuleModelHash::iterator it = r.find(rule.f_); + if (it == r.end()) { + it = r.insert(make_pair(rule.f_, CCRP_NoTable(strength))).first; + } + int count = it->second.increment(rule); + return count; + } + + void IncrementRules(const std::vector& rules) { + for (int i = 0; i < rules.size(); ++i) + IncrementRule(*rules[i]); + } + + void DecrementRules(const std::vector& rules) { + for (int i = 0; i < rules.size(); ++i) + DecrementRule(*rules[i]); + } + + prob_t RuleProbability(const TRule& rule) const { + prob_t p; + RuleModelHash::const_iterator it = r.find(rule.f_); + if (it == r.end()) { + p = rp0(rule.f_, rule.e_); + } else { + p = it->second.prob(rule, rp0(rule.f_, rule.e_)); + } + return p; + } + + double LogLikelihood(const double& dd, const double& aa) const { + if (aa <= -dd) return -std::numeric_limits::infinity(); + //double llh = Md::log_beta_density(dd, 10, 3) + Md::log_gamma_density(aa, 1, 1); + double llh = //Md::log_beta_density(dd, 1, 1) + + Md::log_gamma_density(dd + aa, 1, 1); + typename std::tr1::unordered_map, CCRP_NoTable, boost::hash > >::const_iterator it; + for (it = r.begin(); it != r.end(); ++it) + llh += it->second.log_crp_prob(aa); + return llh; + } + + struct AlphaResampler { + AlphaResampler(const TransliterationChunkConditionalModel& m) : m_(m) {} + const TransliterationChunkConditionalModel& m_; + double operator()(const double& proposed_strength) const { + return m_.LogLikelihood(m_.d, proposed_strength); + } + }; + + void ResampleHyperparameters(MT19937* rng) { + typename std::tr1::unordered_map, CCRP_NoTable, boost::hash > >::iterator it; + //const unsigned nloop = 5; + const unsigned niterations = 10; + //DiscountResampler dr(*this); + AlphaResampler ar(*this); +#if 0 + for (int iter = 0; iter < nloop; ++iter) { + strength = slice_sampler1d(ar, strength, *rng, -d + std::numeric_limits::min(), + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + double min_discount = std::numeric_limits::min(); + if (strength < 0.0) min_discount -= strength; + d = slice_sampler1d(dr, d, *rng, min_discount, + 1.0, 0.0, niterations, 100*niterations); + } +#endif + strength = slice_sampler1d(ar, strength, *rng, -d, + std::numeric_limits::infinity(), 0.0, niterations, 100*niterations); + std::cerr << "CTMModel(alpha=" << strength << ") = " << LogLikelihood(d, strength) << std::endl; + for (it = r.begin(); it != r.end(); ++it) { +#if 0 + it->second.set_discount(d); +#endif + it->second.set_alpha(strength); + } + } + + prob_t Likelihood() const { + prob_t p; p.logeq(LogLikelihood(d, strength)); + return p; + } + + const CondBaseDist& rp0; + typedef std::tr1::unordered_map, + CCRP_NoTable, + boost::hash > > RuleModelHash; + RuleModelHash r; + double d, strength; +}; + struct GraphStructure { GraphStructure() : r() {} // leak memory - these are basically static @@ -20,9 +179,9 @@ struct GraphStructure { bool IsReachable() const { return r->nodes > 0; } }; -struct BackwardEstimates { - BackwardEstimates() : gs(), backward() {} - explicit BackwardEstimates(const GraphStructure& g) : +struct ProbabilityEstimates { + ProbabilityEstimates() : gs(), backward() {} + explicit ProbabilityEstimates(const GraphStructure& g) : gs(&g), backward() { if (g.r->nodes > 0) backward = new float[g.r->nodes]; @@ -36,24 +195,32 @@ struct BackwardEstimates { } // returns an backward estimate - double operator()(int src_covered, int trg_covered) const { + double Backward(int src_covered, int trg_covered) const { if (!backward) return 0; int ind = gs->r->node_addresses[src_covered][trg_covered]; if (ind < 0) return 0; return backward[ind]; } + + prob_t estp; + float* backward; private: const GraphStructure* gs; - float* backward; }; struct TransliterationsImpl { - TransliterationsImpl(int max_src, int max_trg, double fr) : + TransliterationsImpl(int max_src, int max_trg, double sr, const BackwardEstimator& b) : + cp0(max_src, max_trg, sr), + tccm(cp0), + be(b), kMAX_SRC_CHUNK(max_src), kMAX_TRG_CHUNK(max_trg), - kFILTER_RATIO(fr), + kS2T_RATIO(sr), tot_pairs(), tot_mem() { } + const CondBaseDist cp0; + TransliterationChunkConditionalModel tccm; + const BackwardEstimator& be; void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { const size_t src_len = src_lets.size(); @@ -63,20 +230,29 @@ struct TransliterationsImpl { if (src_len >= graphs.size()) graphs.resize(src_len + 1); if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); GraphStructure& gs = graphs[src_len][trg_len]; - if (!gs.r) - gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK, kFILTER_RATIO); + if (!gs.r) { + double rat = exp(fabs(log(trg_len / (src_len * kS2T_RATIO)))); + if (rat > 1.5 || (rat > 2.4 && src_len < 6)) { + cerr << " ** Forbidding transliterations of size " << src_len << "," << trg_len << ": " << rat << endl; + gs.r = new Reachability(src_len, trg_len, 0, 0); + } else { + gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK); + } + } + const Reachability& r = *gs.r; // init backward estimates - if (src >= bes.size()) bes.resize(src + 1); - unordered_map::iterator it = bes[src].find(trg); - if (it != bes[src].end()) return; // already initialized + if (src >= ests.size()) ests.resize(src + 1); + unordered_map::iterator it = ests[src].find(trg); + if (it != ests[src].end()) return; // already initialized - it = bes[src].insert(make_pair(trg, BackwardEstimates(gs))).first; - BackwardEstimates& b = it->second; + it = ests[src].insert(make_pair(trg, ProbabilityEstimates(gs))).first; + ProbabilityEstimates& est = it->second; if (!gs.r->nodes) return; // not derivable subject to length constraints - // TODO + be.InitializeGrid(src_lets, trg_lets, r, kS2T_RATIO, est.backward); + cerr << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << " ||| " << (est.backward[0] / trg_lets.size()) << endl; tot_pairs++; tot_mem += sizeof(float) * gs.r->nodes; } @@ -92,8 +268,11 @@ struct TransliterationsImpl { const vector& tv = graphs[src.size()]; assert(trg.size() < tv.size()); const GraphStructure& gs = tv[trg.size()]; - // TODO: do prob - return prob_t::Zero(); + if (gs.r->nodes == 0) + return prob_t::Zero(); + const unordered_map::const_iterator it = ests[s].find(t); + assert(it != ests[s].end()); + return it->second.estp; } void GraphSummary() const { @@ -126,15 +305,15 @@ struct TransliterationsImpl { const int kMAX_SRC_CHUNK; const int kMAX_TRG_CHUNK; - const double kFILTER_RATIO; + const double kS2T_RATIO; unsigned tot_pairs; size_t tot_mem; vector > graphs; // graphs[src_len][trg_len] - vector > bes; // bes[src][trg] + vector > ests; // ests[src][trg] }; -Transliterations::Transliterations(int max_src, int max_trg, double fr) : - pimpl_(new TransliterationsImpl(max_src, max_trg, fr)) {} +Transliterations::Transliterations(int max_src, int max_trg, double sr, const BackwardEstimator& be) : + pimpl_(new TransliterationsImpl(max_src, max_trg, sr, be)) {} Transliterations::~Transliterations() { delete pimpl_; } void Transliterations::Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index ea9f9d3f..49d14684 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -5,11 +5,12 @@ #include "wordid.h" #include "prob.h" +struct BackwardEstimator; struct TransliterationsImpl; struct Transliterations { // max_src and max_trg indicate how big the transliteration phrases can be // see reachability.h for information about filter_ratio - explicit Transliterations(int max_src, int max_trg, double filter_ratio); + explicit Transliterations(int max_src, int max_trg, double s2t_rat, const BackwardEstimator& be); ~Transliterations(); void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); void Forbid(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); diff --git a/utils/ccrp_nt.h b/utils/ccrp_nt.h index 79321493..6efbfc78 100644 --- a/utils/ccrp_nt.h +++ b/utils/ccrp_nt.h @@ -11,6 +11,7 @@ #include #include "sampler.h" #include "slice_sampler.h" +#include "m.h" // Chinese restaurant process (1 parameter) template > @@ -29,6 +30,7 @@ class CCRP_NoTable { alpha_prior_rate_(c_rate) {} double alpha() const { return alpha_; } + void set_alpha(const double& alpha) { alpha_ = alpha; assert(alpha_ > 0.0); } bool has_alpha_prior() const { return !std::isnan(alpha_prior_shape_); @@ -71,9 +73,10 @@ class CCRP_NoTable { return table_diff; } - double prob(const Dish& dish, const double& p0) const { + template + F prob(const Dish& dish, const F& p0) const { const unsigned at_table = num_customers(dish); - return (at_table + p0 * alpha_) / (num_customers_ + alpha_); + return (F(at_table) + p0 * F(alpha_)) / F(num_customers_ + alpha_); } double logprob(const Dish& dish, const double& logp0) const { @@ -85,20 +88,12 @@ class CCRP_NoTable { return log_crp_prob(alpha_); } - static double log_gamma_density(const double& x, const double& shape, const double& rate) { - assert(x >= 0.0); - assert(shape > 0.0); - assert(rate > 0.0); - const double lp = (shape-1)*log(x) - shape*log(rate) - x/rate - lgamma(shape); - return lp; - } - // taken from http://en.wikipedia.org/wiki/Chinese_restaurant_process // does not include P_0's double log_crp_prob(const double& alpha) const { double lp = 0.0; if (has_alpha_prior()) - lp += log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); + lp += Md::log_gamma_density(alpha, alpha_prior_shape_, alpha_prior_rate_); assert(lp <= 0.0); if (num_customers_) { lp += lgamma(alpha) - lgamma(alpha + num_customers_) + -- cgit v1.2.3