summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-18 22:51:11 -0500
committerChris Dyer <redpony@gmail.com>2009-12-18 22:51:11 -0500
commit544da4d8e42858b19e6229936df56d44d61b1f38 (patch)
tree703d6b8ef746227667a5110a5c92e0c16e9b8f66 /training
parentb0def329260a83da70ceffd503207fc590502a7b (diff)
add symmetrization heuristics to atools, add null word configuration
Diffstat (limited to 'training')
-rw-r--r--training/atools.cc130
1 files changed, 93 insertions, 37 deletions
diff --git a/training/atools.cc b/training/atools.cc
index a18250f7..cf238371 100644
--- a/training/atools.cc
+++ b/training/atools.cc
@@ -28,7 +28,7 @@ struct Command {
x->resize(max(a.width(), b.width()), max(a.height(), b.width()));
}
bool Safe(const Array2D<bool>& a, int i, int j) const {
- if (i < a.width() && j < a.height())
+ if (i >= 0 && j >= 0 && i < a.width() && j < a.height())
return a(i,j);
else
return false;
@@ -40,9 +40,10 @@ struct Command {
struct FMeasureCommand : public Command {
FMeasureCommand() : matches(), num_predicted(), num_in_ref() {}
int Result() const { return 2; }
- string Name() const { return "f"; }
+ string Name() const { return "fmeasure"; }
bool RequiresTwoOperands() const { return true; }
void Apply(const Array2D<bool>& hyp, const Array2D<bool>& ref, Array2D<bool>* x) {
+ (void) x; // AER just computes statistics, not an alignment
int i_len = ref.width();
int j_len = ref.height();
for (int i = 0; i < i_len; ++i) {
@@ -126,54 +127,85 @@ struct RefineCommand : public Command {
neighbors_.push_back(make_pair(0,-1));
}
bool RequiresTwoOperands() const { return true; }
+
+ void Align(int i, int j) {
+ res_(i, j) = true;
+ is_i_aligned_[i] = true;
+ is_j_aligned_[j] = true;
+ }
+
+ bool IsNeighborAligned(int i, int j) const {
+ for (int k = 0; k < neighbors_.size(); ++k) {
+ const int di = neighbors_[k].first;
+ const int dj = neighbors_[k].second;
+ if (Safe(res_, i + di, j + dj))
+ return true;
+ }
+ return false;
+ }
+
+ bool IsNeitherAligned(int i, int j) const {
+ return !(is_i_aligned_[i] || is_j_aligned_[j]);
+ }
+
+ bool IsOneOrBothUnaligned(int i, int j) const {
+ return !(is_i_aligned_[i] && is_j_aligned_[j]);
+ }
+
+ bool KoehnAligned(int i, int j) const {
+ return IsOneOrBothUnaligned(i, j) && IsNeighborAligned(i, j);
+ }
+
+ typedef bool (RefineCommand::*Predicate)(int i, int j) const;
+
protected:
void InitRefine(
const Array2D<bool>& a,
- const Array2D<bool>& b,
- Array2D<bool>* x) {
- EnsureSize(a, b, x);
+ const Array2D<bool>& b) {
+ res_.clear();
+ EnsureSize(a, b, &res_);
in_.clear(); un_.clear(); is_i_aligned_.clear(); is_j_aligned_.clear();
EnsureSize(a, b, &in_);
EnsureSize(a, b, &un_);
- is_i_aligned_.resize(x->width(), false);
- is_j_aligned_.resize(x->height(), false);
+ is_i_aligned_.resize(res_.width(), false);
+ is_j_aligned_.resize(res_.height(), false);
for (int i = 0; i < in_.width(); ++i)
for (int j = 0; j < in_.height(); ++j) {
un_(i, j) = Safe(a, i, j) || Safe(b, i, j);
in_(i, j) = Safe(a, i, j) && Safe(b, i, j);
+ if (in_(i, j)) Align(i, j);
}
}
- // "grow" the intersection alignment with neighboring points
- // from the union alignment
- void Grow(Array2D<bool>* x) {
- Array2D<bool>& res = *x;
- queue<pair<int, int> > q;
- for (int i = 0; i < in_.width(); ++i)
- for (int j = 0; j < in_.height(); ++j)
- if (in_(i, j)) {
- Align(i, j, x);
- q.push(make_pair(i, j));
+ // "grow" the resulting alignment using the points in adds
+ // if they match the constraints determined by pred
+ void Grow(Predicate pred, bool idempotent, const Array2D<bool>& adds) {
+ if (idempotent) {
+ for (int i = 0; i < adds.width(); ++i)
+ for (int j = 0; j < adds.height(); ++j) {
+ if (adds(i, j) && !res_(i, j) &&
+ (this->*pred)(i, j)) Align(i, j);
}
- while(!q.empty()) {
- const pair<int,int> point = q.front();
- q.pop();
- for (int k = 0; k < neighbors_.size(); ++k) {
- const int test_i = neighbors_[k].first + point.first;
- const int test_j = neighbors_[k].second + point.second;
- if (Safe(un_, test_i, test_j) && !res(test_i, test_j)) {
- Align(test_i, test_j, x);
- q.push(make_pair(test_i, test_j));
+ return;
+ }
+ set<pair<int, int> > p;
+ for (int i = 0; i < adds.width(); ++i)
+ for (int j = 0; j < adds.height(); ++j)
+ if (adds(i, j) && !res_(i, j))
+ p.insert(make_pair(i, j));
+ bool keep_going = !p.empty();
+ while (keep_going) {
+ keep_going = false;
+ for (set<pair<int, int> >::iterator pi = p.begin();
+ pi != p.end(); ++pi) {
+ if ((this->*pred)(pi->first, pi->second)) {
+ Align(pi->first, pi->second);
+ p.erase(pi);
+ keep_going = true;
}
}
}
}
- void Final(bool do_and, Array2D<bool>* x) {
- }
- void Align(int i, int j, Array2D<bool>* x) {
- (*x)(i, j) = true;
- is_i_aligned_[i] = true;
- is_j_aligned_[j] = true;
- }
+ Array2D<bool> res_; // refined alignment
Array2D<bool> in_; // intersection alignment
Array2D<bool> un_; // union alignment
vector<bool> is_i_aligned_;
@@ -190,12 +222,34 @@ struct DiagCommand : public RefineCommand {
}
};
+struct GDCommand : public DiagCommand {
+ string Name() const { return "grow-diag"; }
+ void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) {
+ InitRefine(a, b);
+ Grow(&RefineCommand::KoehnAligned, false, un_);
+ *x = res_;
+ }
+};
+
struct GDFCommand : public DiagCommand {
- string Name() const { return "gdf"; }
+ string Name() const { return "grow-diag-final"; }
+ void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) {
+ InitRefine(a, b);
+ Grow(&RefineCommand::KoehnAligned, false, un_);
+ Grow(&RefineCommand::IsOneOrBothUnaligned, true, a);
+ Grow(&RefineCommand::IsOneOrBothUnaligned, true, b);
+ *x = res_;
+ }
+};
+
+struct GDFACommand : public DiagCommand {
+ string Name() const { return "grow-diag-final-and"; }
void Apply(const Array2D<bool>& a, const Array2D<bool>& b, Array2D<bool>* x) {
- InitRefine(a, b, x);
- Grow(x);
- Final(false, x);
+ InitRefine(a, b);
+ Grow(&RefineCommand::KoehnAligned, false, un_);
+ Grow(&RefineCommand::IsNeitherAligned, true, a);
+ Grow(&RefineCommand::IsNeitherAligned, true, b);
+ *x = res_;
}
};
@@ -258,7 +312,9 @@ int main(int argc, char **argv) {
AddCommand<InvertCommand>();
AddCommand<IntersectCommand>();
AddCommand<UnionCommand>();
+ AddCommand<GDCommand>();
AddCommand<GDFCommand>();
+ AddCommand<GDFACommand>();
AddCommand<FMeasureCommand>();
po::variables_map conf;
InitCommandLine(argc, argv, &conf);