summaryrefslogtreecommitdiff
path: root/lib/nlp_ruby/bleu.rb
blob: 42be45ec5f276ab2761d1f96bb5b7f4cbc5d705d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
module BLEU


class BLEU::NgramCounts
  attr_accessor :sum, :clipped, :ref_len, :hyp_len, :n

  def initialize(n)
    @n = 0
    @sum = []
    @clipped = []
    @ref_len = 0.0
    @hyp_len = 0.0
    grow(n)
  end

  def grow(n)
    (n-@n).times {
      @sum << 0.0
      @clipped << 0.0
    }
    @n = n
  end

  def plus_eq(other)
    if other.n > @n then grow(other.n) end
    0.upto(other.n-1) { |m|
      @sum[m] += other.sum[m]
      @clipped[m] += other.clipped[m]
    }
    @ref_len += other.ref_len
    @hyp_len += other.hyp_len
  end

  def to_s
    return "n=#{n} sum=#{sum} clipped=#{clipped} ref_len=#{ref_len} hyp_len=#{hyp_len}"
  end
end

class BLEU::Ngrams
  def initialize
    @h_ = {}
    @h_.default = 0
  end

  def add(k)
    if k.class == Array then k = k.join ' ' end
    @h_[k] += 1
  end

  def get_count(k)
    if k.class == Array then k = k.join ' ' end
    return @h_[k]
  end

  def each
    @h_.each_pair { |k,v|
      yield k.split, v
    }
  end

  def to_s
    @h_.to_s
  end
end

def BLEU::get_counts hypothesis, reference, n, times=1
  p = NgramCounts.new n
  r = Ngrams.new
  ngrams(reference, n) { |ng| r.add ng }
  h = Ngrams.new
  ngrams(hypothesis, n) { |ng| h.add ng }
  h.each { |ng,count|
    sz = ng.size-1
    p.sum[sz] += count * times
    p.clipped[sz] += [r.get_count(ng), count].min * times
  }
  p.ref_len = tokenize(reference.strip).size * times
  p.hyp_len = tokenize(hypothesis.strip).size * times
  return p
end

def BLEU::brevity_penalty(c, r)
  if c > r then return 1.0 end
  return Math.exp(1-r/c)
end

def BLEU::bleu(counts, n, debug=false)
  corpus_stats = NgramCounts.new n
  counts.each { |i| corpus_stats.plus_eq i }
  sum = 0.0
  w = 1.0/n
  0.upto(n-1) { |m|
    STDERR.write "#{m+1} #{corpus_stats.clipped[m]} / #{corpus_stats.sum[m]}\n" if debug
    return 0.0 if corpus_stats.clipped[m] == 0 or corpus_stats.sum == 0
    sum += w * Math.log(corpus_stats.clipped[m] / corpus_stats.sum[m])
  }
  if debug
    STDERR.write "BP #{brevity_penalty(corpus_stats.hyp_len, corpus_stats.ref_len)}\n"
    STDERR.write "sum #{Math.exp(sum)}\n"
  end
  return brevity_penalty(corpus_stats.hyp_len, corpus_stats.ref_len) * Math.exp(sum)
end

def BLEU::hbleu counts, n, debug=false
  (100*bleu(counts, n, debug)).round(3)
end


end