summaryrefslogtreecommitdiff
path: root/training/fast_align.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:12 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:12 -0400
commit438dac41810b7c69fa10203ac5130d20efa2da9f (patch)
treef0054c69e5e14ce6d3f8f4f441661a3ee163234d /training/fast_align.cc
parentea79e535d69f6854d01c62e3752971fb6730d8e7 (diff)
produce alignments for a test set (which can be stdin)
Diffstat (limited to 'training/fast_align.cc')
-rw-r--r--training/fast_align.cc40
1 files changed, 25 insertions, 15 deletions
diff --git a/training/fast_align.cc b/training/fast_align.cc
index 0d7b0202..7492d26f 100644
--- a/training/fast_align.cc
+++ b/training/fast_align.cc
@@ -29,6 +29,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("no_null_word,N","Do not generate from a null token")
("output_parameters,p", "Write model parameters instead of alignments")
("beam_threshold,t",po::value<double>()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)")
+ ("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)")
("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model")
("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)");
po::options_description clo("Command line options");
@@ -54,14 +55,6 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
return true;
}
-double PosteriorInference(const vector<WordID>& src, const vector<WordID>& trg) {
- double llh = 0;
- static vector<double> unnormed_a_i;
- if (src.size() > unnormed_a_i.size())
- unnormed_a_i.resize(src.size());
- return llh;
-}
-
int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
@@ -76,6 +69,7 @@ int main(int argc, char** argv) {
const bool write_alignments = (conf.count("output_parameters") == 0);
const double diagonal_tension = conf["diagonal_tension"].as<double>();
const double prob_align_null = conf["prob_align_null"].as<double>();
+ const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0);
string testset;
if (conf.count("testset")) testset = conf["testset"].as<string>();
const double prob_align_not_null = 1.0 - prob_align_null;
@@ -164,7 +158,7 @@ int main(int argc, char** argv) {
max_i = src[i-1];
}
}
- if (write_alignments) {
+ if (!hide_training_alignments && write_alignments) {
if (max_index > 0) {
if (first_al) first_al = false; else cout << ' ';
if (reverse)
@@ -183,7 +177,7 @@ int main(int argc, char** argv) {
}
likelihood += log(sum);
}
- if (write_alignments && final_iteration) cout << endl;
+ if (write_alignments && final_iteration && !hide_training_alignments) cout << endl;
}
// log(e) = 1.0
@@ -210,11 +204,13 @@ int main(int argc, char** argv) {
istream& in = *rf.stream();
int lc = 0;
double tlp = 0;
- string ssrc, strg, line;
+ string line;
while (getline(in, line)) {
++lc;
vector<WordID> src, trg;
CorpusTools::ReadLine(line, &src, &trg);
+ cout << TD::GetString(src) << " ||| " << TD::GetString(trg) << " |||";
+ if (reverse) swap(src, trg);
double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier);
if (src.size() > unnormed_a_i.size())
unnormed_a_i.resize(src.size());
@@ -223,11 +219,14 @@ int main(int argc, char** argv) {
for (int j = 0; j < trg.size(); ++j) {
const WordID& f_j = trg[j];
double sum = 0;
+ int a_j = 0;
+ double max_pat = 0;
const double j_over_ts = double(j) / trg.size();
double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1)
if (use_null) {
if (favor_diagonal) prob_a_i = prob_align_null;
- sum += s2t.prob(kNULL, f_j) * prob_a_i;
+ max_pat = s2t.prob(kNULL, f_j) * prob_a_i;
+ sum += max_pat;
}
double az = 0;
if (favor_diagonal) {
@@ -240,13 +239,24 @@ int main(int argc, char** argv) {
for (int i = 1; i <= src.size(); ++i) {
if (favor_diagonal)
prob_a_i = unnormed_a_i[i-1] / az;
- sum += s2t.prob(src[i-1], f_j) * prob_a_i;
+ double pat = s2t.prob(src[i-1], f_j) * prob_a_i;
+ if (pat > max_pat) { max_pat = pat; a_j = i; }
+ sum += pat;
}
log_prob += log(sum);
+ if (write_alignments) {
+ if (a_j > 0) {
+ cout << ' ';
+ if (reverse)
+ cout << j << '-' << (a_j - 1);
+ else
+ cout << (a_j - 1) << '-' << j;
+ }
+ }
}
tlp += log_prob;
- cerr << ssrc << " ||| " << strg << " ||| " << log_prob << endl;
- }
+ cout << " ||| " << log_prob << endl << flush;
+ } // loop over test set sentences
cerr << "TOTAL LOG PROB " << tlp << endl;
}