From a0c40ec491674c8826d7be2fbd46eaaa78ad3ed6 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Tue, 18 Jun 2013 10:23:55 -0700 Subject: forced alignment --- word-aligner/fast_align.cc | 56 +++++++++++++++++++++++++++++----------------- word-aligner/ttables.cc | 14 ++++++++++++ word-aligner/ttables.h | 1 + 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc index 9eb1dbc6..fddcba9c 100644 --- a/word-aligner/fast_align.cc +++ b/word-aligner/fast_align.cc @@ -33,11 +33,13 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior") ("alpha,a", po::value()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") ("no_null_word,N","Do not generate from a null token") - ("output_parameters,p", "Write model parameters instead of alignments") + ("output_parameters,p", po::value(), "Write model parameters to file") ("beam_threshold,t",po::value()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)") ("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)") ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") - ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); + ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)") + ("force_align,f",po::value(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.") + ("mean_srclen_multiplier,m",po::value()->default_value(1), "When --force_align, use this source length multiplier"); po::options_description clo("Command line options"); clo.add_options() ("config", po::value(), "Configuration file") @@ -66,18 +68,20 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return 1; const string fname = conf["input"].as(); const bool reverse = conf.count("reverse") > 0; - const int ITERATIONS = conf["iterations"].as(); + const int ITERATIONS = (conf.count("force_align")) ? 0 : conf["iterations"].as(); const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as()); const bool use_null = (conf.count("no_null_word") == 0); const WordID kNULL = TD::Convert(""); const bool add_viterbi = (conf.count("no_add_viterbi") == 0); const bool variational_bayes = (conf.count("variational_bayes") > 0); - const bool write_alignments = (conf.count("output_parameters") == 0); + const bool output_parameters = (conf.count("force_align")) ? false : conf.count("output_parameters"); double diagonal_tension = conf["diagonal_tension"].as(); bool optimize_tension = conf.count("optimize_tension"); - const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); + bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); + const bool write_alignments = (conf.count("force_align")) ? true : !hide_training_alignments; string testset; if (conf.count("testset")) testset = conf["testset"].as(); + if (conf.count("force_align")) testset = fname; double prob_align_null = conf["prob_align_null"].as(); double prob_align_not_null = 1.0 - prob_align_null; const double alpha = conf["alpha"].as(); @@ -86,13 +90,22 @@ int main(int argc, char** argv) { cerr << "--alpha must be > 0\n"; return 1; } - + + TTable s2t, t2s; TTable::Word2Word2Double s2t_viterbi; unordered_map, unsigned, boost::hash > > size_counts; double tot_len_ratio = 0; double mean_srclen_multiplier = 0; vector probs; + + if (conf.count("force_align")) { + // load model parameters + ReadFile s2t_f(conf["force_align"].as()); + s2t.DeserializeLogProbsFromText(s2t_f.stream()); + mean_srclen_multiplier = conf["mean_srclen_multiplier"].as(); + } + for (int iter = 0; iter < ITERATIONS; ++iter) { const bool final_iteration = (iter == (ITERATIONS - 1)); cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; @@ -289,21 +302,22 @@ int main(int argc, char** argv) { cerr << "TOTAL LOG PROB " << tlp << endl; } - if (write_alignments) return 0; - - for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { - const TTable::Word2Double& cpd = ei->second; - const TTable::Word2Double& vit = s2t_viterbi[ei->first]; - const string& esym = TD::Convert(ei->first); - double max_p = -1; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) - if (fi->second > max_p) max_p = fi->second; - const double threshold = max_p * BEAM_THRESHOLD; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { - if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { - cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; - } - } + if (output_parameters) { + WriteFile params_out(conf["output_parameters"].as()); + for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { + const TTable::Word2Double& cpd = ei->second; + const TTable::Word2Double& vit = s2t_viterbi[ei->first]; + const string& esym = TD::Convert(ei->first); + double max_p = -1; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) + if (fi->second > max_p) max_p = fi->second; + const double threshold = max_p * BEAM_THRESHOLD; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { + if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { + *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; + } + } + } } return 0; } diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc index 45bf14c5..c177aa30 100644 --- a/word-aligner/ttables.cc +++ b/word-aligner/ttables.cc @@ -21,6 +21,20 @@ void TTable::DeserializeProbsFromText(std::istream* in) { cerr << "Loaded " << c << " translation parameters.\n"; } +void TTable::DeserializeLogProbsFromText(std::istream* in) { + int c = 0; + while(*in) { + string e; + string f; + double p; + (*in) >> e >> f >> p; + if (e.empty()) break; + ++c; + ttable[TD::Convert(e)][TD::Convert(f)] = exp(p); + } + cerr << "Loaded " << c << " translation parameters.\n"; +} + void TTable::SerializeHelper(string* out, const Word2Word2Double& o) { assert(!"not implemented"); } diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h index 9baa13ca..507f591a 100644 --- a/word-aligner/ttables.h +++ b/word-aligner/ttables.h @@ -86,6 +86,7 @@ class TTable { } } void DeserializeProbsFromText(std::istream* in); + void DeserializeLogProbsFromText(std::istream* in); void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } -- cgit v1.2.3