summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-25 21:43:57 +0200
committerPatrick Simianer <p@simianer.de>2011-09-25 21:43:57 +0200
commit4d8c300734c441821141f4bff044c439e004ff84 (patch)
tree5b9e2b7f9994d9a71e0e2d17f33ba2ff4a1145a1
parentfe471bb707226052551d75b043295ca5f57261c0 (diff)
kbest, ksampler refactoring
-rw-r--r--dtrain/dtrain.cc48
-rw-r--r--dtrain/kbestget.h55
-rw-r--r--dtrain/ksampler.h33
-rw-r--r--dtrain/pairsampling.h48
-rw-r--r--dtrain/test/example/dtrain.ini8
5 files changed, 87 insertions, 105 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index a70ca2f1..ad1ab7b7 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -10,11 +10,11 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("output", po::value<string>()->default_value("-"), "output weights file (or VOID)")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
- ("ksamples", po::value<size_t>()->default_value(100), "size of kbest or sample from forest")
+ ("k", po::value<size_t>()->default_value(100), "size of kbest or sample from forest")
("sample_from", po::value<string>()->default_value("kbest"), "where to get translations from")
("filter", po::value<string>()->default_value("unique"), "filter kbest list")
("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand")
- ("ngrams", po::value<size_t>()->default_value(3), "N for Ngrams")
+ ("N", po::value<size_t>()->default_value(3), "N for Ngrams")
("epochs", po::value<size_t>()->default_value(2), "# of iterations T")
("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring metric")
("stop_after", po::value<size_t>()->default_value(0), "stop after X input sentences")
@@ -75,8 +75,8 @@ main(int argc, char** argv)
hstreaming = true;
quiet = true;
}
- const size_t k = cfg["ksamples"].as<size_t>();
- const size_t N = cfg["ngrams"].as<size_t>();
+ const size_t k = cfg["k"].as<size_t>();
+ const size_t N = cfg["N"].as<size_t>();
const size_t T = cfg["epochs"].as<size_t>();
const size_t stop_after = cfg["stop_after"].as<size_t>();
const string filter_type = cfg["filter"].as<string>();
@@ -96,7 +96,7 @@ main(int argc, char** argv)
MT19937 rng; // random number generator
// setup decoder observer
- HypoSampler* observer;
+ HypSampler* observer;
if (sample_from == "kbest") {
observer = dynamic_cast<KBestGetter*>(new KBestGetter(k, filter_type));
} else {
@@ -274,45 +274,45 @@ main(int argc, char** argv)
decoder.Decode(src_str_buf[ii], observer);
}
- Samples* samples = observer->GetSamples();
+ vector<ScoredHyp>* samples = observer->GetSamples();
// (local) scoring
if (t > 0) ref_ids = ref_ids_buf[ii];
score_t score = 0.;
- for (size_t i = 0; i < samples->GetSize(); i++) {
- NgramCounts counts = make_ngram_counts(ref_ids, samples->sents[i], N);
+ for (size_t i = 0; i < samples->size(); i++) {
+ NgramCounts counts = make_ngram_counts(ref_ids, (*samples)[i].w, N);
if (scorer_str == "approx_bleu") {
size_t hyp_len = 0;
if (i == 0) { // 'context of 1best translations'
global_counts += counts;
- global_hyp_len += samples->sents[i].size();
+ global_hyp_len += (*samples)[i].w.size();
global_ref_len += ref_ids.size();
counts.reset();
} else {
- hyp_len = samples->sents[i].size();
+ hyp_len = (*samples)[i].w.size();
}
- NgramCounts counts_tmp = global_counts + counts;
- score = .9 * scorer(counts_tmp,
+ NgramCounts _c = global_counts + counts;
+ score = .9 * scorer(_c,
global_ref_len,
global_hyp_len + hyp_len, N, bleu_weights);
} else {
score = scorer(counts,
ref_ids.size(),
- samples->sents[i].size(), N, bleu_weights);
+ (*samples)[i].w.size(), N, bleu_weights);
}
- samples->scores.push_back(score);
+ (*samples)[i].score = (score);
if (i == 0) {
score_sum += score;
- model_sum += samples->model_scores[i];
+ model_sum += (*samples)[i].model;
}
if (verbose) {
if (i == 0) cout << "'" << TD::GetString(ref_ids) << "' [ref]" << endl;
- cout << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString(samples->sents[i]) << "'";
- cout << " [SCORE=" << score << ",model="<< samples->model_scores[i] << "]" << endl;
- cout << samples->feats[i] << endl;
+ cout << _p5 << _np << "[hyp " << i << "] " << "'" << TD::GetString((*samples)[i].w) << "'";
+ cout << " [SCORE=" << score << ",model="<< (*samples)[i].model << "]" << endl;
+ cout << (*samples)[i].f << endl;
}
} // sample/scoring loop
@@ -321,18 +321,18 @@ main(int argc, char** argv)
//////////////////////////////////////////////////////////
// UPDATE WEIGHTS
if (!noup) {
- vector<Pair> pairs;
+ vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
sample_all_pairs(samples, pairs);
if (pair_sampling == "rand")
sample_rand_pairs(samples, pairs, &rng);
- for (vector<Pair>::iterator ti = pairs.begin();
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator ti = pairs.begin();
ti != pairs.end(); ti++) {
SparseVector<double> dv;
- if (ti->first_score - ti->second_score < 0) {
- dv = ti->second - ti->first;
+ if (ti->first.score - ti->second.score < 0) {
+ dv = ti->second.f - ti->first.f;
//} else {
//dv = ti->first - ti->second;
//}
@@ -344,14 +344,14 @@ main(int argc, char** argv)
lambdas += dv * eta;
if (verbose) {
- cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
+ /*cout << "{{ f("<< ti->first_rank <<") > f(" << ti->second_rank << ") but g(i)="<< ti->first_score <<" < g(j)="<< ti->second_score << " so update" << endl;
cout << " i " << TD::GetString(samples->sents[ti->first_rank]) << endl;
cout << " " << samples->feats[ti->first_rank] << endl;
cout << " j " << TD::GetString(samples->sents[ti->second_rank]) << endl;
cout << " " << samples->feats[ti->second_rank] << endl;
cout << " diff vec: " << dv << endl;
cout << " lambdas after update: " << lambdas << endl;
- cout << "}}" << endl;
+ cout << "}}" << endl;*/
}
} else {
//SparseVector<double> reg;
diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h
index 79201182..403384de 100644
--- a/dtrain/kbestget.h
+++ b/dtrain/kbestget.h
@@ -7,28 +7,27 @@ namespace dtrain
{
-struct Samples
+struct ScoredHyp
{
- vector<SparseVector<double> > feats;
- vector<vector<WordID> > sents;
- vector<double> model_scores;
- vector<double> scores;
- size_t GetSize() { return sents.size(); }
+ vector<WordID> w;
+ SparseVector<double> f;
+ score_t model;
+ score_t score;
};
-struct HypoSampler : public DecoderObserver
+struct HypSampler : public DecoderObserver
{
- virtual Samples* GetSamples() {}
+ virtual vector<ScoredHyp>* GetSamples() {}
};
-struct KBestGetter : public HypoSampler
+struct KBestGetter : public HypSampler
{
const size_t k_;
- const string filter_type;
- Samples s;
+ const string filter_type_;
+ vector<ScoredHyp> s_;
KBestGetter(const size_t k, const string filter_type) :
- k_(k), filter_type(filter_type) {}
+ k_(k), filter_type_(filter_type) {}
virtual void
NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
@@ -36,14 +35,14 @@ struct KBestGetter : public HypoSampler
KBest(*hg);
}
- Samples* GetSamples() { return &s; }
+ vector<ScoredHyp>* GetSamples() { return &s_; }
void
KBest(const Hypergraph& forest)
{
- if (filter_type == "unique") {
+ if (filter_type_ == "unique") {
KBestUnique(forest);
- } else if (filter_type == "no") {
+ } else if (filter_type_ == "no") {
KBestNoFilter(forest);
}
}
@@ -51,36 +50,34 @@ struct KBestGetter : public HypoSampler
void
KBestUnique(const Hypergraph& forest)
{
- s.sents.clear();
- s.feats.clear();
- s.model_scores.clear();
- s.scores.clear();
+ s_.clear();
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_);
for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- s.sents.push_back(d->yield);
- s.feats.push_back(d->feature_values);
- s.model_scores.push_back(log(d->score));
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ s_.push_back(h);
}
}
void
KBestNoFilter(const Hypergraph& forest)
{
- s.sents.clear();
- s.feats.clear();
- s.model_scores.clear();
- s.scores.clear();
+ s_.clear();
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k_);
for (size_t i = 0; i < k_; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- s.sents.push_back(d->yield);
- s.feats.push_back(d->feature_values);
- s.model_scores.push_back(log(d->score));
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ s_.push_back(h);
}
}
};
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h
index ac88b643..bbe2b402 100644
--- a/dtrain/ksampler.h
+++ b/dtrain/ksampler.h
@@ -13,34 +13,33 @@ namespace dtrain
* KSampler
*
*/
-struct KSampler : public HypoSampler
+struct KSampler : public HypSampler
{
const size_t k_;
- Samples s;
- MT19937* rng;
+ vector<ScoredHyp> s_;
+ MT19937* prng_;
- explicit KSampler( const size_t k, MT19937* prng ) :
- k_(k), rng(prng) {}
+ explicit KSampler(const size_t k, MT19937* prng) :
+ k_(k), prng_(prng) {}
virtual void
- NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg )
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
{
- Sample( *hg );
+ Sample(*hg);
}
- Samples* GetSamples() { return &s; }
+ vector<ScoredHyp>* GetSamples() { return &s_; }
- void Sample( const Hypergraph& forest ) {
- s.sents.clear();
- s.feats.clear();
- s.model_scores.clear();
- s.scores.clear();
+ void Sample(const Hypergraph& forest) {
+ s_.clear();
std::vector<HypergraphSampler::Hypothesis> samples;
- HypergraphSampler::sample_hypotheses(forest, k_, rng, &samples);
+ HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples);
for ( size_t i = 0; i < k_; ++i ) {
- s.sents.push_back( samples[i].words );
- s.feats.push_back( samples[i].fmap );
- s.model_scores.push_back( log(samples[i].model_score) );
+ ScoredHyp h;
+ h.w = samples[i].words;
+ h.f = samples[i].fmap;
+ h.model = log(samples[i].model_score);
+ s_.push_back(h);
}
}
};
diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h
index a8521485..2e4ab155 100644
--- a/dtrain/pairsampling.h
+++ b/dtrain/pairsampling.h
@@ -8,47 +8,33 @@ namespace dtrain
{
-struct Pair
-{
- SparseVector<double> first, second;
- size_t first_rank, second_rank;
- double first_score, second_score;
-};
-
inline void
-sample_all_pairs(Samples* kb, vector<Pair> &training)
+sample_all_pairs(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> > &training)
{
- for (size_t i = 0; i < kb->GetSize()-1; i++) {
- for (size_t j = i+1; j < kb->GetSize(); j++) {
- Pair p;
- p.first = kb->feats[i];
- p.second = kb->feats[j];
- p.first_rank = i;
- p.second_rank = j;
- p.first_score = kb->scores[i];
- p.second_score = kb->scores[j];
+ for (size_t i = 0; i < s->size()-1; i++) {
+ for (size_t j = i+1; j < s->size(); j++) {
+ pair<ScoredHyp,ScoredHyp> p;
+ p.first = (*s)[i];
+ p.second = (*s)[j];
training.push_back(p);
- } // j
- } // i
+ }
+ }
}
inline void
-sample_rand_pairs(Samples* kb, vector<Pair> &training, MT19937* prng)
+sample_rand_pairs(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> > &training,
+ MT19937* prng)
{
- for (size_t i = 0; i < kb->GetSize()-1; i++) {
- for (size_t j = i+1; j < kb->GetSize(); j++) {
+ for (size_t i = 0; i < s->size()-1; i++) {
+ for (size_t j = i+1; j < s->size(); j++) {
if (prng->next() < .5) {
- Pair p;
- p.first = kb->feats[i];
- p.second = kb->feats[j];
- p.first_rank = i;
- p.second_rank = j;
- p.first_score = kb->scores[i];
- p.second_score = kb->scores[j];
+ pair<ScoredHyp,ScoredHyp> p;
+ p.first = (*s)[i];
+ p.second = (*s)[j];
training.push_back(p);
}
- } // j
- } // i
+ }
+ }
}
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index aee3c89e..00ba72f9 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -1,11 +1,11 @@
decoder_config=test/example/cdec.ini
-ksamples=100
-ngrams=3
+k=100
+N=3
epochs=1000
input=test/example/nc-1k.gz
scorer=stupid_bleu
output=test/example/weights.gz
-stop_after=10
-sample_from=kbest
+stop_after=100
+sample_from=forest
pair_sampling=all
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough