From c3ecdd0009b269a360b240a3a99fc6f8d8b117d3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 2 Apr 2014 02:35:19 -0400 Subject: speed up fast align by 20% or so --- word-aligner/fast_align.cc | 19 +++++++++--------- word-aligner/ttables.h | 50 +++++++++++++++++++++++----------------------- 2 files changed, 35 insertions(+), 34 deletions(-) (limited to 'word-aligner') diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc index f54233eb..73b72399 100644 --- a/word-aligner/fast_align.cc +++ b/word-aligner/fast_align.cc @@ -190,6 +190,7 @@ int main(int argc, char** argv) { cout << (max_index - 1) << '-' << j; } } + if (s2t_viterbi.size() <= static_cast(max_i)) s2t_viterbi.resize(max_i + 1); s2t_viterbi[max_i][f_j] = 1.0; } } else { @@ -308,17 +309,17 @@ int main(int argc, char** argv) { if (output_parameters) { WriteFile params_out(conf["output_parameters"].as()); - for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { - const TTable::Word2Double& cpd = ei->second; - const TTable::Word2Double& vit = s2t_viterbi[ei->first]; - const string& esym = TD::Convert(ei->first); + for (unsigned eind = 1; eind < s2t.ttable.size(); ++eind) { + const auto& cpd = s2t.ttable[eind]; + const TTable::Word2Double& vit = s2t_viterbi[eind]; + const string& esym = TD::Convert(eind); double max_p = -1; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) - if (fi->second > max_p) max_p = fi->second; + for (auto& fi : cpd) + if (fi.second > max_p) max_p = fi.second; const double threshold = max_p * BEAM_THRESHOLD; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { - if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { - *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; + for (auto& fi : cpd) { + if (fi.second > threshold || (vit.find(fi.first) != vit.end())) { + *params_out << esym << ' ' << TD::Convert(fi.first) << ' ' << log(fi.second) << endl; } } } diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h index d82aff72..b9964225 100644 --- a/word-aligner/ttables.h +++ b/word-aligner/ttables.h @@ -2,6 +2,7 @@ #define _TTABLES_H_ #include +#include #ifndef HAVE_OLD_CPP # include #else @@ -18,11 +19,10 @@ class TTable { public: TTable() {} typedef std::unordered_map Word2Double; - typedef std::unordered_map Word2Word2Double; + typedef std::vector Word2Word2Double; inline double prob(const int& e, const int& f) const { - const Word2Word2Double::const_iterator cit = ttable.find(e); - if (cit != ttable.end()) { - const Word2Double& cpd = cit->second; + if (e < static_cast(ttable.size())) { + const Word2Double& cpd = ttable[e]; const Word2Double::const_iterator it = cpd.find(f); if (it == cpd.end()) return 1e-9; return it->second; @@ -31,19 +31,21 @@ class TTable { } } inline void Increment(const int& e, const int& f) { + if (e >= static_cast(ttable.size())) counts.resize(e + 1); counts[e][f] += 1.0; } inline void Increment(const int& e, const int& f, double x) { + if (e >= static_cast(counts.size())) counts.resize(e + 1); counts[e][f] += x; } void NormalizeVB(const double alpha) { ttable.swap(counts); - for (Word2Word2Double::iterator cit = ttable.begin(); - cit != ttable.end(); ++cit) { + for (unsigned i = 0; i < ttable.size(); ++i) { double tot = 0; - Word2Double& cpd = cit->second; + Word2Double& cpd = ttable[i]; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) tot += it->second + alpha; + if (!tot) tot = 1; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); } @@ -51,12 +53,12 @@ class TTable { } void Normalize() { ttable.swap(counts); - for (Word2Word2Double::iterator cit = ttable.begin(); - cit != ttable.end(); ++cit) { + for (unsigned i = 0; i < ttable.size(); ++i) { double tot = 0; - Word2Double& cpd = cit->second; + Word2Double& cpd = ttable[i]; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) tot += it->second; + if (!tot) tot = 1; for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) it->second /= tot; } @@ -64,29 +66,27 @@ class TTable { } // adds counts from another TTable - probabilities remain unchanged TTable& operator+=(const TTable& rhs) { - for (Word2Word2Double::const_iterator it = rhs.counts.begin(); - it != rhs.counts.end(); ++it) { - const Word2Double& cpd = it->second; - Word2Double& tgt = counts[it->first]; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - tgt[j->first] += j->second; - } + if (rhs.counts.size() > counts.size()) counts.resize(rhs.counts.size()); + for (unsigned i = 0; i < rhs.counts.size(); ++i) { + const Word2Double& cpd = rhs.counts[i]; + Word2Double& tgt = counts[i]; + for (auto p : cpd) tgt[p.first] += p.second; } return *this; } void ShowTTable() const { - for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) { - const Word2Double& cpd = it->second; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + for (unsigned it = 0; it < ttable.size(); ++it) { + const Word2Double& cpd = ttable[it]; + for (auto& p : cpd) { + std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl; } } } void ShowCounts() const { - for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) { - const Word2Double& cpd = it->second; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + for (unsigned it = 0; it < counts.size(); ++it) { + const Word2Double& cpd = counts[it]; + for (auto& p : cpd) { + std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl; } } } -- cgit v1.2.3 From 74401769fdb8b16f44df8911070b7ae091de5fef Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 9 Apr 2014 20:19:57 -0400 Subject: fix for loading parameters --- word-aligner/ttables.cc | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'word-aligner') diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc index a56bbcef..64d54bdf 100644 --- a/word-aligner/ttables.cc +++ b/word-aligner/ttables.cc @@ -8,28 +8,32 @@ using namespace std; void TTable::DeserializeProbsFromText(std::istream* in) { int c = 0; + string e; + string f; + double p; while(*in) { - string e; - string f; - double p; (*in) >> e >> f >> p; if (e.empty()) break; ++c; - ttable[TD::Convert(e)][TD::Convert(f)] = p; + WordID ie = TD::Convert(e); + if (ie >= static_cast(ttable.size())) ttable.resize(ie + 1); + ttable[ie][TD::Convert(f)] = p; } cerr << "Loaded " << c << " translation parameters.\n"; } void TTable::DeserializeLogProbsFromText(std::istream* in) { int c = 0; + string e; + string f; + double p; while(*in) { - string e; - string f; - double p; (*in) >> e >> f >> p; if (e.empty()) break; ++c; - ttable[TD::Convert(e)][TD::Convert(f)] = exp(p); + WordID ie = TD::Convert(e); + if (ie >= static_cast(ttable.size())) ttable.resize(ie + 1); + ttable[ie][TD::Convert(f)] = exp(p); } cerr << "Loaded " << c << " translation parameters.\n"; } -- cgit v1.2.3 From 1748e9a095bcc3a1db8ab47eb7ac6a1f9568772b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 20 Apr 2014 22:25:15 -0400 Subject: binary derivations with maximal arity-2 --- word-aligner/Makefile.am | 5 +- word-aligner/binderiv.cc | 202 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 word-aligner/binderiv.cc (limited to 'word-aligner') diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am index 1f7f78ae..075ad009 100644 --- a/word-aligner/Makefile.am +++ b/word-aligner/Makefile.am @@ -1,8 +1,11 @@ -bin_PROGRAMS = fast_align +bin_PROGRAMS = fast_align binderiv fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h fast_align_LDADD = ../utils/libutils.a +binderiv_SOURCES = binderiv.cc +binderiv_LDADD = ../utils/libutils.a + EXTRA_DIST = aligner.pl ortho-norm support makefiles stemmers AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training diff --git a/word-aligner/binderiv.cc b/word-aligner/binderiv.cc new file mode 100644 index 00000000..8ebc1105 --- /dev/null +++ b/word-aligner/binderiv.cc @@ -0,0 +1,202 @@ +#include +#include +#include +#include + +#include "alignment_io.h" +#include "tdict.h" + +using namespace std; + +enum CombinationType { + kNONE = 0, + kAXIOM, + kMONO, kSWAP, kCONTAINS_L, kCONTAINS_R, kINTERLEAVE +}; + +string nm(CombinationType x) { + switch (x) { + case kNONE: return "NONE"; + case kAXIOM: return "AXIOM"; + case kMONO: return "MONO"; + case kSWAP: return "SWAP"; + case kCONTAINS_L: return "CONTAINS_L"; + case kCONTAINS_R: return "CONTAINS_R"; + case kINTERLEAVE: return "INTERLEAVE"; + } +} + +string Substring(const vector& s, unsigned i, unsigned j) { + ostringstream os; + for (unsigned k = i; k < j; ++k) { + if (k > i) os << ' '; + os << TD::Convert(s[k]); + } + return os.str(); +} + +inline int min4(int a, int b, int c, int d) { + int l = a; + if (b < a) l = b; + int l2 = c; + if (d < c) l2 = c; + return min(l, l2); +} + +inline int max4(int a, int b, int c, int d) { + int l = a; + if (b > a) l = b; + int l2 = c; + if (d > c) l2 = d; + return max(l, l2); +} + +struct State { + int s,t,u,v; + State() : s(), t(), u(), v() {} + State(int a, int b, int c, int d) : s(a), t(b), u(c), v(d) { + assert(s <= t); + assert(u <= v); + } + bool IsGood() const { + return (s != 0 || t != 0 || u != 0 || v != 0); + } + CombinationType operator&(const State& r) const { + if (r.s != t) return kNONE; + if (v <= r.u) return kMONO; + if (r.v <= u) return kSWAP; + if (v >= r.v && u <= r.u) return kCONTAINS_R; + if (r.v >= v && r.u <= u) return kCONTAINS_L; + return kINTERLEAVE; + } + State& operator*=(const State& r) { + assert(r.s == t); + t = r.t; + const int tu = min4(u, v, r.u, r.v); + v = max4(u, v, r.u, r.v); + u = tu; + return *this; + } +}; + +double score(CombinationType x) { + switch (x) { + case kNONE: return 0.0; + case kAXIOM: return 1.0; + case kMONO: return 16.0; + case kSWAP: return 8.0; + case kCONTAINS_R: return 4.0; + case kCONTAINS_L: return 2.0; + case kINTERLEAVE: return 1.0; + } +} + +State operator*(const State& l, const State& r) { + State res = l; + res *= r; + return res; +} + +ostream& operator<<(ostream& os, const State& s) { + return os << '[' << s.s << ", " << s.t << ", " << s.u << ", " << s.v << ']'; +} + +string NT(const State& s) { + bool decorate=true; + if (decorate) { + ostringstream os; + os << "[X_" << s.s << '_' << s.t << '_' << s.u << '_' << s.v << "]"; + return os.str(); + } else { + return "[X]"; + } +} + +void CreateEdge(const vector& f, const vector& e, CombinationType ct, const State& cur, const State& left, const State& right) { + switch(ct) { + case kINTERLEAVE: + case kAXIOM: + cerr << NT(cur) << " ||| " << Substring(f, cur.s, cur.t) << " ||| " << Substring(e, cur.u, cur.v) << "\n"; + break; + case kMONO: + cerr << NT(cur) << " ||| " << NT(left) << ' ' << NT(right) << " ||| [1] [2]\n"; + break; + case kSWAP: + cerr << NT(cur) << " ||| " << NT(left) << ' ' << NT(right) << " ||| [2] [1]\n"; + break; + case kCONTAINS_L: + cerr << NT(cur) << " ||| " << Substring(f, right.s, left.s) << ' ' << NT(left) << ' ' << Substring(f, left.t, right.t) << " ||| " << Substring(e, right.u, left.u) << " [1] " << Substring(e, left.v, right.v) << endl; + break; + case kCONTAINS_R: + cerr << NT(cur) << " ||| " << Substring(f, left.s, right.s) << ' ' << NT(right) << ' ' << Substring(f, right.t, left.t) << " ||| " << Substring(e, left.u, right.u) << " [1] " << Substring(e, right.v, left.v) << endl; + break; + } +} + +void BuildArity2Forest(const vector& f, const vector& e, const vector& axioms) { + const unsigned n = f.size(); + Array2D chart(n, n+1); + Array2D ctypes(n, n+1); + Array2D cscore(n, n+1); + Array2D cmids(n, n+1, -1); + for (const auto& axiom : axioms) { + chart(axiom.s, axiom.t) = axiom; + ctypes(axiom.s, axiom.t) = kAXIOM; + cscore(axiom.s, axiom.t) = 1.0; + CreateEdge(f, e, kAXIOM, axiom, axiom, axiom); + //cerr << "AXIOM " << axiom.s << ", " << axiom.t << " : " << chart(axiom.s, axiom.t) << " : " << 1 << endl; + } + for (unsigned l = 2; l <= n; ++l) { + const unsigned i_end = n + 1 - l; + for (unsigned i = 0; i < i_end; ++i) { + const unsigned j = i + l; + for (unsigned k = i + 1; k < j; ++k) { + const State& left = chart(i, k); + const State& right = chart(k, j); + if (!left.IsGood() || !right.IsGood()) continue; + CombinationType comb = left & right; + if (comb != kNONE) { + double ns = cscore(i,k) + cscore(k,j) + score(comb); + if (ns > cscore(i,j)) { + cscore(i,j) = ns; + chart(i,j) = left * right; + cmids(i,j) = k; + ctypes(i,j) = comb; + //cerr << "PROVED " << chart(i,j) << " : " << cscore(i,j) << " [" << nm(comb) << " " << left << " * " << right << "]\n"; + } else { + //cerr << "SUBOPTIMAL " << (left*right) << " : " << ns << " [" << nm(comb) << " " << left << " * " << right << "]\n"; + } + CreateEdge(f, e, comb, left * right, left, right); + } else { + //cerr << "CAN'T " << left << " * " << right << endl; + } + } + } + } +} + +int main(int argc, char** argv) { + State s; + vector e,f; + TD::ConvertSentence("B C that A", &e); + TD::ConvertSentence("A de B C", &f); + State w0(0,1,3,4), w1(1,2,2,3), w2(2,3,0,1), w3(3,4,1,2); + vector al = {w0, w1, w2, w3}; + // f cannot have any unaligned words, however, multiple overlapping axioms are possible + // so you can write code to align unaligned words in all ways to surrounding words + BuildArity2Forest(f, e, al); + + TD::ConvertSentence("A B C D", &e); + TD::ConvertSentence("A B , C D", &f); + vector al2 = {State(0,1,0,1), State(1,2,1,2), State(1,3,1,2), State(2,4,2,3), State(4,5,3,4)}; + BuildArity2Forest(f, e, al2); + + TD::ConvertSentence("A B C D", &e); + TD::ConvertSentence("C A D B", &f); + vector al3 = {State(0,1,2,3), State(1,2,0,1), State(2,3,3,4), State(3,4,1,2)}; + BuildArity2Forest(f, e, al3); + + // things to do: run EM, do posterior inference with a Dirichlet prior, etc. + return 0; +} + -- cgit v1.2.3