summaryrefslogtreecommitdiff
path: root/train-test-split
blob: ab2f9167a09afdc34ef584fdae351c2bc35b24b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env ruby

require 'zipf'
require 'optimist'

conf = Optimist::options do
  opt :foreign, "foreign file", :type => :string, :required => true
  opt :english, "english file", :type => :string, :required => true
  opt :size, "one size", :type => :int, :required => true
  opt :repeat, "number of repetitions", :type => :int, :default => 1
  opt :prefix, "prefix for output files", :type => :string
end
fn = conf[:foreign]
fn_ext = fn.split('.').last
f = ReadFile.readlines fn
en = conf[:english]
en_ext = en.split('.').last
e = ReadFile.readlines en
size = conf[:size]
nlines_f = `wc -l #{fn}`.split()[0].to_i
nlines_e = `wc -l #{en}`.split()[0].to_i
if nlines_f != nlines_e
  STDERR.write "Unbalanced files (#{nlines_f} vs. #{nlines_e}), exiting!\n"
  exit 1
end

prefix = conf[:prefix]
a = (0..nlines_e-1).to_a
i = 0
conf[:repeat].times {
  if conf[:repeat] == 1
    infix = ""
  else
    infix = ".#{i}"
  end
  b = a.sample(size)
  ax = a.reject{|j| b.include? j}
  `mkdir split_#{i}`
  new_f = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{fn_ext}"
  new_e = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{en_ext}"
  ax.each { |j|
    new_f.write f[j]
    new_e.write e[j]
  }
  new_f.close; new_e.close
  new_f = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{fn_ext}"
  new_e = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{en_ext}"
  b.each { |j|
    new_f.write f[j]
    new_e.write e[j]
  }
  new_f.close; new_e.close
  i += 1
}