summaryrefslogtreecommitdiff
path: root/format-data.py
diff options
context:
space:
mode:
Diffstat (limited to 'format-data.py')
-rw-r--r--format-data.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/format-data.py b/format-data.py
new file mode 100644
index 0000000..60c13d7
--- /dev/null
+++ b/format-data.py
@@ -0,0 +1,46 @@
+import json
+
+from glob import glob
+
+from PIL import Image
+
+
+def format_example_for_gemma3_preferences(data, target_score, translation_score):
+ prompt = """<bos><start_of_turn>user
+You are a professional English-German translator and also a renowned photography critic.
+
+<start_of_image>
+Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn>"""
+ translation = f"""<start_of_turn>model
+{data['Translation']}<end_of_turn>"""
+ target = f"""<start_of_turn>model
+{data['German']}<end_of_turn>"""
+
+ if target_score == translation_score:
+ return None, None, None
+ elif target_score > translation_score:
+ return prompt, target, translation
+ else:
+ return prompt, translation, target
+
+
+def main():
+ with open("baseline/target.gemba-gpt4.1.scores", "r") as f:
+ scores_target = [int(line.strip()) for line in f.readlines()]
+ with open("baseline/translation.gemba-gpt4.1.scores", "r") as f:
+ scores_translation = [int(line.strip()) for line in f.readlines()]
+
+ for index, filename in enumerate(sorted(glob("baseline/*.jsonl"))):
+ with open(filename, "r") as f:
+ data = json.loads(f.read())
+ prompt, c, r = format_example_for_gemma3_preferences(data, scores_target[index], scores_translation[index])
+ print(f"{c=} {scores_target[index] > scores_translation[index]}")
+
+
+ from transformers import AutoTokenizer
+ model_id = "google/gemma-3-4b-it"
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+
+if __name__ == "__main__":
+ main()