diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/Makefile.am | 30 | ||||
| -rw-r--r-- | training/candidate_set.cc | 168 | ||||
| -rw-r--r-- | training/candidate_set.h | 60 | ||||
| -rw-r--r-- | training/mpi_flex_optimize.cc | 10 | 
4 files changed, 250 insertions, 18 deletions
| diff --git a/training/Makefile.am b/training/Makefile.am index 991ac210..8124b107 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -23,11 +23,17 @@ noinst_PROGRAMS = \  TESTS = lbfgs_test optimize_test -mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc -mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +noinst_LIBRARIES = libtraining.a +libtraining_a_SOURCES = \ +  candidate_set.cc \ +  optimize.cc \ +  online_optimizer.cc -mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc online_optimizer.cc optimize.cc -mpi_flex_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_SOURCES = mpi_online_optimize.cc +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz + +mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz  mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc  mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz @@ -35,8 +41,8 @@ mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mtev  mpi_extract_features_SOURCES = mpi_extract_features.cc  mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz -mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc -mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz  mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc  mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz @@ -50,14 +56,14 @@ test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteva  model1_SOURCES = model1.cc ttables.cc  model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -lbl_model_SOURCES = lbl_model.cc optimize.cc -lbl_model_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_SOURCES = lbl_model.cc +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz  grammar_convert_SOURCES = grammar_convert.cc  grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -optimize_test_SOURCES = optimize_test.cc optimize.cc online_optimizer.cc -optimize_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_SOURCES = optimize_test.cc +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz  collapse_weights_SOURCES = collapse_weights.cc  collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz @@ -65,8 +71,8 @@ collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/lib  lbfgs_test_SOURCES = lbfgs_test.cc  lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc optimize.cc -mr_optimize_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz  mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc  mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz diff --git a/training/candidate_set.cc b/training/candidate_set.cc new file mode 100644 index 00000000..8c086ece --- /dev/null +++ b/training/candidate_set.cc @@ -0,0 +1,168 @@ +#include "candidate_set.h" + +#include <tr1/unordered_set> + +#include <boost/functional/hash.hpp> + +#include "ns.h" +#include "filelib.h" +#include "wordid.h" +#include "tdict.h" +#include "hg.h" +#include "kbest.h" +#include "viterbi.h" + +using namespace std; + +namespace training { + +struct ApproxVectorHasher { +  static const size_t MASK = 0xFFFFFFFFull; +  union UType { +    double f;   // leave as double +    size_t i; +  }; +  static inline double round(const double x) { +    UType t; +    t.f = x; +    size_t r = t.i & MASK; +    if ((r << 1) > MASK) +      t.i += MASK - r + 1; +    else +      t.i &= (1ull - MASK); +    return t.f; +  } +  size_t operator()(const SparseVector<double>& x) const { +    size_t h = 0x573915839; +    for (SparseVector<double>::const_iterator it = x.begin(); it != x.end(); ++it) { +      UType t; +      t.f = it->second; +      if (t.f) { +        size_t z = (t.i >> 32); +        boost::hash_combine(h, it->first); +        boost::hash_combine(h, z); +      } +    } +    return h; +  } +}; + +struct ApproxVectorEquals { +  bool operator()(const SparseVector<double>& a, const SparseVector<double>& b) const { +    SparseVector<double>::const_iterator bit = b.begin(); +    for (SparseVector<double>::const_iterator ait = a.begin(); ait != a.end(); ++ait) { +      if (bit == b.end() || +          ait->first != bit->first || +          ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) +        return false; +      ++bit; +    } +    if (bit != b.end()) return false; +    return true; +  } +}; + +struct CandidateCompare { +  bool operator()(const Candidate& a, const Candidate& b) const { +    ApproxVectorEquals eq; +    return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); +  } +}; + +struct CandidateHasher { +  size_t operator()(const Candidate& x) const { +    boost::hash<vector<WordID> > hhasher; +    ApproxVectorHasher vhasher; +    size_t ha = hhasher(x.ewords); +    boost::hash_combine(ha, vhasher(x.fmap)); +    return ha; +  } +}; + +static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { +  SparseVector<double>& x = *out; +  size_t last_start = cur; +  size_t last_comma = string::npos; +  while(cur <= line.size()) { +    if (line[cur] == ' ' || cur == line.size()) { +      if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { +        cerr << "[ERROR] " << line << endl << "  position = " << cur << endl; +        exit(1); +      } +      const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); +      if (cur < line.size()) line[cur] = 0; +      const double val = strtod(&line[last_comma + 1], NULL); +      x.set_value(fid, val); + +      last_comma = string::npos; +      last_start = cur+1; +    } else { +      if (line[cur] == '=') +        last_comma = cur; +    } +    ++cur; +  } +} + +void CandidateSet::WriteToFile(const string& file) const { +  WriteFile wf(file); +  ostream& out = *wf.stream(); +  out.precision(10); +  string ss; +  for (unsigned i = 0; i < cs.size(); ++i) { +    out << TD::GetString(cs[i].ewords) << endl; +    out << cs[i].fmap << endl; +    cs[i].eval_feats.Encode(&ss); +    out << ss << endl; +  } +} + +void CandidateSet::ReadFromFile(const string& file) { +  cerr << "Reading candidates from " << file << endl; +  ReadFile rf(file); +  istream& in = *rf.stream(); +  string cand; +  string feats; +  string ss; +  while(getline(in, cand)) { +    getline(in, feats); +    getline(in, ss); +    assert(in); +    cs.push_back(Candidate()); +    TD::ConvertSentence(cand, &cs.back().ewords); +    ParseSparseVector(feats, 0, &cs.back().fmap); +    cs.back().eval_feats = SufficientStats(ss); +  } +  cerr << "  read " << cs.size() << " candidates\n"; +} + +void CandidateSet::Dedup() { +  cerr << "Dedup in=" << cs.size(); +  tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare> u; +  while(cs.size() > 0) { +    u.insert(cs.back()); +    cs.pop_back(); +  } +  tr1::unordered_set<Candidate, CandidateHasher, CandidateCompare>::iterator it = u.begin(); +  while (it != u.end()) { +    cs.push_back(*it); +    it = u.erase(it); +  } +  cerr << "  out=" << cs.size() << endl; +} + +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { +  KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); + +  for (unsigned i = 0; i < kbest_size; ++i) { +    const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = +      kbest.LazyKthBest(hg.nodes_.size() - 1, i); +    if (!d) break; +    cs.push_back(Candidate(d->yield, d->feature_values)); +    if (scorer) +      scorer->Evaluate(d->yield, &cs.back().eval_feats); +  } +  Dedup(); +} + +} diff --git a/training/candidate_set.h b/training/candidate_set.h new file mode 100644 index 00000000..9d326ed0 --- /dev/null +++ b/training/candidate_set.h @@ -0,0 +1,60 @@ +#ifndef _CANDIDATE_SET_H_ +#define _CANDIDATE_SET_H_ + +#include <vector> +#include <algorithm> + +#include "ns.h" +#include "wordid.h" +#include "sparse_vector.h" + +class Hypergraph; + +namespace training { + +struct Candidate { +  Candidate() {} +  Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : +      ewords(e), +      fmap(fm) {} +  Candidate(const std::vector<WordID>& e, +            const SparseVector<double>& fm, +            const SegmentEvaluator& se) : +      ewords(e), +      fmap(fm) { +    se.Evaluate(ewords, &eval_feats); +  } + +  void swap(Candidate& other) { +    eval_feats.swap(other.eval_feats); +    ewords.swap(other.ewords); +    fmap.swap(other.fmap); +  } + +  std::vector<WordID> ewords; +  SparseVector<double> fmap; +  SufficientStats eval_feats; +}; + +// represents some kind of collection of translation candidates, e.g. +// aggregated k-best lists, sample lists, etc. +class CandidateSet { + public: +  CandidateSet() {} +  inline size_t size() const { return cs.size(); } +  const Candidate& operator[](size_t i) const { return cs[i]; } + +  void ReadFromFile(const std::string& file); +  void WriteToFile(const std::string& file) const; +  void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); +  // TODO add code to do unique k-best +  // TODO add code to draw k samples + + private: +  void Dedup(); +  std::vector<Candidate> cs; +}; + +} + +#endif diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc index a9197208..a9ead018 100644 --- a/training/mpi_flex_optimize.cc +++ b/training/mpi_flex_optimize.cc @@ -179,18 +179,16 @@ double ApplyRegularizationTerms(const double C,                                  const double T,                                  const vector<double>& weights,                                  const vector<double>& prev_weights, -                                vector<double>* g) { -  assert(weights.size() == g->size()); +                                double* g) {    double reg = 0;    for (size_t i = 0; i < weights.size(); ++i) {      const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0);      const double& w_i = weights[i]; -    double& g_i = (*g)[i];      reg += C * w_i * w_i; -    g_i += 2 * C * w_i; +    g[i] += 2 * C * w_i;      reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); -    g_i += 2 * T * (w_i - prev_w_i); +    g[i] += 2 * T * (w_i - prev_w_i);    }    return reg;  } @@ -365,7 +363,7 @@ int main(int argc, char** argv) {                                  time_series_strength, // * (iter == 0 ? 0.0 : 1.0),                                  cur_weights,                                  prev_weights, -                                &gg); +                                &gg[0]);            obj += r;            if (mi == 0 || mi == (minibatch_iterations - 1)) {              if (!mi) cerr << iter << ' '; else cerr << ' '; | 
