diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/model1.cc | 39 | 
1 files changed, 36 insertions, 3 deletions
diff --git a/training/model1.cc b/training/model1.cc index b9590ece..346c0033 100644 --- a/training/model1.cc +++ b/training/model1.cc @@ -20,6 +20,10 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("iterations,i",po::value<unsigned>()->default_value(5),"Number of iterations of EM training")          ("beam_threshold,t",po::value<double>()->default_value(-4),"log_10 of beam threshold (-10000 to include everything, 0 max)")          ("no_null_word,N","Do not generate from the null token") +        ("write_alignments,A", "Write alignments instead of parameters") +        ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") +        ("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") +        ("prob_align_null", po::value<double>()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?")          ("variational_bayes,v","Add a symmetric Dirichlet prior and infer VB estimate of weights")          ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")          ("no_add_viterbi,V","Do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); @@ -56,7 +60,12 @@ int main(int argc, char** argv) {    const WordID kNULL = TD::Convert("<eps>");    const bool add_viterbi = (conf.count("no_add_viterbi") == 0);    const bool variational_bayes = (conf.count("variational_bayes") > 0); +  const bool write_alignments = (conf.count("write_alignments") > 0); +  const double diagonal_tension = conf["diagonal_tension"].as<double>(); +  const double prob_align_null = conf["prob_align_null"].as<double>(); +  const double prob_align_not_null = 1.0 - prob_align_null;    const double alpha = conf["alpha"].as<double>(); +  const bool favor_diagonal = conf.count("favor_diagonal");    if (variational_bayes && alpha <= 0.0) {      cerr << "--alpha must be > 0\n";      return 1; @@ -93,31 +102,52 @@ int main(int argc, char** argv) {        denom += trg.size();        vector<double> probs(src.size() + 1);        const double src_logprob = -log(src.size() + 1); +      bool first_al = true;  // used for write_alignments        for (int j = 0; j < trg.size(); ++j) {          const WordID& f_j = trg[j][0].label;          double sum = 0; +        const double j_over_ts = double(j) / trg.size(); +        double prob_a_i = 1.0 / (src.size() + use_null);  // uniform (model 1)          if (use_null) { -          probs[0] = tt.prob(kNULL, f_j); +          if (favor_diagonal) prob_a_i = prob_align_null; +          probs[0] = tt.prob(kNULL, f_j) * prob_a_i;            sum += probs[0];          } +        double az = 0; +        if (favor_diagonal) { +          for (int ta = 0; ta < src.size(); ++ta) +            az += exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); +          az /= prob_align_not_null; +        }          for (int i = 1; i <= src.size(); ++i) { -          probs[i] = tt.prob(src[i-1][0].label, f_j); +          if (favor_diagonal) +            prob_a_i = exp(-fabs(double(i) / src.size() - j_over_ts) * diagonal_tension) / az; +          probs[i] = tt.prob(src[i-1][0].label, f_j) * prob_a_i;            sum += probs[i];          }          if (final_iteration) { -          if (add_viterbi) { +          if (add_viterbi || write_alignments) {              WordID max_i = 0;              double max_p = -1; +            int max_index = -1;              if (use_null) {                max_i = kNULL; +              max_index = 0;                max_p = probs[0];              }              for (int i = 1; i <= src.size(); ++i) {                if (probs[i] > max_p) { +                max_index = i;                  max_p = probs[i];                  max_i = src[i-1][0].label;                }              } +            if (write_alignments) { +              if (max_index > 0) { +                if (first_al) first_al = false; else cout << ' '; +                cout << (max_index - 1) << "-" << j; +              } +            }              was_viterbi[max_i][f_j] = 1.0;            }          } else { @@ -128,6 +158,7 @@ int main(int argc, char** argv) {          }          likelihood += log(sum) + src_logprob;        } +      if (write_alignments && final_iteration) cout << endl;      }      // log(e) = 1.0 @@ -145,6 +176,8 @@ int main(int argc, char** argv) {          tt.Normalize();      }    } +  if (write_alignments) return 0; +    for (TTable::Word2Word2Double::iterator ei = tt.ttable.begin(); ei != tt.ttable.end(); ++ei) {      const TTable::Word2Double& cpd = ei->second;      const TTable::Word2Double& vit = was_viterbi[ei->first];  | 
