From 9d693723ba7bcf380182e8bd4d622f6d8eff4e3a Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Sun, 2 Mar 2014 16:36:13 +0100 Subject: logistic regression --- lin_reg | 8 ++--- log_reg | 73 ++++++++++++++++++++++++++++++++++++++++++ test/lin_reg/exptected.txt | 3 ++ test/lin_reg/input.dat | 50 +++++++++++++++++++++++++++++ test/lin_reg/output.dat | 50 +++++++++++++++++++++++++++++ test/lin_reg/x.dat | 50 ----------------------------- test/lin_reg/y.dat | 50 ----------------------------- test/log_reg/expected.txt | 2 ++ test/log_reg/input.dat | 80 ++++++++++++++++++++++++++++++++++++++++++++++ test/log_reg/output.dat | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 342 insertions(+), 104 deletions(-) create mode 100755 log_reg create mode 100644 test/lin_reg/exptected.txt create mode 100644 test/lin_reg/input.dat create mode 100644 test/lin_reg/output.dat delete mode 100644 test/lin_reg/x.dat delete mode 100644 test/lin_reg/y.dat create mode 100644 test/log_reg/expected.txt create mode 100644 test/log_reg/input.dat create mode 100644 test/log_reg/output.dat diff --git a/lin_reg b/lin_reg index 3546c3e..d512104 100755 --- a/lin_reg +++ b/lin_reg @@ -44,13 +44,13 @@ def main i += 1 u = SparseVector.new zeros overall_loss = 0.0 - data.each_with_index { |d,j| - loss = model.dot(d) - t[j] + data.each_with_index { |x,j| + loss = model.dot(x) - t[j] overall_loss += loss**2 - u += d * (loss * (1.0/t.size)) + u += x * loss } STDERR.write "#{i} #{overall_loss/data.size}\n" if cfg[:show_loss] - u *= cfg[:learning_rate] + u *= cfg[:learning_rate]*(1.0/t.size) model -= u if model.approx_eql? prev_model stop += 1 diff --git a/log_reg b/log_reg new file mode 100755 index 0000000..c0a95d4 --- /dev/null +++ b/log_reg @@ -0,0 +1,73 @@ +#!/usr/bin/env ruby + +require 'nlp_ruby' +require 'matrix' +require 'trollop' + + +def read_data fn + f = ReadFile.new fn + data = [] + while line = f.gets + line.strip! + a = [] + a << 1.0 + tokenize(line).each { |i| a << i.to_f } + v = Vector.elements a + data << v + end + return data +end + +def dot x, y + r = 0.0 + x.each_with_index { |_,j| + r += x[j] * y[j] + } + return r +end + +def approx_eql x, y, eps=10**-10 + return false if !x||!y + return false if x.size!=y.size + x.each_with_index { |_,i| + return false if (x[i]-y[i]).abs>eps + } + return true +end + +def main + cfg = Trollop::options do + opt :input, "input data", :type => :string, :required => true + opt :output, "1/0 output data", :type => :string, :required => true + end + data = read_data cfg[:input] + dim = data[0].size + zeros = [0.0]*dim + t = ReadFile.readlines(cfg[:output]).map{ |i| i.to_f } + model = Vector.elements zeros + prev_model = nil + gradient = Vector.elements zeros + hessian = Matrix.build(dim,dim) { |i,j| 0.0 } + i = 0 + while true + i += 1 + data.each_with_index { |x,j| + m = 1.0/(1+Math.exp(-dot(model, x))) + gradient += (m-t[j]) * x + hup = Matrix.column_vector(x) * Matrix.row_vector(x) + hessian += m*(1.0-m) * hup + } + gradient /= data.size + hessian /= data.size + model -= hessian.inverse * gradient + break if approx_eql model, prev_model + prev_model = model + end + STDERR.write "ran for #{i} iterations\n" + puts model.to_s +end + + +main + diff --git a/test/lin_reg/exptected.txt b/test/lin_reg/exptected.txt new file mode 100644 index 0000000..13de1fc --- /dev/null +++ b/test/lin_reg/exptected.txt @@ -0,0 +1,3 @@ +ran for 2527 iterations + R^2=0.858063223720823 +{0=>0.7501625304145768, 1=>0.06388116702419537} diff --git a/test/lin_reg/input.dat b/test/lin_reg/input.dat new file mode 100644 index 0000000..3d93394 --- /dev/null +++ b/test/lin_reg/input.dat @@ -0,0 +1,50 @@ + 2.0658746e+00 + 2.3684087e+00 + 2.5399929e+00 + 2.5420804e+00 + 2.5490790e+00 + 2.7866882e+00 + 2.9116825e+00 + 3.0356270e+00 + 3.1146696e+00 + 3.1582389e+00 + 3.3275944e+00 + 3.3793165e+00 + 3.4122006e+00 + 3.4215823e+00 + 3.5315732e+00 + 3.6393002e+00 + 3.6732537e+00 + 3.9256462e+00 + 4.0498646e+00 + 4.2483348e+00 + 4.3440052e+00 + 4.3826531e+00 + 4.4230602e+00 + 4.6102443e+00 + 4.6881183e+00 + 4.9777333e+00 + 5.0359967e+00 + 5.0684536e+00 + 5.4161491e+00 + 5.4395623e+00 + 5.4563207e+00 + 5.5698458e+00 + 5.6015729e+00 + 5.6877617e+00 + 5.7215602e+00 + 5.8538914e+00 + 6.1978026e+00 + 6.3510941e+00 + 6.4797033e+00 + 6.7383791e+00 + 6.8637686e+00 + 7.0223387e+00 + 7.0782373e+00 + 7.1514232e+00 + 7.4664023e+00 + 7.5973874e+00 + 7.7440717e+00 + 7.7729662e+00 + 7.8264514e+00 + 7.9306356e+00 diff --git a/test/lin_reg/output.dat b/test/lin_reg/output.dat new file mode 100644 index 0000000..1f4f963 --- /dev/null +++ b/test/lin_reg/output.dat @@ -0,0 +1,50 @@ + 7.7918926e-01 + 9.1596757e-01 + 9.0538354e-01 + 9.0566138e-01 + 9.3898890e-01 + 9.6684740e-01 + 9.6436824e-01 + 9.1445939e-01 + 9.3933944e-01 + 9.6074971e-01 + 8.9837094e-01 + 9.1209739e-01 + 9.4238499e-01 + 9.6624578e-01 + 1.0526500e+00 + 1.0143791e+00 + 9.5969426e-01 + 9.6853716e-01 + 1.0766065e+00 + 1.1454978e+00 + 1.0340625e+00 + 1.0070009e+00 + 9.6683648e-01 + 1.0895919e+00 + 1.0634462e+00 + 1.1237239e+00 + 1.0323374e+00 + 1.0874452e+00 + 1.0702988e+00 + 1.1606493e+00 + 1.0778037e+00 + 1.1069758e+00 + 1.0971875e+00 + 1.1648603e+00 + 1.1411796e+00 + 1.0844156e+00 + 1.1252493e+00 + 1.1168341e+00 + 1.1970789e+00 + 1.2069462e+00 + 1.1251046e+00 + 1.1235672e+00 + 1.2132829e+00 + 1.2522652e+00 + 1.2497065e+00 + 1.1799706e+00 + 1.1897299e+00 + 1.3029934e+00 + 1.2601134e+00 + 1.2562267e+00 diff --git a/test/lin_reg/x.dat b/test/lin_reg/x.dat deleted file mode 100644 index 3d93394..0000000 --- a/test/lin_reg/x.dat +++ /dev/null @@ -1,50 +0,0 @@ - 2.0658746e+00 - 2.3684087e+00 - 2.5399929e+00 - 2.5420804e+00 - 2.5490790e+00 - 2.7866882e+00 - 2.9116825e+00 - 3.0356270e+00 - 3.1146696e+00 - 3.1582389e+00 - 3.3275944e+00 - 3.3793165e+00 - 3.4122006e+00 - 3.4215823e+00 - 3.5315732e+00 - 3.6393002e+00 - 3.6732537e+00 - 3.9256462e+00 - 4.0498646e+00 - 4.2483348e+00 - 4.3440052e+00 - 4.3826531e+00 - 4.4230602e+00 - 4.6102443e+00 - 4.6881183e+00 - 4.9777333e+00 - 5.0359967e+00 - 5.0684536e+00 - 5.4161491e+00 - 5.4395623e+00 - 5.4563207e+00 - 5.5698458e+00 - 5.6015729e+00 - 5.6877617e+00 - 5.7215602e+00 - 5.8538914e+00 - 6.1978026e+00 - 6.3510941e+00 - 6.4797033e+00 - 6.7383791e+00 - 6.8637686e+00 - 7.0223387e+00 - 7.0782373e+00 - 7.1514232e+00 - 7.4664023e+00 - 7.5973874e+00 - 7.7440717e+00 - 7.7729662e+00 - 7.8264514e+00 - 7.9306356e+00 diff --git a/test/lin_reg/y.dat b/test/lin_reg/y.dat deleted file mode 100644 index 1f4f963..0000000 --- a/test/lin_reg/y.dat +++ /dev/null @@ -1,50 +0,0 @@ - 7.7918926e-01 - 9.1596757e-01 - 9.0538354e-01 - 9.0566138e-01 - 9.3898890e-01 - 9.6684740e-01 - 9.6436824e-01 - 9.1445939e-01 - 9.3933944e-01 - 9.6074971e-01 - 8.9837094e-01 - 9.1209739e-01 - 9.4238499e-01 - 9.6624578e-01 - 1.0526500e+00 - 1.0143791e+00 - 9.5969426e-01 - 9.6853716e-01 - 1.0766065e+00 - 1.1454978e+00 - 1.0340625e+00 - 1.0070009e+00 - 9.6683648e-01 - 1.0895919e+00 - 1.0634462e+00 - 1.1237239e+00 - 1.0323374e+00 - 1.0874452e+00 - 1.0702988e+00 - 1.1606493e+00 - 1.0778037e+00 - 1.1069758e+00 - 1.0971875e+00 - 1.1648603e+00 - 1.1411796e+00 - 1.0844156e+00 - 1.1252493e+00 - 1.1168341e+00 - 1.1970789e+00 - 1.2069462e+00 - 1.1251046e+00 - 1.1235672e+00 - 1.2132829e+00 - 1.2522652e+00 - 1.2497065e+00 - 1.1799706e+00 - 1.1897299e+00 - 1.3029934e+00 - 1.2601134e+00 - 1.2562267e+00 diff --git a/test/log_reg/expected.txt b/test/log_reg/expected.txt new file mode 100644 index 0000000..46a03ef --- /dev/null +++ b/test/log_reg/expected.txt @@ -0,0 +1,2 @@ +ran for 15 iterations +Vector[-16.378743410287445, 0.1483407737248737, 0.1589084517934473] diff --git a/test/log_reg/input.dat b/test/log_reg/input.dat new file mode 100644 index 0000000..eed0ab1 --- /dev/null +++ b/test/log_reg/input.dat @@ -0,0 +1,80 @@ + 5.5500000e+01 6.9500000e+01 + 4.1000000e+01 8.1500000e+01 + 5.3500000e+01 8.6000000e+01 + 4.6000000e+01 8.4000000e+01 + 4.1000000e+01 7.3500000e+01 + 5.1500000e+01 6.9000000e+01 + 5.1000000e+01 6.2500000e+01 + 4.2000000e+01 7.5000000e+01 + 5.3500000e+01 8.3000000e+01 + 5.7500000e+01 7.1000000e+01 + 4.2500000e+01 7.2500000e+01 + 4.1000000e+01 8.0000000e+01 + 4.6000000e+01 8.2000000e+01 + 4.6000000e+01 6.0500000e+01 + 4.9500000e+01 7.6000000e+01 + 4.1000000e+01 7.6000000e+01 + 4.8500000e+01 7.2500000e+01 + 5.1500000e+01 8.2500000e+01 + 4.4500000e+01 7.0500000e+01 + 4.4000000e+01 6.6000000e+01 + 3.3000000e+01 7.6500000e+01 + 3.3500000e+01 7.8500000e+01 + 3.1500000e+01 7.2000000e+01 + 3.3000000e+01 8.1500000e+01 + 4.2000000e+01 5.9500000e+01 + 3.0000000e+01 6.4000000e+01 + 6.1000000e+01 4.5000000e+01 + 4.9000000e+01 7.9000000e+01 + 2.6500000e+01 6.4500000e+01 + 3.4000000e+01 7.1500000e+01 + 4.2000000e+01 8.3500000e+01 + 2.9500000e+01 7.4500000e+01 + 3.9500000e+01 7.0000000e+01 + 5.1500000e+01 6.6000000e+01 + 4.1500000e+01 7.1500000e+01 + 4.2500000e+01 7.9500000e+01 + 3.5000000e+01 5.9500000e+01 + 3.8500000e+01 7.3500000e+01 + 3.2000000e+01 8.1500000e+01 + 4.6000000e+01 6.0500000e+01 + 3.6500000e+01 5.3000000e+01 + 3.6500000e+01 5.3500000e+01 + 2.4000000e+01 6.0500000e+01 + 1.9000000e+01 5.7500000e+01 + 3.4500000e+01 6.0000000e+01 + 3.7500000e+01 6.4500000e+01 + 3.5500000e+01 5.1000000e+01 + 3.7000000e+01 5.0500000e+01 + 2.1500000e+01 4.2000000e+01 + 3.5500000e+01 5.8500000e+01 + 2.6500000e+01 6.8500000e+01 + 2.6500000e+01 5.5500000e+01 + 1.8500000e+01 6.7000000e+01 + 4.0000000e+01 6.7000000e+01 + 3.2500000e+01 7.1500000e+01 + 3.9000000e+01 7.1500000e+01 + 4.3000000e+01 5.5500000e+01 + 2.2000000e+01 5.4000000e+01 + 3.6000000e+01 6.2500000e+01 + 3.1000000e+01 5.5500000e+01 + 3.8500000e+01 7.6000000e+01 + 4.0000000e+01 7.5000000e+01 + 3.7500000e+01 6.3000000e+01 + 2.4500000e+01 5.8000000e+01 + 3.0000000e+01 6.7000000e+01 + 3.3000000e+01 5.6000000e+01 + 5.6500000e+01 6.1000000e+01 + 4.1000000e+01 5.7000000e+01 + 4.9500000e+01 6.3000000e+01 + 3.4500000e+01 7.2500000e+01 + 3.2500000e+01 6.9000000e+01 + 3.6000000e+01 7.3000000e+01 + 2.7000000e+01 5.3500000e+01 + 4.1000000e+01 6.3500000e+01 + 2.9500000e+01 5.2500000e+01 + 2.0000000e+01 6.5500000e+01 + 3.8000000e+01 6.5000000e+01 + 1.8500000e+01 7.4500000e+01 + 1.6000000e+01 7.2500000e+01 + 3.3500000e+01 6.8000000e+01 diff --git a/test/log_reg/output.dat b/test/log_reg/output.dat new file mode 100644 index 0000000..51283c0 --- /dev/null +++ b/test/log_reg/output.dat @@ -0,0 +1,80 @@ + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 1.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 + 0.0000000e+00 -- cgit v1.2.3