summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2013-06-18 21:04:07 -0700
committerChris Dyer <redpony@gmail.com>2013-06-18 21:04:07 -0700
commitf1ce46ec9b1b8efcc4a91a149454acf03c01db02 (patch)
treea18fe97c06f725e9c492907351c1c669d73b9c5f
parent535d4016ec5179cb673b697c2e81500a2097924c (diff)
parenta0c40ec491674c8826d7be2fbd46eaaa78ad3ed6 (diff)
Merge pull request #23 from felleh/upstream
forced alignment for fast_align
-rw-r--r--word-aligner/fast_align.cc56
-rw-r--r--word-aligner/ttables.cc14
-rw-r--r--word-aligner/ttables.h1
3 files changed, 50 insertions, 21 deletions
diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc
index 9eb1dbc6..fddcba9c 100644
--- a/word-aligner/fast_align.cc
+++ b/word-aligner/fast_align.cc
@@ -33,11 +33,13 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior")
("alpha,a", po::value<double>()->default_value(0.01), "Hyperparameter for optional Dirichlet prior")
("no_null_word,N","Do not generate from a null token")
- ("output_parameters,p", "Write model parameters instead of alignments")
+ ("output_parameters,p", po::value<string>(), "Write model parameters to file")
("beam_threshold,t",po::value<double>()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)")
("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)")
("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model")
- ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
+ ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)")
+ ("force_align,f",po::value<string>(), "Load previously written parameters to 'force align' input. Set --diagonal_tension and --mean_srclen_multiplier as estimated during training.")
+ ("mean_srclen_multiplier,m",po::value<double>()->default_value(1), "When --force_align, use this source length multiplier");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
@@ -66,18 +68,20 @@ int main(int argc, char** argv) {
if (!InitCommandLine(argc, argv, &conf)) return 1;
const string fname = conf["input"].as<string>();
const bool reverse = conf.count("reverse") > 0;
- const int ITERATIONS = conf["iterations"].as<unsigned>();
+ const int ITERATIONS = (conf.count("force_align")) ? 0 : conf["iterations"].as<unsigned>();
const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as<double>());
const bool use_null = (conf.count("no_null_word") == 0);
const WordID kNULL = TD::Convert("<eps>");
const bool add_viterbi = (conf.count("no_add_viterbi") == 0);
const bool variational_bayes = (conf.count("variational_bayes") > 0);
- const bool write_alignments = (conf.count("output_parameters") == 0);
+ const bool output_parameters = (conf.count("force_align")) ? false : conf.count("output_parameters");
double diagonal_tension = conf["diagonal_tension"].as<double>();
bool optimize_tension = conf.count("optimize_tension");
- const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0);
+ bool hide_training_alignments = (conf.count("hide_training_alignments") > 0);
+ const bool write_alignments = (conf.count("force_align")) ? true : !hide_training_alignments;
string testset;
if (conf.count("testset")) testset = conf["testset"].as<string>();
+ if (conf.count("force_align")) testset = fname;
double prob_align_null = conf["prob_align_null"].as<double>();
double prob_align_not_null = 1.0 - prob_align_null;
const double alpha = conf["alpha"].as<double>();
@@ -86,13 +90,22 @@ int main(int argc, char** argv) {
cerr << "--alpha must be > 0\n";
return 1;
}
-
+
+
TTable s2t, t2s;
TTable::Word2Word2Double s2t_viterbi;
unordered_map<pair<short, short>, unsigned, boost::hash<pair<short, short> > > size_counts;
double tot_len_ratio = 0;
double mean_srclen_multiplier = 0;
vector<double> probs;
+
+ if (conf.count("force_align")) {
+ // load model parameters
+ ReadFile s2t_f(conf["force_align"].as<string>());
+ s2t.DeserializeLogProbsFromText(s2t_f.stream());
+ mean_srclen_multiplier = conf["mean_srclen_multiplier"].as<double>();
+ }
+
for (int iter = 0; iter < ITERATIONS; ++iter) {
const bool final_iteration = (iter == (ITERATIONS - 1));
cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl;
@@ -289,21 +302,22 @@ int main(int argc, char** argv) {
cerr << "TOTAL LOG PROB " << tlp << endl;
}
- if (write_alignments) return 0;
-
- for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
- const TTable::Word2Double& cpd = ei->second;
- const TTable::Word2Double& vit = s2t_viterbi[ei->first];
- const string& esym = TD::Convert(ei->first);
- double max_p = -1;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
- if (fi->second > max_p) max_p = fi->second;
- const double threshold = max_p * BEAM_THRESHOLD;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
- if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
- cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
- }
- }
+ if (output_parameters) {
+ WriteFile params_out(conf["output_parameters"].as<string>());
+ for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
+ const TTable::Word2Double& cpd = ei->second;
+ const TTable::Word2Double& vit = s2t_viterbi[ei->first];
+ const string& esym = TD::Convert(ei->first);
+ double max_p = -1;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
+ if (fi->second > max_p) max_p = fi->second;
+ const double threshold = max_p * BEAM_THRESHOLD;
+ for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
+ if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
+ *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+ }
+ }
+ }
}
return 0;
}
diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc
index 45bf14c5..c177aa30 100644
--- a/word-aligner/ttables.cc
+++ b/word-aligner/ttables.cc
@@ -21,6 +21,20 @@ void TTable::DeserializeProbsFromText(std::istream* in) {
cerr << "Loaded " << c << " translation parameters.\n";
}
+void TTable::DeserializeLogProbsFromText(std::istream* in) {
+ int c = 0;
+ while(*in) {
+ string e;
+ string f;
+ double p;
+ (*in) >> e >> f >> p;
+ if (e.empty()) break;
+ ++c;
+ ttable[TD::Convert(e)][TD::Convert(f)] = exp(p);
+ }
+ cerr << "Loaded " << c << " translation parameters.\n";
+}
+
void TTable::SerializeHelper(string* out, const Word2Word2Double& o) {
assert(!"not implemented");
}
diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h
index 9baa13ca..507f591a 100644
--- a/word-aligner/ttables.h
+++ b/word-aligner/ttables.h
@@ -86,6 +86,7 @@ class TTable {
}
}
void DeserializeProbsFromText(std::istream* in);
+ void DeserializeLogProbsFromText(std::istream* in);
void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); }
void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); }
void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); }