summaryrefslogtreecommitdiff
path: root/perceptron/perceptron.rb
blob: 4b9f2fa857e7ee6a2517603d3cd992b08c5f4e78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#!/usr/bin/env ruby

require 'zlib'

STDOUT.set_encoding 'utf-8'
STDOUT.sync = true


def ngrams_it(s, n, fix=false)
  a = s.strip.split
  a.each_with_index { |tok, i|
    tok.strip!
    0.upto([n-1, a.size-i-1].min) { |m|
      yield a[i..i+m] if !(fix^(a[i..i+m].size==n))
    }
  }
end

class NamedSparseVector
  attr_accessor :h

  def initialize init=nil
    @h = {}
    @h = init if init
    @h.default = 0.0
  end

  def + other
    new_h = Hash.new
    new_h.update @h
    ret = NamedSparseVector.new new_h
    other.each_pair { |k,v| ret[k]+=v }
    return ret
  end

  def - other
    new_h = Hash.new
    new_h.update @h
    ret = NamedSparseVector.new new_h
    other.each_pair { |k,v| ret[k]-=v }
    return ret
  end

  def * scalar
    raise ArgumentError, "Arg is not numeric #{scalar}" unless scalar.is_a? Numeric
    ret = NamedSparseVector.new
    @h.keys.each { |k| ret[k] = @h[k]*scalar }
    return ret
  end

  def dot other
    sum = 0.0
    @h.each_pair { |k,v|
      sum += v * other[k]
    }
    return sum
  end

  def [] k
    @h[k]
  end

  def []= k, v
    @h[k] = v
  end

  def each_pair
    @h.each_pair { |k,v| yield k,v }
  end

  def to_s
    @h.to_s
  end

  def size
    @h.keys.size
  end
end

def sparse_vector_test
  a = NamedSparseVector.new
  b = NamedSparseVector.new
  a["a"] = 1
  b["b"] = 1
  c = NamedSparseVector.new
  c += (a-b)*0.1
  puts "a=#{a.to_s}, b=#{b.to_s}, (a-b)*0.1 = #{c.to_s}"
end

def write_model fn, w
  Zlib::GzipWriter.open(fn) do |gz|
    gz.write w.to_s+"\n"
  end
end

def read_model fn
  Zlib::GzipReader.open(fn) do |gz|
    return NamedSparseVector.new eval(gz.read)
  end
end

def usage
  STDERR.write "#{__FILE__} <config file>\n"
  exit 1
end
usage if ARGV.size != 1

def read_cfg fn
  begin
    f = File.new fn, 'r'
  rescue
    STDERR.write "#{__FILE__}: Can't find file '#{fn}', exiting.\n"
    exit 1
  end
  cfg = {}
  while line = f.gets
    next if /^\s*$/.match line
    k, v = line.strip.split /\s*=\s*/, 2
    cfg[k] = v unless k[0]=='#' # no inline comments
  end
  return cfg
end

def parse_example s
  a = s.split
  label = a[0].to_f
  fv = NamedSparseVector.new
  a[1..a.size-2].each { |i|
    name,val = i.split ':'
    fv[name] = val.to_f
  }
  return [label, fv]
end

# main
cfg = read_cfg ARGV[0]
silent = true if cfg['silent']
max_iter = 1000
max_iter = cfg['max_iter'].to_i if cfg['max_iter']
errors = 0
start = Time.now
w = NamedSparseVector.new
bias = 0

train = []
train_f = File.new cfg['train'], 'r'
while line = train_f.gets
  train << parse_example(line.strip)
end
train_f.close

test = []
if cfg['test']
  test_f = File.new cfg['test'], 'r'
  while line = test_f.gets
    test << parse_example(line.strip)
  end
  test_f.close
end

iter = 0
while true
  err = 0
  train.each_with_index { |i, idx|
    if (i[0] * (w.dot(i[1]) + bias)) <= i[0]
      w += i[1] * i[0]
      bias += i[0]
      err += 1
    end
  }
  puts "iter:#{iter} err=#{err}"
  iter += 1
  break if err==0 || iter==max_iter
end

elapsed = Time.now-start
puts "#{elapsed.round 2} s, #{(elapsed/Float(iter+1)).round 2} s per iter; model size: #{w.size}" if !silent
puts cfg['model_file']
write_model cfg['model_file'], w

if cfg['test']
  test_err = 0
  test.each { |i|
    if (i[0] * (w.dot(i[1]) + bias)) <= i[0]
      test_err += 1
    end
  }
  puts "test error=#{test_err}"
end