summaryrefslogtreecommitdiff
path: root/extractor/backoff_sampler.cc
blob: 891276c638eab007adb10daa30daa029e732bc45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "backoff_sampler.h"

#include "data_array.h"
#include "phrase_location.h"

namespace extractor {

BackoffSampler::BackoffSampler(
    shared_ptr<DataArray> source_data_array, int max_samples) :
    source_data_array(source_data_array), max_samples(max_samples) {}

BackoffSampler::BackoffSampler() {}

PhraseLocation BackoffSampler::Sample(
    const PhraseLocation& location,
    const unordered_set<int>& blacklisted_sentence_ids) const {
  vector<int> samples;
  int low = GetRangeLow(location), high = GetRangeHigh(location);
  int last = low - 1;
  double step = max(1.0, (double) (high - low) / max_samples);
  for (double num_samples = 0, i = low;
       num_samples < max_samples && i < high;
       ++num_samples, i += step) {
    int sample = round(i);
    int position = GetPosition(location, sample);
    int sentence_id = source_data_array->GetSentenceId(position);
    bool found = false;
    if (last >= sample ||
        blacklisted_sentence_ids.count(sentence_id)) {
      for (double backoff_step = 1; backoff_step < step; ++backoff_step) {
        double j = i - backoff_step;
        sample = round(j);
        if (sample >= 0) {
          position = GetPosition(location, sample);
          sentence_id = source_data_array->GetSentenceId(position);
          if (sample > last && !blacklisted_sentence_ids.count(sentence_id)) {
            found = true;
            break;
          }
        }

        double k = i + backoff_step;
        sample = round(k);
        if (sample < high) {
          position = GetPosition(location, sample);
          sentence_id = source_data_array->GetSentenceId(position);
          if (!blacklisted_sentence_ids.count(sentence_id)) {
            found = true;
            break;
          }
        }
      }
    } else {
      found = true;
    }

    if (found) {
      last = sample;
      AppendMatching(samples, sample, location);
    }
  }

  return PhraseLocation(samples, GetNumSubpatterns(location));
}

} // namespace extractor