summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extractor/backoff_sampler.cc26
-rw-r--r--extractor/matchings_sampler.cc5
2 files changed, 16 insertions, 15 deletions
diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc
index 28b12909..891276c6 100644
--- a/extractor/backoff_sampler.cc
+++ b/extractor/backoff_sampler.cc
@@ -16,47 +16,47 @@ PhraseLocation BackoffSampler::Sample(
const unordered_set<int>& blacklisted_sentence_ids) const {
vector<int> samples;
int low = GetRangeLow(location), high = GetRangeHigh(location);
- int last_position = low - 1;
+ int last = low - 1;
double step = max(1.0, (double) (high - low) / max_samples);
for (double num_samples = 0, i = low;
num_samples < max_samples && i < high;
++num_samples, i += step) {
- int position = GetPosition(location, round(i));
+ int sample = round(i);
+ int position = GetPosition(location, sample);
int sentence_id = source_data_array->GetSentenceId(position);
bool found = false;
- if (last_position >= position ||
+ if (last >= sample ||
blacklisted_sentence_ids.count(sentence_id)) {
for (double backoff_step = 1; backoff_step < step; ++backoff_step) {
double j = i - backoff_step;
- if (round(j) >= 0) {
- position = GetPosition(location, round(j));
+ sample = round(j);
+ if (sample >= 0) {
+ position = GetPosition(location, sample);
sentence_id = source_data_array->GetSentenceId(position);
- if (position > last_position &&
- !blacklisted_sentence_ids.count(sentence_id)) {
+ if (sample > last && !blacklisted_sentence_ids.count(sentence_id)) {
found = true;
- last_position = position;
break;
}
}
double k = i + backoff_step;
- if (round(k) < high) {
- position = GetPosition(location, round(k));
+ sample = round(k);
+ if (sample < high) {
+ position = GetPosition(location, sample);
sentence_id = source_data_array->GetSentenceId(position);
if (!blacklisted_sentence_ids.count(sentence_id)) {
found = true;
- last_position = position;
break;
}
}
}
} else {
found = true;
- last_position = position;
}
if (found) {
- AppendMatching(samples, position, location);
+ last = sample;
+ AppendMatching(samples, sample, location);
}
}
diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc
index bb916e49..75a62366 100644
--- a/extractor/matchings_sampler.cc
+++ b/extractor/matchings_sampler.cc
@@ -30,8 +30,9 @@ int MatchingsSampler::GetPosition(const PhraseLocation& location,
void MatchingsSampler::AppendMatching(vector<int>& samples, int index,
const PhraseLocation& location) const {
- copy(location.matchings->begin() + index,
- location.matchings->begin() + index + location.num_subpatterns,
+ int start = index * location.num_subpatterns;
+ copy(location.matchings->begin() + start,
+ location.matchings->begin() + start + location.num_subpatterns,
back_inserter(samples));
}