From d389d25b78e5c99366f49cdcaf788693f3c01c40 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Wed, 27 Nov 2013 14:33:36 +0000 Subject: Unify sampling backoff strategy. --- extractor/backoff_sampler.cc | 66 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 extractor/backoff_sampler.cc (limited to 'extractor/backoff_sampler.cc') diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc new file mode 100644 index 00000000..28b12909 --- /dev/null +++ b/extractor/backoff_sampler.cc @@ -0,0 +1,66 @@ +#include "backoff_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +BackoffSampler::BackoffSampler( + shared_ptr source_data_array, int max_samples) : + source_data_array(source_data_array), max_samples(max_samples) {} + +BackoffSampler::BackoffSampler() {} + +PhraseLocation BackoffSampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { + vector samples; + int low = GetRangeLow(location), high = GetRangeHigh(location); + int last_position = low - 1; + double step = max(1.0, (double) (high - low) / max_samples); + for (double num_samples = 0, i = low; + num_samples < max_samples && i < high; + ++num_samples, i += step) { + int position = GetPosition(location, round(i)); + int sentence_id = source_data_array->GetSentenceId(position); + bool found = false; + if (last_position >= position || + blacklisted_sentence_ids.count(sentence_id)) { + for (double backoff_step = 1; backoff_step < step; ++backoff_step) { + double j = i - backoff_step; + if (round(j) >= 0) { + position = GetPosition(location, round(j)); + sentence_id = source_data_array->GetSentenceId(position); + if (position > last_position && + !blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + + double k = i + backoff_step; + if (round(k) < high) { + position = GetPosition(location, round(k)); + sentence_id = source_data_array->GetSentenceId(position); + if (!blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + } + } else { + found = true; + last_position = position; + } + + if (found) { + AppendMatching(samples, position, location); + } + } + + return PhraseLocation(samples, GetNumSubpatterns(location)); +} + +} // namespace extractor -- cgit v1.2.3 From 9e742b90007c32e9f0cec8940c73bd50e33b8182 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Thu, 28 Nov 2013 01:39:14 +0000 Subject: Fixes. --- extractor/backoff_sampler.cc | 26 +++++++++++++------------- extractor/matchings_sampler.cc | 5 +++-- 2 files changed, 16 insertions(+), 15 deletions(-) (limited to 'extractor/backoff_sampler.cc') diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc index 28b12909..891276c6 100644 --- a/extractor/backoff_sampler.cc +++ b/extractor/backoff_sampler.cc @@ -16,47 +16,47 @@ PhraseLocation BackoffSampler::Sample( const unordered_set& blacklisted_sentence_ids) const { vector samples; int low = GetRangeLow(location), high = GetRangeHigh(location); - int last_position = low - 1; + int last = low - 1; double step = max(1.0, (double) (high - low) / max_samples); for (double num_samples = 0, i = low; num_samples < max_samples && i < high; ++num_samples, i += step) { - int position = GetPosition(location, round(i)); + int sample = round(i); + int position = GetPosition(location, sample); int sentence_id = source_data_array->GetSentenceId(position); bool found = false; - if (last_position >= position || + if (last >= sample || blacklisted_sentence_ids.count(sentence_id)) { for (double backoff_step = 1; backoff_step < step; ++backoff_step) { double j = i - backoff_step; - if (round(j) >= 0) { - position = GetPosition(location, round(j)); + sample = round(j); + if (sample >= 0) { + position = GetPosition(location, sample); sentence_id = source_data_array->GetSentenceId(position); - if (position > last_position && - !blacklisted_sentence_ids.count(sentence_id)) { + if (sample > last && !blacklisted_sentence_ids.count(sentence_id)) { found = true; - last_position = position; break; } } double k = i + backoff_step; - if (round(k) < high) { - position = GetPosition(location, round(k)); + sample = round(k); + if (sample < high) { + position = GetPosition(location, sample); sentence_id = source_data_array->GetSentenceId(position); if (!blacklisted_sentence_ids.count(sentence_id)) { found = true; - last_position = position; break; } } } } else { found = true; - last_position = position; } if (found) { - AppendMatching(samples, position, location); + last = sample; + AppendMatching(samples, sample, location); } } diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc index bb916e49..75a62366 100644 --- a/extractor/matchings_sampler.cc +++ b/extractor/matchings_sampler.cc @@ -30,8 +30,9 @@ int MatchingsSampler::GetPosition(const PhraseLocation& location, void MatchingsSampler::AppendMatching(vector& samples, int index, const PhraseLocation& location) const { - copy(location.matchings->begin() + index, - location.matchings->begin() + index + location.num_subpatterns, + int start = index * location.num_subpatterns; + copy(location.matchings->begin() + start, + location.matchings->begin() + start + location.num_subpatterns, back_inserter(samples)); } -- cgit v1.2.3