From c3ecdd0009b269a360b240a3a99fc6f8d8b117d3 Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Wed, 2 Apr 2014 02:35:19 -0400
Subject: speed up fast align by 20% or so

---
 word-aligner/fast_align.cc | 19 +++++++++---------
 word-aligner/ttables.h     | 50 +++++++++++++++++++++++-----------------------
 2 files changed, 35 insertions(+), 34 deletions(-)

(limited to 'word-aligner')

diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc
index f54233eb..73b72399 100644
--- a/word-aligner/fast_align.cc
+++ b/word-aligner/fast_align.cc
@@ -190,6 +190,7 @@ int main(int argc, char** argv) {
                   cout << (max_index - 1) << '-' << j;
               }
             }
+            if (s2t_viterbi.size() <= static_cast<unsigned>(max_i)) s2t_viterbi.resize(max_i + 1);
             s2t_viterbi[max_i][f_j] = 1.0;
           }
         } else {
@@ -308,17 +309,17 @@ int main(int argc, char** argv) {
 
   if (output_parameters) {
     WriteFile params_out(conf["output_parameters"].as<string>());
-    for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) {
-      const TTable::Word2Double& cpd = ei->second;
-      const TTable::Word2Double& vit = s2t_viterbi[ei->first];
-      const string& esym = TD::Convert(ei->first);
+    for (unsigned eind = 1; eind < s2t.ttable.size(); ++eind) {
+      const auto& cpd = s2t.ttable[eind];
+      const TTable::Word2Double& vit = s2t_viterbi[eind];
+      const string& esym = TD::Convert(eind);
       double max_p = -1;
-      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi)
-        if (fi->second > max_p) max_p = fi->second;
+      for (auto& fi : cpd)
+        if (fi.second > max_p) max_p = fi.second;
       const double threshold = max_p * BEAM_THRESHOLD;
-      for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) {
-        if (fi->second > threshold || (vit.find(fi->first) != vit.end())) {
-          *params_out << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl;
+      for (auto& fi : cpd) {
+        if (fi.second > threshold || (vit.find(fi.first) != vit.end())) {
+          *params_out << esym << ' ' << TD::Convert(fi.first) << ' ' << log(fi.second) << endl;
         }
       } 
     }
diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h
index d82aff72..b9964225 100644
--- a/word-aligner/ttables.h
+++ b/word-aligner/ttables.h
@@ -2,6 +2,7 @@
 #define _TTABLES_H_
 
 #include <iostream>
+#include <vector>
 #ifndef HAVE_OLD_CPP
 # include <unordered_map>
 #else
@@ -18,11 +19,10 @@ class TTable {
  public:
   TTable() {}
   typedef std::unordered_map<WordID, double> Word2Double;
-  typedef std::unordered_map<WordID, Word2Double> Word2Word2Double;
+  typedef std::vector<Word2Double> Word2Word2Double;
   inline double prob(const int& e, const int& f) const {
-    const Word2Word2Double::const_iterator cit = ttable.find(e);
-    if (cit != ttable.end()) {
-      const Word2Double& cpd = cit->second;
+    if (e < static_cast<int>(ttable.size())) {
+      const Word2Double& cpd = ttable[e];
       const Word2Double::const_iterator it = cpd.find(f);
       if (it == cpd.end()) return 1e-9;
       return it->second;
@@ -31,19 +31,21 @@ class TTable {
     }
   }
   inline void Increment(const int& e, const int& f) {
+    if (e >= static_cast<int>(ttable.size())) counts.resize(e + 1);
     counts[e][f] += 1.0;
   }
   inline void Increment(const int& e, const int& f, double x) {
+    if (e >= static_cast<int>(counts.size())) counts.resize(e + 1);
     counts[e][f] += x;
   }
   void NormalizeVB(const double alpha) {
     ttable.swap(counts);
-    for (Word2Word2Double::iterator cit = ttable.begin();
-         cit != ttable.end(); ++cit) {
+    for (unsigned i = 0; i < ttable.size(); ++i) {
       double tot = 0;
-      Word2Double& cpd = cit->second;
+      Word2Double& cpd = ttable[i];
       for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
         tot += it->second + alpha;
+      if (!tot) tot = 1;
       for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
         it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot));
     }
@@ -51,12 +53,12 @@ class TTable {
   }
   void Normalize() {
     ttable.swap(counts);
-    for (Word2Word2Double::iterator cit = ttable.begin();
-         cit != ttable.end(); ++cit) {
+    for (unsigned i = 0; i < ttable.size(); ++i) {
       double tot = 0;
-      Word2Double& cpd = cit->second;
+      Word2Double& cpd = ttable[i];
       for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
         tot += it->second;
+      if (!tot) tot = 1;
       for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
         it->second /= tot;
     }
@@ -64,29 +66,27 @@ class TTable {
   }
   // adds counts from another TTable - probabilities remain unchanged
   TTable& operator+=(const TTable& rhs) {
-    for (Word2Word2Double::const_iterator it = rhs.counts.begin();
-         it != rhs.counts.end(); ++it) {
-      const Word2Double& cpd = it->second;
-      Word2Double& tgt = counts[it->first];
-      for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
-        tgt[j->first] += j->second;
-      }
+    if (rhs.counts.size() > counts.size()) counts.resize(rhs.counts.size());
+    for (unsigned i = 0; i < rhs.counts.size(); ++i) {
+      const Word2Double& cpd = rhs.counts[i];
+      Word2Double& tgt = counts[i];
+      for (auto p : cpd) tgt[p.first] += p.second;
     }
     return *this;
   }
   void ShowTTable() const {
-    for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) {
-      const Word2Double& cpd = it->second;
-      for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
-        std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+    for (unsigned it = 0; it < ttable.size(); ++it) {
+      const Word2Double& cpd = ttable[it];
+      for (auto& p : cpd) {
+        std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
       }
     }
   }
   void ShowCounts() const {
-    for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) {
-      const Word2Double& cpd = it->second;
-      for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
-        std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+    for (unsigned it = 0; it < counts.size(); ++it) {
+      const Word2Double& cpd = counts[it];
+      for (auto& p : cpd) {
+        std::cerr << "c(" << TD::Convert(p.first) << '|' << TD::Convert(it) << ") = " << p.second << std::endl;
       }
     }
   }
-- 
cgit v1.2.3