diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 16:46:52 +0100 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 16:46:52 +0100 | 
| commit | 22e6ab01aebca3e9012b07f9600153c7b593996e (patch) | |
| tree | 844d1a650a302114ae619d37b8778ab66207a834 /word-aligner | |
| parent | 02099a01350a41a99ec400e9b29df08a01d88979 (diff) | |
| parent | 0dc7755f7fb1ef15db5a60c70866aa61b6367898 (diff) | |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'word-aligner')
| -rw-r--r-- | word-aligner/fast_align.cc | 56 | ||||
| -rw-r--r-- | word-aligner/ttables.cc | 14 | ||||
| -rw-r--r-- | word-aligner/ttables.h | 1 | 
3 files changed, 50 insertions, 21 deletions
| diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc index 9eb1dbc6..fddcba9c 100644 --- a/word-aligner/fast_align.cc +++ b/word-aligner/fast_align.cc @@ -33,11 +33,13 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior")          ("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")          ("no_null_word,N","Do not generate from a null token") -        ("output_parameters,p", "Write model parameters instead of alignments") +        ("output_parameters,p", po::value<string>(), "Write model parameters to file")          ("beam_threshold,t",po::value<double>()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)")          ("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)")          ("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") -        ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); +        ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)") +		("force_align,f",po::value<string>(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.") +		("mean_srclen_multiplier,m",po::value<double>()->default_value(1), "When --force_align, use this source length multiplier");    po::options_description clo("Command line options");    clo.add_options()          ("config", po::value<string>(), "Configuration file") @@ -66,18 +68,20 @@ int main(int argc, char** argv) {    if (!InitCommandLine(argc, argv, &conf)) return 1;    const string fname = conf["input"].as<string>();    const bool reverse = conf.count("reverse") > 0; -  const int ITERATIONS = conf["iterations"].as<unsigned>(); +  const int ITERATIONS = (conf.count("force_align")) ? 0 : conf["iterations"].as<unsigned>();    const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>());    const bool use_null = (conf.count("no_null_word") == 0);    const WordID kNULL = TD::Convert("<eps>");    const bool add_viterbi = (conf.count("no_add_viterbi") == 0);    const bool variational_bayes = (conf.count("variational_bayes") > 0); -  const bool write_alignments = (conf.count("output_parameters") == 0); +  const bool output_parameters = (conf.count("force_align")) ? false : conf.count("output_parameters");    double diagonal_tension = conf["diagonal_tension"].as<double>();    bool optimize_tension = conf.count("optimize_tension"); -  const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); +  bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); +  const bool write_alignments = (conf.count("force_align")) ? true : !hide_training_alignments;    string testset;    if (conf.count("testset")) testset = conf["testset"].as<string>(); +  if (conf.count("force_align")) testset = fname;    double prob_align_null = conf["prob_align_null"].as<double>();    double prob_align_not_null = 1.0 - prob_align_null;    const double alpha = conf["alpha"].as<double>(); @@ -86,13 +90,22 @@ int main(int argc, char** argv) {      cerr << "--alpha must be > 0\n";      return 1;    } - +   +      TTable s2t, t2s;    TTable::Word2Word2Double s2t_viterbi;    unordered_map<pair<short, short>, unsigned, boost::hash<pair<short, short> > > size_counts;    double tot_len_ratio = 0;    double mean_srclen_multiplier = 0;    vector<double> probs; +   +  if (conf.count("force_align")) { +	// load model parameters +	ReadFile s2t_f(conf["force_align"].as<string>()); +	s2t.DeserializeLogProbsFromText(s2t_f.stream()); +	mean_srclen_multiplier = conf["mean_srclen_multiplier"].as<double>(); +  } +      for (int iter = 0; iter < ITERATIONS; ++iter) {      const bool final_iteration = (iter == (ITERATIONS - 1));      cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; @@ -289,21 +302,22 @@ int main(int argc, char** argv) {      cerr << "TOTAL LOG PROB " << tlp << endl;    } -  if (write_alignments) return 0; - -  for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { -    const TTable::Word2Double& cpd = ei->second; -    const TTable::Word2Double& vit = s2t_viterbi[ei->first]; -    const string& esym = TD::Convert(ei->first); -    double max_p = -1; -    for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) -      if (fi->second > max_p) max_p = fi->second; -    const double threshold = max_p * BEAM_THRESHOLD; -    for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { -      if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { -        cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; -      } -    }  +  if (output_parameters) { +    WriteFile params_out(conf["output_parameters"].as<string>()); +    for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { +      const TTable::Word2Double& cpd = ei->second; +      const TTable::Word2Double& vit = s2t_viterbi[ei->first]; +      const string& esym = TD::Convert(ei->first); +      double max_p = -1; +      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) +        if (fi->second > max_p) max_p = fi->second; +      const double threshold = max_p * BEAM_THRESHOLD; +      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { +        if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { +          *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; +        } +      }  +    }    }    return 0;  } diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc index 45bf14c5..c177aa30 100644 --- a/word-aligner/ttables.cc +++ b/word-aligner/ttables.cc @@ -21,6 +21,20 @@ void TTable::DeserializeProbsFromText(std::istream* in) {    cerr << "Loaded " << c << " translation parameters.\n";  } +void TTable::DeserializeLogProbsFromText(std::istream* in) { +  int c = 0; +  while(*in) { +    string e; +    string f; +    double p; +    (*in) >> e >> f >> p; +    if (e.empty()) break; +    ++c; +    ttable[TD::Convert(e)][TD::Convert(f)] = exp(p); +  } +  cerr << "Loaded " << c << " translation parameters.\n"; +} +  void TTable::SerializeHelper(string* out, const Word2Word2Double& o) {    assert(!"not implemented");  } diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h index 9baa13ca..507f591a 100644 --- a/word-aligner/ttables.h +++ b/word-aligner/ttables.h @@ -86,6 +86,7 @@ class TTable {      }    }    void DeserializeProbsFromText(std::istream* in); +  void DeserializeLogProbsFromText(std::istream* in);    void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); }    void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); }    void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } | 
