diff options
Diffstat (limited to 'gi')
| -rw-r--r-- | gi/pf/Makefile.am | 7 | ||||
| -rw-r--r-- | gi/pf/align-tl.cc | 334 | ||||
| -rw-r--r-- | gi/pf/conditional_pseg.h | 11 | ||||
| -rw-r--r-- | gi/pf/nuisance_test.cc | 161 | ||||
| -rw-r--r-- | gi/pf/transliterations.cc | 193 | ||||
| -rw-r--r-- | gi/pf/transliterations.h | 20 | 
6 files changed, 723 insertions, 3 deletions
| diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am index 7cf9c14d..5e89f02a 100644 --- a/gi/pf/Makefile.am +++ b/gi/pf/Makefile.am @@ -1,12 +1,17 @@ -bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp learn_cfg pyp_lm +bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexonly align-lexonly-pyp learn_cfg pyp_lm nuisance_test align-tl  noinst_LIBRARIES = libpf.a +  libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc +nuisance_test_SOURCES = nuisance_test.cc transliterations.cc +  align_lexonly_SOURCES = align-lexonly.cc  align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc +align_tl_SOURCES = align-tl.cc transliterations.cc +  itg_SOURCES = itg.cc  pyp_lm_SOURCES = pyp_lm.cc diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc new file mode 100644 index 00000000..0e0454e5 --- /dev/null +++ b/gi/pf/align-tl.cc @@ -0,0 +1,334 @@ +#include <iostream> +#include <tr1/memory> +#include <queue> + +#include <boost/multi_array.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "array2d.h" +#include "base_distributions.h" +#include "monotonic_pseg.h" +#include "conditional_pseg.h" +#include "trule.h" +#include "tdict.h" +#include "stringlib.h" +#include "filelib.h" +#include "dict.h" +#include "sampler.h" +#include "mfcr.h" +#include "corpus.h" +#include "ngram_base.h" +#include "transliterations.h" + +using namespace std; +using namespace tr1; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +  po::options_description opts("Configuration options"); +  opts.add_options() +        ("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples") +        ("input,i",po::value<string>(),"Read parallel data from") +        ("random_seed,S",po::value<uint32_t>(), "Random seed"); +  po::options_description clo("Command line options"); +  clo.add_options() +        ("config", po::value<string>(), "Configuration file") +        ("help,h", "Print this help message and exit"); +  po::options_description dconfig_options, dcmdline_options; +  dconfig_options.add(opts); +  dcmdline_options.add(opts).add(clo); +   +  po::store(parse_command_line(argc, argv, dcmdline_options), *conf); +  if (conf->count("config")) { +    ifstream config((*conf)["config"].as<string>().c_str()); +    po::store(po::parse_config_file(config, dconfig_options), *conf); +  } +  po::notify(*conf); + +  if (conf->count("help") || (conf->count("input") == 0)) { +    cerr << dcmdline_options << endl; +    exit(1); +  } +} + +shared_ptr<MT19937> prng; + +struct LexicalAlignment { +  unsigned char src_index; +  bool is_transliteration; +  vector<pair<short, short> > derivation; +}; + +struct AlignedSentencePair { +  vector<WordID> src; +  vector<WordID> trg; +  vector<LexicalAlignment> a; +  Array2D<short> posterior; +}; + +struct HierarchicalWordBase { +  explicit HierarchicalWordBase(const unsigned vocab_e_size) : +      base(prob_t::One()), r(1,1,1,1,0.66,50.0), u0(-log(vocab_e_size)), l(1,prob_t::One()), v(1, prob_t::Zero()) {} + +  void ResampleHyperparameters(MT19937* rng) { +    r.resample_hyperparameters(rng); +  } + +  inline double logp0(const vector<WordID>& s) const { +    return Md::log_poisson(s.size(), 7.5) + s.size() * u0; +  } + +  // return p0 of rule.e_ +  prob_t operator()(const TRule& rule) const { +    v[0].logeq(logp0(rule.e_)); +    return r.prob(rule.e_, v.begin(), l.begin()); +  } + +  void Increment(const TRule& rule) { +    v[0].logeq(logp0(rule.e_)); +    if (r.increment(rule.e_, v.begin(), l.begin(), &*prng).count) { +      base *= v[0] * l[0]; +    } +  } + +  void Decrement(const TRule& rule) { +    if (r.decrement(rule.e_, &*prng).count) { +      base /= prob_t(exp(logp0(rule.e_))); +    } +  } + +  prob_t Likelihood() const { +    prob_t p; p.logeq(r.log_crp_prob()); +    p *= base; +    return p; +  } + +  void Summary() const { +    cerr << "NUMBER OF CUSTOMERS: " << r.num_customers() << "  (d=" << r.discount() << ",s=" << r.strength() << ')' << endl; +    for (MFCR<1,vector<WordID> >::const_iterator it = r.begin(); it != r.end(); ++it) +      cerr << "   " << it->second.total_dish_count_ << " (on " << it->second.table_counts_.size() << " tables) " << TD::GetString(it->first) << endl; +  } + +  prob_t base; +  MFCR<1,vector<WordID> > r; +  const double u0; +  const vector<prob_t> l; +  mutable vector<prob_t> v; +}; + +struct BasicLexicalAlignment { +  explicit BasicLexicalAlignment(const vector<vector<WordID> >& lets, +                                 const unsigned words_e, +                                 const unsigned letters_e, +                                 vector<AlignedSentencePair>* corp) : +      letters(lets), +      corpus(*corp), +      //up0(words_e), +      //up0("en.chars.1gram", letters_e), +      //up0("en.words.1gram"), +      up0(letters_e), +      //up0("en.chars.2gram"), +      tmodel(up0) { +  } + +  void InstantiateRule(const WordID src, +                       const WordID trg, +                       TRule* rule) const { +    static const WordID kX = TD::Convert("X") * -1; +    rule->lhs_ = kX; +    rule->e_ = letters[trg]; +    rule->f_ = letters[src]; +  } + +  void InitializeRandom() { +    const WordID kNULL = TD::Convert("NULL"); +    cerr << "Initializing with random alignments ...\n"; +    for (unsigned i = 0; i < corpus.size(); ++i) { +      AlignedSentencePair& asp = corpus[i]; +      asp.a.resize(asp.trg.size()); +      for (unsigned j = 0; j < asp.trg.size(); ++j) { +        const unsigned char a_j = prng->next() * (1 + asp.src.size()); +        const WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); +        TRule r; +        InstantiateRule(f_a_j, asp.trg[j], &r); +        asp.a[j].is_transliteration = false; +        asp.a[j].src_index = a_j; +        if (tmodel.IncrementRule(r, &*prng)) +          up0.Increment(r); +      } +    } +    cerr << "  LLH = " << Likelihood() << endl; +  } + +  prob_t Likelihood() const { +    prob_t p = tmodel.Likelihood(); +    p *= up0.Likelihood(); +    return p; +  } + +  void ResampleHyperparemeters() { +    tmodel.ResampleHyperparameters(&*prng); +    up0.ResampleHyperparameters(&*prng); +    cerr << "  (base d=" << up0.r.discount() << ",s=" << up0.r.strength() << ")\n"; +  } + +  void ResampleCorpus(); + +  const vector<vector<WordID> >& letters; // spelling dictionary +  vector<AlignedSentencePair>& corpus; +  //PhraseConditionalUninformativeBase up0; +  //PhraseConditionalUninformativeUnigramBase up0; +  //UnigramWordBase up0; +  //HierarchicalUnigramBase up0; +  HierarchicalWordBase up0; +  //CompletelyUniformBase up0; +  //FixedNgramBase up0; +  //ConditionalTranslationModel<PhraseConditionalUninformativeBase> tmodel; +  //ConditionalTranslationModel<PhraseConditionalUninformativeUnigramBase> tmodel; +  //ConditionalTranslationModel<UnigramWordBase> tmodel; +  //ConditionalTranslationModel<HierarchicalUnigramBase> tmodel; +  MConditionalTranslationModel<HierarchicalWordBase> tmodel; +  //ConditionalTranslationModel<FixedNgramBase> tmodel; +  //ConditionalTranslationModel<CompletelyUniformBase> tmodel; +}; + +void BasicLexicalAlignment::ResampleCorpus() { +  static const WordID kNULL = TD::Convert("NULL"); +  for (unsigned i = 0; i < corpus.size(); ++i) { +    AlignedSentencePair& asp = corpus[i]; +    SampleSet<prob_t> ss; ss.resize(asp.src.size() + 1); +    for (unsigned j = 0; j < asp.trg.size(); ++j) { +      TRule r; +      unsigned char& a_j = asp.a[j].src_index; +      WordID f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); +      InstantiateRule(f_a_j, asp.trg[j], &r); +      if (tmodel.DecrementRule(r, &*prng)) +        up0.Decrement(r); + +      for (unsigned prop_a_j = 0; prop_a_j <= asp.src.size(); ++prop_a_j) { +        const WordID prop_f = (prop_a_j ? asp.src[prop_a_j - 1] : kNULL); +        InstantiateRule(prop_f, asp.trg[j], &r); +        ss[prop_a_j] = tmodel.RuleProbability(r); +      } +      a_j = prng->SelectSample(ss); +      f_a_j = (a_j ? asp.src[a_j - 1] : kNULL); +      InstantiateRule(f_a_j, asp.trg[j], &r); +      if (tmodel.IncrementRule(r, &*prng)) +        up0.Increment(r); +    } +  } +  cerr << "  LLH = " << Likelihood() << endl; +} + +void ExtractLetters(const set<WordID>& v, vector<vector<WordID> >* l, set<WordID>* letset = NULL) { +  for (set<WordID>::const_iterator it = v.begin(); it != v.end(); ++it) { +    vector<WordID>& letters = (*l)[*it]; +    if (letters.size()) continue;   // if e and f have the same word + +    const string& w = TD::Convert(*it); +     +    size_t cur = 0; +    while (cur < w.size()) { +      const size_t len = UTF8Len(w[cur]); +      letters.push_back(TD::Convert(w.substr(cur, len))); +      if (letset) letset->insert(letters.back()); +      cur += len; +    } +  } +} + +void Debug(const AlignedSentencePair& asp) { +  cerr << TD::GetString(asp.src) << endl << TD::GetString(asp.trg) << endl; +  Array2D<bool> a(asp.src.size(), asp.trg.size()); +  for (unsigned j = 0; j < asp.trg.size(); ++j) +    if (asp.a[j].src_index) a(asp.a[j].src_index - 1, j) = true; +  cerr << a << endl; +} + +void AddSample(AlignedSentencePair* asp) { +  for (unsigned j = 0; j < asp->trg.size(); ++j) +    asp->posterior(asp->a[j].src_index, j)++; +} + +void WriteAlignments(const AlignedSentencePair& asp) { +  bool first = true; +  for (unsigned j = 0; j < asp.trg.size(); ++j) { +    int src_index = -1; +    int mc = -1; +    for (unsigned i = 0; i <= asp.src.size(); ++i) { +      if (asp.posterior(i, j) > mc) { +        mc = asp.posterior(i, j); +        src_index = i; +      } +    } + +    if (src_index) { +      if (first) first = false; else cout << ' '; +      cout << (src_index - 1) << '-' << j; +    } +  } +  cout << endl; +} + +int main(int argc, char** argv) { +  po::variables_map conf; +  InitCommandLine(argc, argv, &conf); + +  if (conf.count("random_seed")) +    prng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); +  else +    prng.reset(new MT19937); +//  MT19937& rng = *prng; + +  vector<vector<int> > corpuse, corpusf; +  set<int> vocabe, vocabf; +  corpus::ReadParallelCorpus(conf["input"].as<string>(), &corpusf, &corpuse, &vocabf, &vocabe); +  cerr << "f-Corpus size: " << corpusf.size() << " sentences\n"; +  cerr << "f-Vocabulary size: " << vocabf.size() << " types\n"; +  cerr << "f-Corpus size: " << corpuse.size() << " sentences\n"; +  cerr << "f-Vocabulary size: " << vocabe.size() << " types\n"; +  assert(corpusf.size() == corpuse.size()); + +  vector<AlignedSentencePair> corpus(corpuse.size()); +  for (unsigned i = 0; i < corpuse.size(); ++i) { +    corpus[i].src.swap(corpusf[i]); +    corpus[i].trg.swap(corpuse[i]); +    corpus[i].posterior.resize(corpus[i].src.size() + 1, corpus[i].trg.size()); +  } +  corpusf.clear(); corpuse.clear(); + +  vocabf.insert(TD::Convert("NULL")); +  vector<vector<WordID> > letters(TD::NumWords()); +  set<WordID> letset; +  ExtractLetters(vocabe, &letters, &letset); +  ExtractLetters(vocabf, &letters, NULL); +  letters[TD::Convert("NULL")].clear(); + +  Transliterations tl; + +  // TODO CONFIGURE THIS +  int min_trans_src = 4; + +  cerr << "Initializing transliteration DPs ...\n"; +  for (int i = 0; i < corpus.size(); ++i) { +    const vector<int>& src = corpus[i].src; +    const vector<int>& trg = corpus[i].trg; +    cerr << '.' << flush; +    if (i % 80 == 79) cerr << endl; +    for (int j = 0; j < src.size(); ++j) { +      const vector<int>& src_let = letters[src[j]]; +      for (int k = 0; k < trg.size(); ++k) { +        const vector<int>& trg_let = letters[trg[k]]; +        if (src_let.size() < min_trans_src) +          tl.Forbid(src[j], trg[k]); +        else +          tl.Initialize(src[j], src_let, trg[k], trg_let); +      } +    } +  } +  cerr << endl; +  tl.GraphSummary(); + +  return 0; +} diff --git a/gi/pf/conditional_pseg.h b/gi/pf/conditional_pseg.h index 8202778b..81ddb206 100644 --- a/gi/pf/conditional_pseg.h +++ b/gi/pf/conditional_pseg.h @@ -56,6 +56,12 @@ struct MConditionalTranslationModel {    };    void ResampleHyperparameters(MT19937* rng) { +    typename std::tr1::unordered_map<std::vector<WordID>, MFCR<1,TRule>, boost::hash<std::vector<WordID> > >::iterator it; +#if 1 +    for (it = r.begin(); it != r.end(); ++it) { +      it->second.resample_hyperparameters(rng); +    } +#else      const unsigned nloop = 5;      const unsigned niterations = 10;      DiscountResampler dr(*this); @@ -70,12 +76,12 @@ struct MConditionalTranslationModel {      }      strength = slice_sampler1d(ar, strength, *rng, -d,                              std::numeric_limits<double>::infinity(), 0.0, niterations, 100*niterations); -    typename std::tr1::unordered_map<std::vector<WordID>, MFCR<1,TRule>, boost::hash<std::vector<WordID> > >::iterator it;      std::cerr << "MConditionalTranslationModel(d=" << d << ",s=" << strength << ") = " << log_likelihood(d, strength) << std::endl;      for (it = r.begin(); it != r.end(); ++it) {        it->second.set_discount(d);        it->second.set_strength(strength);      } +#endif    }    int DecrementRule(const TRule& rule, MT19937* rng) { @@ -91,7 +97,8 @@ struct MConditionalTranslationModel {    int IncrementRule(const TRule& rule, MT19937* rng) {      RuleModelHash::iterator it = r.find(rule.f_);      if (it == r.end()) { -      it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(d, strength))).first; +      //it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(d, strength))).first; +      it = r.insert(make_pair(rule.f_, MFCR<1,TRule>(1,1,1,1,0.6, -0.12))).first;      }      p0s[0] = rp0(rule);       TableCount delta = it->second.increment(rule, p0s.begin(), lambdas.begin(), rng); diff --git a/gi/pf/nuisance_test.cc b/gi/pf/nuisance_test.cc new file mode 100644 index 00000000..0f44fe95 --- /dev/null +++ b/gi/pf/nuisance_test.cc @@ -0,0 +1,161 @@ +#include "ccrp.h" + +#include <vector> +#include <iostream> + +#include "tdict.h" +#include "transliterations.h" + +using namespace std; + +MT19937 rng; + +ostream& operator<<(ostream&os, const vector<int>& v) { +  os << '[' << v[0]; +  if (v.size() == 2) os << ' ' << v[1]; +  return os << ']'; +} + +struct Base { +  Base() : llh(), v(2), v1(1), v2(1), crp(0.25, 0.5) {} +  inline double p0(const vector<int>& x) const { +    double p = 0.75; +    if (x.size() == 2) p = 0.25; +    p *= 1.0 / 3.0; +    if (x.size() == 2) p *= 1.0 / 3.0; +    return p; +  } +  double est_deriv_prob(int a, int b, int seg) const { +    assert(a > 0 && a < 4);  // a \in {1,2,3} +    assert(b > 0 && b < 4);  // b \in {1,2,3} +    assert(seg == 0 || seg == 1);   // seg \in {0,1} +    if (seg == 0) { +      v[0] = a; +      v[1] = b; +      return crp.prob(v, p0(v)); +    } else { +      v1[0] = a; +      v2[0] = b; +      return crp.prob(v1, p0(v1)) * crp.prob(v2, p0(v2)); +    } +  } +  double est_marginal_prob(int a, int b) const { +    return est_deriv_prob(a,b,0) + est_deriv_prob(a,b,1); +  } +  int increment(int a, int b, double* pw = NULL) { +    double p1 = est_deriv_prob(a, b, 0); +    double p2 = est_deriv_prob(a, b, 1); +    //p1 = 0.5; p2 = 0.5; +    int seg = rng.SelectSample(p1,p2); +    double tmp = 0; +    if (!pw) pw = &tmp; +    double& w = *pw; +    if (seg == 0) { +      v[0] = a; +      v[1] = b; +      w = crp.prob(v, p0(v)) / p1; +      if (crp.increment(v, p0(v), &rng)) { +        llh += log(p0(v)); +      } +    } else { +      v1[0] = a; +      w = crp.prob(v1, p0(v1)) / p2; +      if (crp.increment(v1, p0(v1), &rng)) { +        llh += log(p0(v1)); +      } +      v2[0] = b; +      w *= crp.prob(v2, p0(v2)); +      if (crp.increment(v2, p0(v2), &rng)) { +        llh += log(p0(v2)); +      } +    } +    return seg; +  } +  void increment(int a, int b, int seg) { +    if (seg == 0) { +      v[0] = a; +      v[1] = b; +      if (crp.increment(v, p0(v), &rng)) { +        llh += log(p0(v)); +      } +    } else { +      v1[0] = a; +      if (crp.increment(v1, p0(v1), &rng)) { +        llh += log(p0(v1)); +      } +      v2[0] = b; +      if (crp.increment(v2, p0(v2), &rng)) { +        llh += log(p0(v2)); +      } +    } +  } +  void decrement(int a, int b, int seg) { +    if (seg == 0) { +      v[0] = a; +      v[1] = b; +      if (crp.decrement(v, &rng)) { +        llh -= log(p0(v)); +      } +    } else { +      v1[0] = a; +      if (crp.decrement(v1, &rng)) { +        llh -= log(p0(v1)); +      } +      v2[0] = b; +      if (crp.decrement(v2, &rng)) { +        llh -= log(p0(v2)); +      } +    } +  } +  double log_likelihood() const { +    return llh + crp.log_crp_prob(); +  } +  double llh; +  mutable vector<int> v, v1, v2; +  CCRP<vector<int> > crp; +}; + +int main(int argc, char** argv) { +  double tl = 0; +  const int ITERS = 1000; +  const int PARTICLES = 20; +  const int DATAPOINTS = 50; +  WordID x = TD::Convert("souvenons"); +  WordID y = TD::Convert("remember"); +  vector<WordID> src; TD::ConvertSentence("s o u v e n o n s", &src); +  vector<WordID> trg; TD::ConvertSentence("r e m e m b e r", &trg); +  Transliterations xx; +  xx.Initialize(x, src, y, trg); +  return 1; + + for (int j = 0; j < ITERS; ++j) { +  Base b; +  vector<int> segs(DATAPOINTS); +  SampleSet<double> ss; +  vector<int> sss; +  for (int i = 0; i < DATAPOINTS; i++) { +    ss.clear(); +    sss.clear(); +    int x = ((i / 10) % 3) + 1; +    int y = (i % 3) + 1; +    //double ep = b.est_marginal_prob(x,y); +    //cerr << "est p(" << x << "," << y << ") = " << ep << endl; +    for (int n = 0; n < PARTICLES; ++n) { +      double w; +      int seg = b.increment(x,y,&w); +      //cerr << seg << " w=" << w << endl; +      ss.add(w); +      sss.push_back(seg); +      b.decrement(x,y,seg); +    } +    int seg = sss[rng.SelectSample(ss)]; +    b.increment(x, y, seg); +    //cerr << "Selected: " << seg << endl; +    //return 1; +    segs[i] = seg; +  } +  tl += b.log_likelihood(); + } +  cerr << "LLH=" << tl / ITERS << endl; +} + diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc new file mode 100644 index 00000000..6e0c2e93 --- /dev/null +++ b/gi/pf/transliterations.cc @@ -0,0 +1,193 @@ +#include "transliterations.h" + +#include <iostream> +#include <vector> +#include <tr1/unordered_map> + +#include "grammar.h" +#include "bottom_up_parser.h" +#include "hg.h" +#include "hg_intersect.h" +#include "filelib.h" +#include "ccrp.h" +#include "m.h" +#include "lattice.h" +#include "verbose.h" + +using namespace std; +using namespace std::tr1; + +static WordID kX; +static int kMAX_SRC_SIZE = 0; +static vector<vector<WordID> > cur_trg_chunks; + +vector<GrammarIter*> tlttofreelist; + +static void InitTargetChunks(int max_size, const vector<WordID>& trg) { +  cur_trg_chunks.clear(); +  vector<WordID> tmp; +  unordered_set<vector<WordID>, boost::hash<vector<WordID> > > u; +  for (int len = 1; len <= max_size; ++len) { +    int end = trg.size() + 1; +    end -= len; +    for (int i = 0; i < end; ++i) { +      tmp.clear(); +      for (int j = 0; j < len; ++j) +        tmp.push_back(trg[i + j]); +      if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp); +    } +  } +} + +struct TransliterationGrammarIter : public GrammarIter, public RuleBin { +  TransliterationGrammarIter() { tlttofreelist.push_back(this); } +  TransliterationGrammarIter(const TRulePtr& inr, int symbol) { +    if (inr) { +      r.reset(new TRule(*inr)); +    } else { +      r.reset(new TRule); +    } +    TRule& rr = *r; +    rr.lhs_ = kX; +    rr.f_.push_back(symbol); +    tlttofreelist.push_back(this); +  } +  virtual int GetNumRules() const { +    if (!r) return 0; +    return cur_trg_chunks.size(); +  } +  virtual TRulePtr GetIthRule(int i) const { +    TRulePtr nr(new TRule(*r)); +    nr->e_ = cur_trg_chunks[i]; +    //cerr << nr->AsString() << endl; +    return nr; +  } +  virtual int Arity() const { +    return 0; +  } +  virtual const RuleBin* GetRules() const { +    if (!r) return NULL; else return this; +  } +  virtual const GrammarIter* Extend(int symbol) const { +    if (symbol <= 0) return NULL; +    if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE) +      return new TransliterationGrammarIter(r, symbol); +    else +      return NULL; +  } +  TRulePtr r; +}; + +struct TransliterationGrammar : public Grammar { +  virtual const GrammarIter* GetRoot() const { +    return new TransliterationGrammarIter; +  } +  virtual bool HasRuleForSpan(int, int, int distance) const { +    return (distance < kMAX_SRC_SIZE); +  } +}; + +struct TInfo { +  TInfo() : initialized(false) {} +  bool initialized; +  Hypergraph lattice;   // may be empty if transliteration is not possible +  prob_t est_prob;      // will be zero if not possible +}; + +struct TransliterationsImpl { +  TransliterationsImpl() { +    kX = TD::Convert("X")*-1; +    kMAX_SRC_SIZE = 4; +    grammars.push_back(GrammarPtr(new TransliterationGrammar)); +    grammars.push_back(GrammarPtr(new GlueGrammar("S", "X"))); +    SetSilent(true); +  } + +  void Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) { +    if (src >= graphs.size()) graphs.resize(src + 1); +    if (graphs[src][trg].initialized) return; +    int kMAX_TRG_SIZE = 4; +    InitTargetChunks(kMAX_TRG_SIZE, trg_lets); +    ExhaustiveBottomUpParser parser("S", grammars); +    Lattice lat(src_lets.size()), tlat(trg_lets.size()); +    for (unsigned i = 0; i < src_lets.size(); ++i) +      lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1)); +    for (unsigned i = 0; i < trg_lets.size(); ++i) +      tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1)); +    //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl; +    //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl; +    if (!parser.Parse(lat, &graphs[src][trg].lattice)) { +      //cerr << "Failed to parse " << TD::GetString(src_lets) << endl; +      abort(); +    } +    if (HG::Intersect(tlat, &graphs[src][trg].lattice)) { +      graphs[src][trg].est_prob = prob_t(1e-4); +    } else { +      graphs[src][trg].lattice.clear(); +      //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl; +      graphs[src][trg].est_prob = prob_t::Zero(); +    } +    for (unsigned i = 0; i < tlttofreelist.size(); ++i) +      delete tlttofreelist[i]; +    tlttofreelist.clear(); +    //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; +    graphs[src][trg].initialized = true; +  } + +  const prob_t& EstimateProbability(WordID src, WordID trg) const { +    assert(src < graphs.size()); +    const unordered_map<WordID, TInfo>& um = graphs[src]; +    const unordered_map<WordID, TInfo>::const_iterator it = um.find(trg); +    assert(it != um.end()); +    assert(it->second.initialized); +    return it->second.est_prob; +  } + +  void Forbid(WordID src, WordID trg) { +    if (src >= graphs.size()) graphs.resize(src + 1); +    graphs[src][trg].est_prob = prob_t::Zero(); +    graphs[src][trg].initialized = true; +  } + +  void GraphSummary() const { +    double tlp = 0; +    int tt = 0; +    for (int i = 0; i < graphs.size(); ++i) { +      const unordered_map<WordID, TInfo>& um = graphs[i]; +      unordered_map<WordID, TInfo>::const_iterator it; +      for (it = um.begin(); it != um.end(); ++it) { +        if (it->second.lattice.empty()) continue; +        //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl; +        tlp += log(it->second.lattice.NumberOfPaths()); +        tt++; +      } +    } +    tlp /= tt; +    cerr << "E[log paths] = " << tlp << endl; +    cerr << "exp(E[log paths]) = " << exp(tlp) << endl; +  } + +  vector<unordered_map<WordID, TInfo> > graphs; +  vector<GrammarPtr> grammars; +}; + +Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} +Transliterations::~Transliterations() { delete pimpl_; } + +void Transliterations::Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) { +  pimpl_->Initialize(src, src_lets, trg, trg_lets); +} + +prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const { +  return pimpl_->EstimateProbability(src,trg); +} + +void Transliterations::Forbid(WordID src, WordID trg) { +  pimpl_->Forbid(src, trg); +} + +void Transliterations::GraphSummary() const { +  pimpl_->GraphSummary(); +} + + diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h new file mode 100644 index 00000000..a548aacf --- /dev/null +++ b/gi/pf/transliterations.h @@ -0,0 +1,20 @@ +#ifndef _TRANSLITERATIONS_H_ +#define _TRANSLITERATIONS_H_ + +#include <vector> +#include "wordid.h" +#include "prob.h" + +struct TransliterationsImpl; +struct Transliterations { +  explicit Transliterations(); +  ~Transliterations(); +  void Initialize(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets); +  void Forbid(WordID src, WordID trg); +  void GraphSummary() const; +  prob_t EstimateProbability(WordID src, WordID trg) const; +  TransliterationsImpl* pimpl_; +}; + +#endif + | 
