From 95e9ea690b87f4648215782e820e177cbe17f18b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 20 Oct 2011 15:21:54 +0100 Subject: bidir model1 base measure --- gi/pf/base_measures.cc | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++ gi/pf/base_measures.h | 31 ++++++++++++++++++++++++++++ gi/pf/dpnaive.cc | 17 ++++++++++----- gi/pf/monotonic_pseg.h | 5 +++-- gi/pf/pfnaive.cc | 12 ++++++++++- 5 files changed, 113 insertions(+), 8 deletions(-) (limited to 'gi') diff --git a/gi/pf/base_measures.cc b/gi/pf/base_measures.cc index f8ddfd32..8adb37d7 100644 --- a/gi/pf/base_measures.cc +++ b/gi/pf/base_measures.cc @@ -89,6 +89,62 @@ prob_t PhraseJointBase::p0(const vector& vsrc, return p; } +prob_t PhraseJointBase_BiDir::p0(const vector& vsrc, + const vector& vtrg, + int start_src, int start_trg) const { + const int flen = vsrc.size() - start_src; + const int elen = vtrg.size() - start_trg; + prob_t uniform_src_alignment; uniform_src_alignment.logeq(-log(flen + 1)); + prob_t uniform_trg_alignment; uniform_trg_alignment.logeq(-log(elen + 1)); + + prob_t p1; + p1.logeq(log_poisson(flen, 1.0)); // flen ~Pois(1) + // elen | flen ~Pois(flen + 0.01) + prob_t ptrglen; ptrglen.logeq(log_poisson(elen, flen + 0.01)); + p1 *= ptrglen; + p1 *= kUNIFORM_SOURCE.pow(flen); // each f in F ~Uniform + for (int i = 0; i < elen; ++i) { // for each position i in E + const WordID trg = vtrg[i + start_trg]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < flen; ++j) { + const WordID src = j < 0 ? 0 : vsrc[j + start_src]; + tp += kM1MIXTURE * model1(src, trg); + tp += kUNIFORM_MIXTURE * kUNIFORM_TARGET; + } + tp *= uniform_src_alignment; // draw a_i ~uniform + p1 *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + if (p1.is_0()) { + cerr << "Zero! " << vsrc << "\nTRG=" << vtrg << endl; + abort(); + } + + prob_t p2; + p2.logeq(log_poisson(elen, 1.0)); // elen ~Pois(1) + // flen | elen ~Pois(flen + 0.01) + prob_t psrclen; psrclen.logeq(log_poisson(flen, elen + 0.01)); + p2 *= psrclen; + p2 *= kUNIFORM_TARGET.pow(elen); // each f in F ~Uniform + for (int i = 0; i < flen; ++i) { // for each position i in E + const WordID src = vsrc[i + start_src]; + prob_t tp = prob_t::Zero(); + for (int j = -1; j < elen; ++j) { + const WordID trg = j < 0 ? 0 : vtrg[j + start_trg]; + tp += kM1MIXTURE * invmodel1(trg, src); + tp += kUNIFORM_MIXTURE * kUNIFORM_SOURCE; + } + tp *= uniform_trg_alignment; // draw a_i ~uniform + p2 *= tp; // draw e_i ~Model1(f_a_i) / uniform + } + if (p2.is_0()) { + cerr << "Zero! " << vsrc << "\nTRG=" << vtrg << endl; + abort(); + } + + static const prob_t kHALF(0.5); + return (p1 + p2) * kHALF; +} + JumpBase::JumpBase() : p(200) { for (unsigned src_len = 1; src_len < 200; ++src_len) { map& cpd = p[src_len]; diff --git a/gi/pf/base_measures.h b/gi/pf/base_measures.h index df17aa62..7ce7e2e6 100644 --- a/gi/pf/base_measures.h +++ b/gi/pf/base_measures.h @@ -97,6 +97,37 @@ struct PhraseJointBase { const prob_t kUNIFORM_TARGET; }; +struct PhraseJointBase_BiDir { + explicit PhraseJointBase_BiDir(const Model1& m1, + const Model1& im1, + const double m1mixture, + const unsigned vocab_e_size, + const unsigned vocab_f_size) : + model1(m1), + invmodel1(im1), + kM1MIXTURE(m1mixture), + kUNIFORM_MIXTURE(1.0 - m1mixture), + kUNIFORM_SOURCE(1.0 / vocab_f_size), + kUNIFORM_TARGET(1.0 / vocab_e_size) { + assert(m1mixture >= 0.0 && m1mixture <= 1.0); + assert(vocab_e_size > 0); + } + + // return p0 of rule.e_ | rule.f_ + prob_t operator()(const TRule& rule) const { + return p0(rule.f_, rule.e_, 0, 0); + } + + prob_t p0(const std::vector& vsrc, const std::vector& vtrg, int start_src, int start_trg) const; + + const Model1& model1; + const Model1& invmodel1; + const prob_t kM1MIXTURE; // Model 1 mixture component + const prob_t kUNIFORM_MIXTURE; // uniform mixture component + const prob_t kUNIFORM_SOURCE; + const prob_t kUNIFORM_TARGET; +}; + // base distribution for jump size multinomials // basically p(0) = 0 and then, p(1) is max, and then // you drop as you move to the max jump distance diff --git a/gi/pf/dpnaive.cc b/gi/pf/dpnaive.cc index c926487b..db1c43c7 100644 --- a/gi/pf/dpnaive.cc +++ b/gi/pf/dpnaive.cc @@ -31,6 +31,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_src_phrase",po::value()->default_value(4),"Maximum length of source language phrases") ("max_trg_phrase",po::value()->default_value(4),"Maximum length of target language phrases") ("model1,m",po::value(),"Model 1 parameters (used in base distribution)") + ("inverse_model1,M",po::value(),"Inverse Model 1 parameters (used in base distribution)") ("model1_interpolation_weight",po::value()->default_value(0.95),"Mixing proportion of model 1 with uniform target distribution") ("random_seed,S",po::value(), "Random seed"); po::options_description clo("Command line options"); @@ -58,7 +59,7 @@ shared_ptr prng; template struct ModelAndData { - explicit ModelAndData(MonotonicParallelSegementationModel& m, const Base& b, const vector >& ce, const vector >& cf, const set& ve, const set& vf) : + explicit ModelAndData(MonotonicParallelSegementationModel& m, const Base& b, const vector >& ce, const vector >& cf, const set& ve, const set& vf) : model(m), rng(&*prng), p0(b), @@ -139,7 +140,7 @@ struct ModelAndData { void Sample(); - MonotonicParallelSegementationModel& model; + MonotonicParallelSegementationModel& model; MT19937* rng; const Base& p0; prob_t baseprob; // cached value of generating the table table labels from p0 @@ -267,6 +268,10 @@ int main(int argc, char** argv) { cerr << argv[0] << "Please use --model1 to specify model 1 parameters\n"; return 1; } + if (!conf.count("inverse_model1")) { + cerr << argv[0] << "Please use --inverse_model1 to specify inverse model 1 parameters\n"; + return 1; + } if (conf.count("random_seed")) prng.reset(new MT19937(conf["random_seed"].as())); else @@ -283,10 +288,12 @@ int main(int argc, char** argv) { assert(corpusf.size() == corpuse.size()); Model1 m1(conf["model1"].as()); - PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); - MonotonicParallelSegementationModel m(lp0); + Model1 invm1(conf["inverse_model1"].as()); +// PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + PhraseJointBase_BiDir alp0(m1, invm1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MonotonicParallelSegementationModel m(alp0); - ModelAndData posterior(m, lp0, corpuse, corpusf, vocabe, vocabf); + ModelAndData posterior(m, alp0, corpuse, corpusf, vocabe, vocabf); posterior.Sample(); return 0; diff --git a/gi/pf/monotonic_pseg.h b/gi/pf/monotonic_pseg.h index 7e6af3fc..301aa6d8 100644 --- a/gi/pf/monotonic_pseg.h +++ b/gi/pf/monotonic_pseg.h @@ -8,8 +8,9 @@ #include "trule.h" #include "base_measures.h" +template struct MonotonicParallelSegementationModel { - explicit MonotonicParallelSegementationModel(PhraseJointBase& rcp0) : + explicit MonotonicParallelSegementationModel(BaseMeasure& rcp0) : rp0(rcp0), base(prob_t::One()), rules(1,1), stop(1.0) {} void DecrementRule(const TRule& rule) { @@ -78,7 +79,7 @@ struct MonotonicParallelSegementationModel { return prob_t(stop.prob(false, 0.5)); } - const PhraseJointBase& rp0; + const BaseMeasure& rp0; prob_t base; CCRP_NoTable rules; CCRP_NoTable stop; diff --git a/gi/pf/pfnaive.cc b/gi/pf/pfnaive.cc index 33dc08c3..d967958c 100644 --- a/gi/pf/pfnaive.cc +++ b/gi/pf/pfnaive.cc @@ -181,7 +181,17 @@ int main(int argc, char** argv) { Model1 invm1(conf["inverse_model1"].as()); PhraseJointBase lp0(m1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); - MonotonicParallelSegementationModel m(lp0); + PhraseJointBase_BiDir alp0(m1, invm1, conf["model1_interpolation_weight"].as(), vocabe.size(), vocabf.size()); + MonotonicParallelSegementationModel m(alp0); + TRule xx("[X] ||| ms. kimura ||| MS. KIMURA ||| X=0"); + cerr << xx << endl << lp0(xx) << " " << alp0(xx) << endl; + TRule xx12("[X] ||| . ||| PHARMACY . ||| X=0"); + TRule xx21("[X] ||| pharmacy . ||| . ||| X=0"); +// TRule xx22("[X] ||| . ||| . ||| X=0"); + TRule xx22("[X] ||| . ||| THE . ||| X=0"); + cerr << xx12 << "\t" << lp0(xx12) << " " << alp0(xx12) << endl; + cerr << xx21 << "\t" << lp0(xx21) << " " << alp0(xx21) << endl; + cerr << xx22 << "\t" << lp0(xx22) << " " << alp0(xx22) << endl; cerr << "Initializing reachability limits...\n"; vector ps(corpusf.size()); -- cgit v1.2.3