summaryrefslogtreecommitdiff
path: root/word-aligner/fast_align.cc
diff options
context:
space:
mode:
authorFelix Hieber <fhieber@sdl.com>2013-06-18 10:23:55 -0700
committerFelix Hieber <fhieber@sdl.com>2013-06-18 10:23:55 -0700
commita0c40ec491674c8826d7be2fbd46eaaa78ad3ed6 (patch)
treeb2872c724824b9de3e077f1a466c3a06b6462aec /word-aligner/fast_align.cc
parent3ed5b91d299841eed678d8989fb7d7b90888c3be (diff)
forced alignment
Diffstat (limited to 'word-aligner/fast_align.cc')
-rw-r--r--word-aligner/fast_align.cc56
1 files changed, 35 insertions, 21 deletions
diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc
index 9eb1dbc6..fddcba9c 100644
--- a/word-aligner/fast_align.cc
+++ b/word-aligner/fast_align.cc
@@ -33,11 +33,13 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior")
("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")
("no_null_word,N","Do not generate from a null token")
- ("output_parameters,p", "Write model parameters instead of alignments")
+ ("output_parameters,p", po::value<string>(), "Write model parameters to file")
("beam_threshold,t",po::value<double>()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)")
("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)")
("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model")
- ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
+ ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)")
+ ("force_align,f",po::value<string>(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.")
+ ("mean_srclen_multiplier,m",po::value<double>()->default_value(1), "When --force_align, use this source length multiplier");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
@@ -66,18 +68,20 @@ int main(int argc, char** argv) {
if (!InitCommandLine(argc, argv, &conf)) return 1;
const string fname = conf["input"].as<string>();
const bool reverse = conf.count("reverse") > 0;
- const int ITERATIONS = conf["iterations"].as<unsigned>();
+ const int ITERATIONS = (conf.count("force_align")) ? 0 : conf["iterations"].as<unsigned>();
const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>());
const bool use_null = (conf.count("no_null_word") == 0);
const WordID kNULL = TD::Convert("<eps>");
const bool add_viterbi = (conf.count("no_add_viterbi") == 0);
const bool variational_bayes = (conf.count("variational_bayes") > 0);
- const bool write_alignments = (conf.count("output_parameters") == 0);
+ const bool output_parameters = (conf.count("force_align")) ? false : conf.count("output_parameters");
double diagonal_tension = conf["diagonal_tension"].as<double>();
bool optimize_tension = conf.count("optimize_tension");
- const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0);
+ bool hide_training_alignments = (conf.count("hide_training_alignments") > 0);
+ const bool write_alignments = (conf.count("force_align")) ? true : !hide_training_alignments;
string testset;
if (conf.count("testset")) testset = conf["testset"].as<string>();
+ if (conf.count("force_align")) testset = fname;
double prob_align_null = conf["prob_align_null"].as<double>();
double prob_align_not_null = 1.0 - prob_align_null;
const double alpha = conf["alpha"].as<double>();
@@ -86,13 +90,22 @@ int main(int argc, char** argv) {
cerr << "--alpha must be > 0\n";
return 1;
}
-
+
+
TTable s2t, t2s;
TTable::Word2Word2Double s2t_viterbi;
unordered_map<pair<short, short>, unsigned, boost::hash<pair<short, short> > > size_counts;
double tot_len_ratio = 0;
double mean_srclen_multiplier = 0;
vector<double> probs;
+
+ if (conf.count("force_align")) {
+ // load model parameters
+ ReadFile s2t_f(conf["force_align"].as<string>());
+ s2t.DeserializeLogProbsFromText(s2t_f.stream());
+ mean_srclen_multiplier = conf["mean_srclen_multiplier"].as<double>();
+ }
+
for (int iter = 0; iter < ITERATIONS; ++iter) {
const bool final_iteration = (iter == (ITERATIONS - 1));
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
@@ -289,21 +302,22 @@ int main(int argc, char** argv) {
cerr << "TOTAL LOG PROB " << tlp << endl;
}
- if (write_alignments) return 0;
-
- for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
- const TTable::Word2Double& cpd = ei->second;
- const TTable::Word2Double& vit = s2t_viterbi[ei->first];
- const string& esym = TD::Convert(ei->first);
- double max_p = -1;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
- if (fi->second > max_p) max_p = fi->second;
- const double threshold = max_p * BEAM_THRESHOLD;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
- if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
- cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
- }
- }
+ if (output_parameters) {
+ WriteFile params_out(conf["output_parameters"].as<string>());
+ for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
+ const TTable::Word2Double& cpd = ei->second;
+ const TTable::Word2Double& vit = s2t_viterbi[ei->first];
+ const string& esym = TD::Convert(ei->first);
+ double max_p = -1;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
+ if (fi->second > max_p) max_p = fi->second;
+ const double threshold = max_p * BEAM_THRESHOLD;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
+ if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
+ *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+ }
+ }
+ }
}
return 0;
}