From 6b72035db4a17dd9675cac4cd7adc17f8b1998e2 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Sun, 27 Sep 2020 19:18:16 +0200 Subject: train-test-split: proper implementation --- train-test-split | 85 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 37 deletions(-) (limited to 'train-test-split') diff --git a/train-test-split b/train-test-split index aa55534..6aa4796 100755 --- a/train-test-split +++ b/train-test-split @@ -4,53 +4,64 @@ require 'zipf' require 'optimist' conf = Optimist::options do - opt :foreign, "foreign file", :type => :string, :required => true - opt :english, "english file", :type => :string, :required => true + opt :source, "source file", :type => :string, :required => true + opt :target, "target file", :type => :string, :required => true opt :size, "one size", :type => :int, :required => true opt :repeat, "number of repetitions", :type => :int, :default => 1 - opt :prefix, "prefix for output files", :type => :string + opt :prefix, "prefix for output files", :type => :string, :default => "split" opt :sets, "number of sets", :type => :int, :default => 1 end -fn = conf[:foreign] -fn_ext = fn.split('.').last -f = ReadFile.readlines fn -en = conf[:english] -en_ext = en.split('.').last -e = ReadFile.readlines en + +source_filename = conf[:source] +source_extension = source_filename.split('.').last +source_lines = ReadFile.readlines source_filename + +target_filename = conf[:target] +target_extension = target_filename.split('.').last +target_lines = ReadFile.readlines target_filename + size = conf[:size] -nlines_f = `wc -l #{fn}`.split()[0].to_i -nlines_e = `wc -l #{en}`.split()[0].to_i -if nlines_f != nlines_e - STDERR.write "Unbalanced files (#{nlines_f} vs. #{nlines_e}), exiting!\n" + +if source_lines.size != target_lines.size + STDERR.write "Unbalanced files (#{source_lines.size} vs. #{target_lines.size}), exiting!\n" exit 1 end -prefix = conf[:prefix] -a = (0..nlines_e-1).to_a -i = 0 -conf[:repeat].times { - if conf[:repeat] == 1 - infix = "" - else - infix = ".#{i}" - end - b = a.sample(size) - ax = a.reject{|j| b.include? j} +index = (0..source_lines.size-1).to_a +conf[:repeat].times { |i| `mkdir split_#{i}` - new_f = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{fn_ext}" - new_e = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{en_ext}" - ax.each { |j| - new_f.write f[j] - new_e.write e[j] + + sampled = index.sample(size * conf[:sets]) + + test_strings_source = {} + test_strings_target = {} + + conf[:sets].times { |s| + slice_start_index = (s-1) * size + + source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{source_extension}" + target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{target_extension}" + + sampled.slice(slice_start_index, size).each { |j| + source_file.write source_lines[j] + target_file.write target_lines[j] + test_strings_source[source_lines[j].downcase] = true + test_strings_target[target_lines[j].downcase] = true + } + source_file.close; target_file.close } - new_f.close; new_e.close - new_f = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{fn_ext}" - new_e = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{en_ext}" - b.each { |j| - new_f.write f[j] - new_e.write e[j] + + filtered_index = index.reject{ |j| sampled.include? j } + source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{source_extension}" + target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{target_extension}" + filtered_index.each { |j| + if not test_strings_source.include? source_lines[j].downcase \ + and not test_strings_target.include? target_lines[j] + source_file.write source_lines[j] + target_file.write target_lines[j] + end } - new_f.close; new_e.close + source_file.close; target_file.close + i += 1 } - -- cgit v1.2.3