summaryrefslogtreecommitdiff
path: root/log_reg
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2014-03-02 16:36:13 +0100
committerPatrick Simianer <p@simianer.de>2014-03-02 16:36:13 +0100
commit9d693723ba7bcf380182e8bd4d622f6d8eff4e3a (patch)
treee7c489836a41bdbcc30c9690108ed9e3a77b965d /log_reg
parent38862b7e0cde7ac7285169c10e1377357ea24488 (diff)
logistic regression
Diffstat (limited to 'log_reg')
-rwxr-xr-xlog_reg73
1 files changed, 73 insertions, 0 deletions
diff --git a/log_reg b/log_reg
new file mode 100755
index 0000000..c0a95d4
--- /dev/null
+++ b/log_reg
@@ -0,0 +1,73 @@
+#!/usr/bin/env ruby
+
+require 'nlp_ruby'
+require 'matrix'
+require 'trollop'
+
+
+def read_data fn
+ f = ReadFile.new fn
+ data = []
+ while line = f.gets
+ line.strip!
+ a = []
+ a << 1.0
+ tokenize(line).each { |i| a << i.to_f }
+ v = Vector.elements a
+ data << v
+ end
+ return data
+end
+
+def dot x, y
+ r = 0.0
+ x.each_with_index { |_,j|
+ r += x[j] * y[j]
+ }
+ return r
+end
+
+def approx_eql x, y, eps=10**-10
+ return false if !x||!y
+ return false if x.size!=y.size
+ x.each_with_index { |_,i|
+ return false if (x[i]-y[i]).abs>eps
+ }
+ return true
+end
+
+def main
+ cfg = Trollop::options do
+ opt :input, "input data", :type => :string, :required => true
+ opt :output, "1/0 output data", :type => :string, :required => true
+ end
+ data = read_data cfg[:input]
+ dim = data[0].size
+ zeros = [0.0]*dim
+ t = ReadFile.readlines(cfg[:output]).map{ |i| i.to_f }
+ model = Vector.elements zeros
+ prev_model = nil
+ gradient = Vector.elements zeros
+ hessian = Matrix.build(dim,dim) { |i,j| 0.0 }
+ i = 0
+ while true
+ i += 1
+ data.each_with_index { |x,j|
+ m = 1.0/(1+Math.exp(-dot(model, x)))
+ gradient += (m-t[j]) * x
+ hup = Matrix.column_vector(x) * Matrix.row_vector(x)
+ hessian += m*(1.0-m) * hup
+ }
+ gradient /= data.size
+ hessian /= data.size
+ model -= hessian.inverse * gradient
+ break if approx_eql model, prev_model
+ prev_model = model
+ end
+ STDERR.write "ran for #{i} iterations\n"
+ puts model.to_s
+end
+
+
+main
+