summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <pks@pks.rocks>2020-09-27 19:18:16 +0200
committerPatrick Simianer <pks@pks.rocks>2020-09-27 19:18:16 +0200
commit6b72035db4a17dd9675cac4cd7adc17f8b1998e2 (patch)
tree57b94e410e7a049823321a3505445f355cf8f3db
parent64e8bdba930479249b8dfbc4b5d4b659a95433f0 (diff)
train-test-split: proper implementation
-rwxr-xr-xtrain-test-split85
1 files changed, 48 insertions, 37 deletions
diff --git a/train-test-split b/train-test-split
index aa55534..6aa4796 100755
--- a/train-test-split
+++ b/train-test-split
@@ -4,53 +4,64 @@ require 'zipf'
require 'optimist'
conf = Optimist::options do
- opt :foreign, "foreign file", :type => :string, :required => true
- opt :english, "english file", :type => :string, :required => true
+ opt :source, "source file", :type => :string, :required => true
+ opt :target, "target file", :type => :string, :required => true
opt :size, "one size", :type => :int, :required => true
opt :repeat, "number of repetitions", :type => :int, :default => 1
- opt :prefix, "prefix for output files", :type => :string
+ opt :prefix, "prefix for output files", :type => :string, :default => "split"
opt :sets, "number of sets", :type => :int, :default => 1
end
-fn = conf[:foreign]
-fn_ext = fn.split('.').last
-f = ReadFile.readlines fn
-en = conf[:english]
-en_ext = en.split('.').last
-e = ReadFile.readlines en
+
+source_filename = conf[:source]
+source_extension = source_filename.split('.').last
+source_lines = ReadFile.readlines source_filename
+
+target_filename = conf[:target]
+target_extension = target_filename.split('.').last
+target_lines = ReadFile.readlines target_filename
+
size = conf[:size]
-nlines_f = `wc -l #{fn}`.split()[0].to_i
-nlines_e = `wc -l #{en}`.split()[0].to_i
-if nlines_f != nlines_e
- STDERR.write "Unbalanced files (#{nlines_f} vs. #{nlines_e}), exiting!\n"
+
+if source_lines.size != target_lines.size
+ STDERR.write "Unbalanced files (#{source_lines.size} vs. #{target_lines.size}), exiting!\n"
exit 1
end
-prefix = conf[:prefix]
-a = (0..nlines_e-1).to_a
-i = 0
-conf[:repeat].times {
- if conf[:repeat] == 1
- infix = ""
- else
- infix = ".#{i}"
- end
- b = a.sample(size)
- ax = a.reject{|j| b.include? j}
+index = (0..source_lines.size-1).to_a
+conf[:repeat].times { |i|
`mkdir split_#{i}`
- new_f = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{fn_ext}"
- new_e = WriteFile.new "split_#{i}/#{prefix}.train#{infix}.#{en_ext}"
- ax.each { |j|
- new_f.write f[j]
- new_e.write e[j]
+
+ sampled = index.sample(size * conf[:sets])
+
+ test_strings_source = {}
+ test_strings_target = {}
+
+ conf[:sets].times { |s|
+ slice_start_index = (s-1) * size
+
+ source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{source_extension}"
+ target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.devtest.#{s}.#{target_extension}"
+
+ sampled.slice(slice_start_index, size).each { |j|
+ source_file.write source_lines[j]
+ target_file.write target_lines[j]
+ test_strings_source[source_lines[j].downcase] = true
+ test_strings_target[target_lines[j].downcase] = true
+ }
+ source_file.close; target_file.close
}
- new_f.close; new_e.close
- new_f = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{fn_ext}"
- new_e = WriteFile.new "split_#{i}/#{prefix}.devtest#{infix}.#{en_ext}"
- b.each { |j|
- new_f.write f[j]
- new_e.write e[j]
+
+ filtered_index = index.reject{ |j| sampled.include? j }
+ source_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{source_extension}"
+ target_file = WriteFile.new "split_#{i}/#{conf[:prefix]}.train.#{target_extension}"
+ filtered_index.each { |j|
+ if not test_strings_source.include? source_lines[j].downcase \
+ and not test_strings_target.include? target_lines[j]
+ source_file.write source_lines[j]
+ target_file.write target_lines[j]
+ end
}
- new_f.close; new_e.close
+ source_file.close; target_file.close
+
i += 1
}
-