summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-01-24 22:26:44 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-01-24 22:26:44 -0500
commit4c2360119def2fb624d2691b355b1908c511f004 (patch)
tree9dd7ce4b2884750822b433e0c2254a1f99dc3cc5
parent26d9ad04bd81508163d75c99726f970dd75f5127 (diff)
more models
-rw-r--r--gi/pf/align-lexonly.cc14
-rw-r--r--gi/pf/base_measures.cc2
-rw-r--r--gi/pf/base_measures.h27
-rw-r--r--training/model1.cc64
4 files changed, 98 insertions, 9 deletions
diff --git a/gi/pf/align-lexonly.cc b/gi/pf/align-lexonly.cc
index e9f1e7b6..76e2e009 100644
--- a/gi/pf/align-lexonly.cc
+++ b/gi/pf/align-lexonly.cc
@@ -122,10 +122,11 @@ struct BasicLexicalAlignment {
vector<AlignedSentencePair>* corp) :
letters(lets),
corpus(*corp),
+ up0("fr-en.10k.translit-base.txt.gz"),
//up0(words_e),
//up0("en.chars.1gram", letters_e),
//up0("en.words.1gram"),
- up0(letters_e),
+ //up0(letters_e),
//up0("en.chars.2gram"),
tmodel(up0) {
}
@@ -180,14 +181,18 @@ struct BasicLexicalAlignment {
//PhraseConditionalUninformativeUnigramBase up0;
//UnigramWordBase up0;
//HierarchicalUnigramBase up0;
- HierarchicalWordBase up0;
+ TableLookupBase up0;
+ //HierarchicalWordBase up0;
+ //PoissonUniformUninformativeBase up0;
//CompletelyUniformBase up0;
//FixedNgramBase up0;
//ConditionalTranslationModel<PhraseConditionalUninformativeBase> tmodel;
//ConditionalTranslationModel<PhraseConditionalUninformativeUnigramBase> tmodel;
//ConditionalTranslationModel<UnigramWordBase> tmodel;
//ConditionalTranslationModel<HierarchicalUnigramBase> tmodel;
- ConditionalTranslationModel<HierarchicalWordBase> tmodel;
+ //ConditionalTranslationModel<HierarchicalWordBase> tmodel;
+ //ConditionalTranslationModel<PoissonUniformUninformativeBase> tmodel;
+ ConditionalTranslationModel<TableLookupBase> tmodel;
//ConditionalTranslationModel<FixedNgramBase> tmodel;
//ConditionalTranslationModel<CompletelyUniformBase> tmodel;
};
@@ -222,6 +227,7 @@ void BasicLexicalAlignment::ResampleCorpus() {
void ExtractLetters(const set<WordID>& v, vector<vector<WordID> >* l, set<WordID>* letset = NULL) {
for (set<WordID>::const_iterator it = v.begin(); it != v.end(); ++it) {
+ if (*it >= l->size()) { l->resize(*it + 1); }
vector<WordID>& letters = (*l)[*it];
if (letters.size()) continue; // if e and f have the same word
@@ -308,7 +314,7 @@ int main(int argc, char** argv) {
x.InitializeRandom();
const unsigned samples = conf["samples"].as<unsigned>();
for (int i = 0; i < samples; ++i) {
- for (int j = 4995; j < 4997; ++j) Debug(corpus[j]);
+ for (int j = 395; j < 397; ++j) Debug(corpus[j]);
cerr << i << "\t" << x.tmodel.r.size() << "\t";
if (i % 10 == 0) x.ResampleHyperparemeters();
x.ResampleCorpus();
diff --git a/gi/pf/base_measures.cc b/gi/pf/base_measures.cc
index 7894d3e7..4b1863fa 100644
--- a/gi/pf/base_measures.cc
+++ b/gi/pf/base_measures.cc
@@ -37,7 +37,7 @@ TableLookupBase::TableLookupBase(const string& fname) {
} else if (cc == 1) {
x.e_.push_back(cur);
} else if (cc == 2) {
- table[x] = atof(TD::Convert(cur));
+ table[x].logeq(atof(TD::Convert(cur)));
++cc;
} else {
if (flag) cerr << endl;
diff --git a/gi/pf/base_measures.h b/gi/pf/base_measures.h
index 7214aa22..b0495bfd 100644
--- a/gi/pf/base_measures.h
+++ b/gi/pf/base_measures.h
@@ -51,6 +51,22 @@ struct Model1 {
std::vector<std::map<WordID, prob_t> > ttable;
};
+struct PoissonUniformUninformativeBase {
+ explicit PoissonUniformUninformativeBase(const unsigned ves) : kUNIFORM(1.0 / ves) {}
+ prob_t operator()(const TRule& r) const {
+ prob_t p; p.logeq(log_poisson(r.e_.size(), 1.0));
+ prob_t q = kUNIFORM; q.poweq(r.e_.size());
+ p *= q;
+ return p;
+ }
+ void Summary() const {}
+ void ResampleHyperparameters(MT19937*) {}
+ void Increment(const TRule&) {}
+ void Decrement(const TRule&) {}
+ prob_t Likelihood() const { return prob_t::One(); }
+ const prob_t kUNIFORM;
+};
+
struct CompletelyUniformBase {
explicit CompletelyUniformBase(const unsigned ves) : kUNIFORM(1.0 / ves) {}
prob_t operator()(const TRule&) const {
@@ -83,10 +99,19 @@ struct TableLookupBase {
prob_t operator()(const TRule& rule) const {
const std::tr1::unordered_map<TRule,prob_t>::const_iterator it = table.find(rule);
- assert(it != table.end());
+ if (it == table.end()) {
+ std::cerr << rule << " not found\n";
+ abort();
+ }
return it->second;
}
+ void ResampleHyperparameters(MT19937*) {}
+ void Increment(const TRule&) {}
+ void Decrement(const TRule&) {}
+ prob_t Likelihood() const { return prob_t::One(); }
+ void Summary() const {}
+
std::tr1::unordered_map<TRule,prob_t,RuleHasher> table;
};
diff --git a/training/model1.cc b/training/model1.cc
index 346c0033..40249aa3 100644
--- a/training/model1.cc
+++ b/training/model1.cc
@@ -14,6 +14,11 @@
namespace po = boost::program_options;
using namespace std;
+inline double log_poisson(unsigned x, const double& lambda) {
+ assert(lambda > 0.0);
+ return log(lambda) * x - lgamma(x + 1) - lambda;
+}
+
bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
@@ -25,6 +30,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)")
("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?")
("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights")
+ ("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model")
("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")
("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
po::options_description clo("Command line options");
@@ -63,6 +69,8 @@ int main(int argc, char** argv) {
const bool write_alignments = (conf.count("write_alignments") > 0);
const double diagonal_tension = conf["diagonal_tension"].as<double>();
const double prob_align_null = conf["prob_align_null"].as<double>();
+ string testset;
+ if (conf.count("testset")) testset = conf["testset"].as<string>();
const double prob_align_not_null = 1.0 - prob_align_null;
const double alpha = conf["alpha"].as<double>();
const bool favor_diagonal = conf.count("favor_diagonal");
@@ -73,6 +81,8 @@ int main(int argc, char** argv) {
TTable tt;
TTable::Word2Word2Double was_viterbi;
+ double tot_len_ratio = 0;
+ double mean_srclen_multiplier = 0;
for (int iter = 0; iter < ITERATIONS; ++iter) {
const bool final_iteration = (iter == (ITERATIONS - 1));
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
@@ -83,13 +93,13 @@ int main(int argc, char** argv) {
int lc = 0;
bool flag = false;
string line;
+ string ssrc, strg;
while(true) {
getline(in, line);
if (!in) break;
++lc;
if (lc % 1000 == 0) { cerr << '.'; flag = true; }
if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; }
- string ssrc, strg;
ParseTranslatorInput(line, &ssrc, &strg);
Lattice src, trg;
LatticeTools::ConvertTextToLattice(ssrc, &src);
@@ -99,9 +109,10 @@ int main(int argc, char** argv) {
assert(src.size() > 0);
assert(trg.size() > 0);
}
+ if (iter == 0)
+ tot_len_ratio += static_cast<double>(trg.size()) / static_cast<double>(src.size());
denom += trg.size();
vector<double> probs(src.size() + 1);
- const double src_logprob = -log(src.size() + 1);
bool first_al = true; // used for write_alignments
for (int j = 0; j < trg.size(); ++j) {
const WordID& f_j = trg[j][0].label;
@@ -156,7 +167,7 @@ int main(int argc, char** argv) {
for (int i = 1; i <= src.size(); ++i)
tt.Increment(src[i-1][0].label, f_j, probs[i] / sum);
}
- likelihood += log(sum) + src_logprob;
+ likelihood += log(sum);
}
if (write_alignments && final_iteration) cout << endl;
}
@@ -165,6 +176,10 @@ int main(int argc, char** argv) {
double base2_likelihood = likelihood / log(2);
if (flag) { cerr << endl; }
+ if (iter == 0) {
+ mean_srclen_multiplier = tot_len_ratio / lc;
+ cerr << "expected target length = source length * " << mean_srclen_multiplier << endl;
+ }
cerr << " log_e likelihood: " << likelihood << endl;
cerr << " log_2 likelihood: " << base2_likelihood << endl;
cerr << " cross entropy: " << (-base2_likelihood / denom) << endl;
@@ -176,6 +191,49 @@ int main(int argc, char** argv) {
tt.Normalize();
}
}
+ if (testset.size()) {
+ ReadFile rf(testset);
+ istream& in = *rf.stream();
+ int lc = 0;
+ double tlp = 0;
+ string ssrc, strg, line;
+ while (getline(in, line)) {
+ ++lc;
+ ParseTranslatorInput(line, &ssrc, &strg);
+ Lattice src, trg;
+ LatticeTools::ConvertTextToLattice(ssrc, &src);
+ LatticeTools::ConvertTextToLattice(strg, &trg);
+ double log_prob = log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
+
+ // compute likelihood
+ for (int j = 0; j < trg.size(); ++j) {
+ const WordID& f_j = trg[j][0].label;
+ double sum = 0;
+ const double j_over_ts = double(j) / trg.size();
+ double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1)
+ if (use_null) {
+ if (favor_diagonal) prob_a_i = prob_align_null;
+ sum += tt.prob(kNULL, f_j) * prob_a_i;
+ }
+ double az = 0;
+ if (favor_diagonal) {
+ for (int ta = 0; ta < src.size(); ++ta)
+ az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
+ az /= prob_align_not_null;
+ }
+ for (int i = 1; i <= src.size(); ++i) {
+ if (favor_diagonal)
+ prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az;
+ sum += tt.prob(src[i-1][0].label, f_j) * prob_a_i;
+ }
+ log_prob += log(sum);
+ }
+ tlp += log_prob;
+ cerr << ssrc << " ||| " << strg << " ||| " << log_prob << endl;
+ }
+ cerr << "TOTAL LOG PROB " << tlp << endl;
+ }
+
if (write_alignments) return 0;
for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {