summaryrefslogtreecommitdiff
path: root/train-test-split
blob: 6aa4796f418ce5383462dc3ca1d907b4aab147b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env ruby

require 'zipf'
require 'optimist'

conf = Optimist::options do
  opt :source, "source file", :type => :string, :required => true
  opt :target, "target file", :type => :string, :required => true
  opt :size, "one size", :type => :int, :required => true
  opt :repeat, "number of repetitions", :type => :int, :default => 1
  opt :prefix, "prefix for output files", :type => :string, :default => "split"
  opt :sets, "number of sets", :type => :int, :default => 1
end

source_filename = conf[:source]
source_extension = source_filename.split('.').last
source_lines = ReadFile.readlines source_filename

target_filename = conf[:target]
target_extension = target_filename.split('.').last
target_lines = ReadFile.readlines target_filename

size = conf[:size]

if source_lines.size != target_lines.size
  STDERR.write "Unbalanced files (#{source_lines.size} vs. #{target_lines.size}), exiting!\n"
  exit 1
end

index = (0..source_lines.size-1).to_a
conf[:repeat].times { |i|
  `mkdir split_#{i}`

  sampled = index.sample(size * conf[:sets])

  test_strings_source = {}
  test_strings_target = {}

  conf[:sets].times { |s|
    slice_start_index = (s-1) * size

    source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{source_extension}"
    target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{target_extension}"

    sampled.slice(slice_start_index, size).each { |j|
      source_file.write source_lines[j]
      target_file.write target_lines[j]
      test_strings_source[source_lines[j].downcase] = true
      test_strings_target[target_lines[j].downcase] = true
    }
    source_file.close; target_file.close
  }

  filtered_index = index.reject{ |j| sampled.include? j }
  source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{source_extension}"
  target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{target_extension}"
  filtered_index.each { |j|
    if not test_strings_source.include? source_lines[j].downcase \
        and not test_strings_target.include? target_lines[j]
      source_file.write source_lines[j]
      target_file.write target_lines[j]
    end
  }
  source_file.close; target_file.close

  i += 1
}