summaryrefslogtreecommitdiff
path: root/word-aligner
diff options
context:
space:
mode:
Diffstat (limited to 'word-aligner')
-rw-r--r--word-aligner/Makefile.am5
-rw-r--r--word-aligner/binderiv.cc202
-rw-r--r--word-aligner/fast_align.cc19
-rw-r--r--word-aligner/ttables.cc20
-rw-r--r--word-aligner/ttables.h50
5 files changed, 253 insertions, 43 deletions
diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am
index 1f7f78ae..075ad009 100644
--- a/word-aligner/Makefile.am
+++ b/word-aligner/Makefile.am
@@ -1,8 +1,11 @@
-bin_PROGRAMS = fast_align
+bin_PROGRAMS = fast_align binderiv
fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h
fast_align_LDADD = ../utils/libutils.a
+binderiv_SOURCES = binderiv.cc
+binderiv_LDADD = ../utils/libutils.a
+
EXTRA_DIST = aligner.pl ortho-norm support makefiles stemmers
AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training
diff --git a/word-aligner/binderiv.cc b/word-aligner/binderiv.cc
new file mode 100644
index 00000000..8ebc1105
--- /dev/null
+++ b/word-aligner/binderiv.cc
@@ -0,0 +1,202 @@
+#include <iostream>
+#include <string>
+#include <queue>
+#include <sstream>
+
+#include "alignment_io.h"
+#include "tdict.h"
+
+using namespace std;
+
+enum CombinationType {
+ kNONE = 0,
+ kAXIOM,
+ kMONO, kSWAP, kCONTAINS_L, kCONTAINS_R, kINTERLEAVE
+};
+
+string nm(CombinationType x) {
+ switch (x) {
+ case kNONE: return "NONE";
+ case kAXIOM: return "AXIOM";
+ case kMONO: return "MONO";
+ case kSWAP: return "SWAP";
+ case kCONTAINS_L: return "CONTAINS_L";
+ case kCONTAINS_R: return "CONTAINS_R";
+ case kINTERLEAVE: return "INTERLEAVE";
+ }
+}
+
+string Substring(const vector<WordID>& s, unsigned i, unsigned j) {
+ ostringstream os;
+ for (unsigned k = i; k < j; ++k) {
+ if (k > i) os << ' ';
+ os << TD::Convert(s[k]);
+ }
+ return os.str();
+}
+
+inline int min4(int a, int b, int c, int d) {
+ int l = a;
+ if (b < a) l = b;
+ int l2 = c;
+ if (d < c) l2 = c;
+ return min(l, l2);
+}
+
+inline int max4(int a, int b, int c, int d) {
+ int l = a;
+ if (b > a) l = b;
+ int l2 = c;
+ if (d > c) l2 = d;
+ return max(l, l2);
+}
+
+struct State {
+ int s,t,u,v;
+ State() : s(), t(), u(), v() {}
+ State(int a, int b, int c, int d) : s(a), t(b), u(c), v(d) {
+ assert(s <= t);
+ assert(u <= v);
+ }
+ bool IsGood() const {
+ return (s != 0 || t != 0 || u != 0 || v != 0);
+ }
+ CombinationType operator&(const State& r) const {
+ if (r.s != t) return kNONE;
+ if (v <= r.u) return kMONO;
+ if (r.v <= u) return kSWAP;
+ if (v >= r.v && u <= r.u) return kCONTAINS_R;
+ if (r.v >= v && r.u <= u) return kCONTAINS_L;
+ return kINTERLEAVE;
+ }
+ State& operator*=(const State& r) {
+ assert(r.s == t);
+ t = r.t;
+ const int tu = min4(u, v, r.u, r.v);
+ v = max4(u, v, r.u, r.v);
+ u = tu;
+ return *this;
+ }
+};
+
+double score(CombinationType x) {
+ switch (x) {
+ case kNONE: return 0.0;
+ case kAXIOM: return 1.0;
+ case kMONO: return 16.0;
+ case kSWAP: return 8.0;
+ case kCONTAINS_R: return 4.0;
+ case kCONTAINS_L: return 2.0;
+ case kINTERLEAVE: return 1.0;
+ }
+}
+
+State operator*(const State& l, const State& r) {
+ State res = l;
+ res *= r;
+ return res;
+}
+
+ostream& operator<<(ostream& os, const State& s) {
+ return os << '[' << s.s << ", " << s.t << ", " << s.u << ", " << s.v << ']';
+}
+
+string NT(const State& s) {
+ bool decorate=true;
+ if (decorate) {
+ ostringstream os;
+ os << "[X_" << s.s << '_' << s.t << '_' << s.u << '_' << s.v << "]";
+ return os.str();
+ } else {
+ return "[X]";
+ }
+}
+
+void CreateEdge(const vector<WordID>& f, const vector<WordID>& e, CombinationType ct, const State& cur, const State& left, const State& right) {
+ switch(ct) {
+ case kINTERLEAVE:
+ case kAXIOM:
+ cerr << NT(cur) << " ||| " << Substring(f, cur.s, cur.t) << " ||| " << Substring(e, cur.u, cur.v) << "\n";
+ break;
+ case kMONO:
+ cerr << NT(cur) << " ||| " << NT(left) << ' ' << NT(right) << " ||| [1] [2]\n";
+ break;
+ case kSWAP:
+ cerr << NT(cur) << " ||| " << NT(left) << ' ' << NT(right) << " ||| [2] [1]\n";
+ break;
+ case kCONTAINS_L:
+ cerr << NT(cur) << " ||| " << Substring(f, right.s, left.s) << ' ' << NT(left) << ' ' << Substring(f, left.t, right.t) << " ||| " << Substring(e, right.u, left.u) << " [1] " << Substring(e, left.v, right.v) << endl;
+ break;
+ case kCONTAINS_R:
+ cerr << NT(cur) << " ||| " << Substring(f, left.s, right.s) << ' ' << NT(right) << ' ' << Substring(f, right.t, left.t) << " ||| " << Substring(e, left.u, right.u) << " [1] " << Substring(e, right.v, left.v) << endl;
+ break;
+ }
+}
+
+void BuildArity2Forest(const vector<WordID>& f, const vector<WordID>& e, const vector<State>& axioms) {
+ const unsigned n = f.size();
+ Array2D<State> chart(n, n+1);
+ Array2D<CombinationType> ctypes(n, n+1);
+ Array2D<double> cscore(n, n+1);
+ Array2D<int> cmids(n, n+1, -1);
+ for (const auto& axiom : axioms) {
+ chart(axiom.s, axiom.t) = axiom;
+ ctypes(axiom.s, axiom.t) = kAXIOM;
+ cscore(axiom.s, axiom.t) = 1.0;
+ CreateEdge(f, e, kAXIOM, axiom, axiom, axiom);
+ //cerr << "AXIOM " << axiom.s << ", " << axiom.t << " : " << chart(axiom.s, axiom.t) << " : " << 1 << endl;
+ }
+ for (unsigned l = 2; l <= n; ++l) {
+ const unsigned i_end = n + 1 - l;
+ for (unsigned i = 0; i < i_end; ++i) {
+ const unsigned j = i + l;
+ for (unsigned k = i + 1; k < j; ++k) {
+ const State& left = chart(i, k);
+ const State& right = chart(k, j);
+ if (!left.IsGood() || !right.IsGood()) continue;
+ CombinationType comb = left & right;
+ if (comb != kNONE) {
+ double ns = cscore(i,k) + cscore(k,j) + score(comb);
+ if (ns > cscore(i,j)) {
+ cscore(i,j) = ns;
+ chart(i,j) = left * right;
+ cmids(i,j) = k;
+ ctypes(i,j) = comb;
+ //cerr << "PROVED " << chart(i,j) << " : " << cscore(i,j) << " [" << nm(comb) << " " << left << " * " << right << "]\n";
+ } else {
+ //cerr << "SUBOPTIMAL " << (left*right) << " : " << ns << " [" << nm(comb) << " " << left << " * " << right << "]\n";
+ }
+ CreateEdge(f, e, comb, left * right, left, right);
+ } else {
+ //cerr << "CAN'T " << left << " * " << right << endl;
+ }
+ }
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ State s;
+ vector<WordID> e,f;
+ TD::ConvertSentence("B C that A", &e);
+ TD::ConvertSentence("A de B C", &f);
+ State w0(0,1,3,4), w1(1,2,2,3), w2(2,3,0,1), w3(3,4,1,2);
+ vector<State> al = {w0, w1, w2, w3};
+ // f cannot have any unaligned words, however, multiple overlapping axioms are possible
+ // so you can write code to align unaligned words in all ways to surrounding words
+ BuildArity2Forest(f, e, al);
+
+ TD::ConvertSentence("A B C D", &e);
+ TD::ConvertSentence("A B , C D", &f);
+ vector<State> al2 = {State(0,1,0,1), State(1,2,1,2), State(1,3,1,2), State(2,4,2,3), State(4,5,3,4)};
+ BuildArity2Forest(f, e, al2);
+
+ TD::ConvertSentence("A B C D", &e);
+ TD::ConvertSentence("C A D B", &f);
+ vector<State> al3 = {State(0,1,2,3), State(1,2,0,1), State(2,3,3,4), State(3,4,1,2)};
+ BuildArity2Forest(f, e, al3);
+
+ // things to do: run EM, do posterior inference with a Dirichlet prior, etc.
+ return 0;
+}
+
diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc
index f54233eb..73b72399 100644
--- a/word-aligner/fast_align.cc
+++ b/word-aligner/fast_align.cc
@@ -190,6 +190,7 @@ int main(int argc, char** argv) {
cout << (max_index - 1) << '-' << j;
}
}
+ if (s2t_viterbi.size() <= static_cast<unsigned>(max_i)) s2t_viterbi.resize(max_i + 1);
s2t_viterbi[max_i][f_j] = 1.0;
}
} else {
@@ -308,17 +309,17 @@ int main(int argc, char** argv) {
if (output_parameters) {
WriteFile params_out(conf["output_parameters"].as<string>());
- for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
- const TTable::Word2Double& cpd = ei->second;
- const TTable::Word2Double& vit = s2t_viterbi[ei->first];
- const string& esym = TD::Convert(ei->first);
+ for (unsigned eind = 1; eind < s2t.ttable.size(); ++eind) {
+ const auto& cpd = s2t.ttable[eind];
+ const TTable::Word2Double& vit = s2t_viterbi[eind];
+ const string& esym = TD::Convert(eind);
double max_p = -1;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
- if (fi->second > max_p) max_p = fi->second;
+ for (auto& fi : cpd)
+ if (fi.second > max_p) max_p = fi.second;
const double threshold = max_p * BEAM_THRESHOLD;
- for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
- if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
- *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+ for (auto& fi : cpd) {
+ if (fi.second > threshold || (vit.find(fi.first) != vit.end())) {
+ *params_out << esym << ' ' << TD::Convert(fi.first) << ' ' << log(fi.second) << endl;
}
}
}
diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc
index a56bbcef..64d54bdf 100644
--- a/word-aligner/ttables.cc
+++ b/word-aligner/ttables.cc
@@ -8,28 +8,32 @@ using namespace std;
void TTable::DeserializeProbsFromText(std::istream* in) {
int c = 0;
+ string e;
+ string f;
+ double p;
while(*in) {
- string e;
- string f;
- double p;
(*in) >> e >> f >> p;
if (e.empty()) break;
++c;
- ttable[TD::Convert(e)][TD::Convert(f)] = p;
+ WordID ie = TD::Convert(e);
+ if (ie >= static_cast<int>(ttable.size())) ttable.resize(ie + 1);
+ ttable[ie][TD::Convert(f)] = p;
}
cerr << "Loaded " << c << " translation parameters.\n";
}
void TTable::DeserializeLogProbsFromText(std::istream* in) {
int c = 0;
+ string e;
+ string f;
+ double p;
while(*in) {
- string e;
- string f;
- double p;
(*in) >> e >> f >> p;
if (e.empty()) break;
++c;
- ttable[TD::Convert(e)][TD::Convert(f)] = exp(p);
+ WordID ie = TD::Convert(e);
+ if (ie >= static_cast<int>(ttable.size())) ttable.resize(ie + 1);
+ ttable[ie][TD::Convert(f)] = exp(p);
}
cerr << "Loaded " << c << " translation parameters.\n";
}
diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h
index d82aff72..b9964225 100644
--- a/word-aligner/ttables.h
+++ b/word-aligner/ttables.h
@@ -2,6 +2,7 @@
#define _TTABLES_H_
#include <iostream>
+#include <vector>
#ifndef HAVE_OLD_CPP
# include <unordered_map>
#else
@@ -18,11 +19,10 @@ class TTable {
public:
TTable() {}
typedef std::unordered_map<WordID, double> Word2Double;
- typedef std::unordered_map<WordID, Word2Double> Word2Word2Double;
+ typedef std::vector<Word2Double> Word2Word2Double;
inline double prob(const int& e, const int& f) const {
- const Word2Word2Double::const_iterator cit = ttable.find(e);
- if (cit != ttable.end()) {
- const Word2Double& cpd = cit->second;
+ if (e < static_cast<int>(ttable.size())) {
+ const Word2Double& cpd = ttable[e];
const Word2Double::const_iterator it = cpd.find(f);
if (it == cpd.end()) return 1e-9;
return it->second;
@@ -31,19 +31,21 @@ class TTable {
}
}
inline void Increment(const int& e, const int& f) {
+ if (e >= static_cast<int>(ttable.size())) counts.resize(e + 1);
counts[e][f] += 1.0;
}
inline void Increment(const int& e, const int& f, double x) {
+ if (e >= static_cast<int>(counts.size())) counts.resize(e + 1);
counts[e][f] += x;
}
void NormalizeVB(const double alpha) {
ttable.swap(counts);
- for (Word2Word2Double::iterator cit = ttable.begin();
- cit != ttable.end(); ++cit) {
+ for (unsigned i = 0; i < ttable.size(); ++i) {
double tot = 0;
- Word2Double& cpd = cit->second;
+ Word2Double& cpd = ttable[i];
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second + alpha;
+ if (!tot) tot = 1;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot));
}
@@ -51,12 +53,12 @@ class TTable {
}
void Normalize() {
ttable.swap(counts);
- for (Word2Word2Double::iterator cit = ttable.begin();
- cit != ttable.end(); ++cit) {
+ for (unsigned i = 0; i < ttable.size(); ++i) {
double tot = 0;
- Word2Double& cpd = cit->second;
+ Word2Double& cpd = ttable[i];
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
tot += it->second;
+ if (!tot) tot = 1;
for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
it->second /= tot;
}
@@ -64,29 +66,27 @@ class TTable {
}
// adds counts from another TTable - probabilities remain unchanged
TTable& operator+=(const TTable& rhs) {
- for (Word2Word2Double::const_iterator it = rhs.counts.begin();
- it != rhs.counts.end(); ++it) {
- const Word2Double& cpd = it->second;
- Word2Double& tgt = counts[it->first];
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- tgt[j->first] += j->second;
- }
+ if (rhs.counts.size() > counts.size()) counts.resize(rhs.counts.size());
+ for (unsigned i = 0; i < rhs.counts.size(); ++i) {
+ const Word2Double& cpd = rhs.counts[i];
+ Word2Double& tgt = counts[i];
+ for (auto p : cpd) tgt[p.first] += p.second;
}
return *this;
}
void ShowTTable() const {
- for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) {
- const Word2Double& cpd = it->second;
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ for (unsigned it = 0; it < ttable.size(); ++it) {
+ const Word2Double& cpd = ttable[it];
+ for (auto& p : cpd) {
+ std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
}
}
}
void ShowCounts() const {
- for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) {
- const Word2Double& cpd = it->second;
- for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
- std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ for (unsigned it = 0; it < counts.size(); ++it) {
+ const Word2Double& cpd = counts[it];
+ for (auto& p : cpd) {
+ std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
}
}
}