diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-05-29 21:39:22 -0400 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-05-29 21:39:22 -0400 | 
| commit | 090a64e73f94a6a35e5364a9d416dcf75c0a2938 (patch) | |
| tree | cd65973c232103b6fba00d17653c9c343b5fa99e /training | |
| parent | 7d1a63de4894de55f152bb806c85d42b745b9661 (diff) | |
add support to rampion for accumulating k-best lists
Diffstat (limited to 'training')
| -rw-r--r-- | training/candidate_set.cc | 6 | ||||
| -rw-r--r-- | training/candidate_set.h | 21 | ||||
| -rw-r--r-- | training/mpi_flex_optimize.cc | 10 | 
3 files changed, 22 insertions, 15 deletions
| diff --git a/training/candidate_set.cc b/training/candidate_set.cc index e2ca9ad2..8c086ece 100644 --- a/training/candidate_set.cc +++ b/training/candidate_set.cc @@ -112,7 +112,7 @@ void CandidateSet::WriteToFile(const string& file) const {    for (unsigned i = 0; i < cs.size(); ++i) {      out << TD::GetString(cs[i].ewords) << endl;      out << cs[i].fmap << endl; -    cs[i].score_stats.Encode(&ss); +    cs[i].eval_feats.Encode(&ss);      out << ss << endl;    }  } @@ -131,7 +131,7 @@ void CandidateSet::ReadFromFile(const string& file) {      cs.push_back(Candidate());      TD::ConvertSentence(cand, &cs.back().ewords);      ParseSparseVector(feats, 0, &cs.back().fmap); -    cs.back().score_stats = SufficientStats(ss); +    cs.back().eval_feats = SufficientStats(ss);    }    cerr << "  read " << cs.size() << " candidates\n";  } @@ -160,7 +160,7 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c      if (!d) break;      cs.push_back(Candidate(d->yield, d->feature_values));      if (scorer) -      scorer->Evaluate(d->yield, &cs.back().score_stats); +      scorer->Evaluate(d->yield, &cs.back().eval_feats);    }    Dedup();  } diff --git a/training/candidate_set.h b/training/candidate_set.h index 824a4de2..9d326ed0 100644 --- a/training/candidate_set.h +++ b/training/candidate_set.h @@ -15,16 +15,25 @@ namespace training {  struct Candidate {    Candidate() {}    Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : -    ewords(e), -    fmap(fm) {} -  std::vector<WordID> ewords; -  SparseVector<double> fmap; -  SufficientStats score_stats; +      ewords(e), +      fmap(fm) {} +  Candidate(const std::vector<WordID>& e, +            const SparseVector<double>& fm, +            const SegmentEvaluator& se) : +      ewords(e), +      fmap(fm) { +    se.Evaluate(ewords, &eval_feats); +  } +    void swap(Candidate& other) { -    score_stats.swap(other.score_stats); +    eval_feats.swap(other.eval_feats);      ewords.swap(other.ewords);      fmap.swap(other.fmap);    } + +  std::vector<WordID> ewords; +  SparseVector<double> fmap; +  SufficientStats eval_feats;  };  // represents some kind of collection of translation candidates, e.g. diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc index a9197208..a9ead018 100644 --- a/training/mpi_flex_optimize.cc +++ b/training/mpi_flex_optimize.cc @@ -179,18 +179,16 @@ double ApplyRegularizationTerms(const double C,                                  const double T,                                  const vector<double>& weights,                                  const vector<double>& prev_weights, -                                vector<double>* g) { -  assert(weights.size() == g->size()); +                                double* g) {    double reg = 0;    for (size_t i = 0; i < weights.size(); ++i) {      const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0);      const double& w_i = weights[i]; -    double& g_i = (*g)[i];      reg += C * w_i * w_i; -    g_i += 2 * C * w_i; +    g[i] += 2 * C * w_i;      reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); -    g_i += 2 * T * (w_i - prev_w_i); +    g[i] += 2 * T * (w_i - prev_w_i);    }    return reg;  } @@ -365,7 +363,7 @@ int main(int argc, char** argv) {                                  time_series_strength, // * (iter == 0 ? 0.0 : 1.0),                                  cur_weights,                                  prev_weights, -                                &gg); +                                &gg[0]);            obj += r;            if (mi == 0 || mi == (minibatch_iterations - 1)) {              if (!mi) cerr << iter << ' '; else cerr << ' '; | 
