summaryrefslogtreecommitdiff
path: root/format-data.py
blob: 60c13d7cdb4b8ece547d3fafa7a5340ebc1cfea9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json

from glob import glob

from PIL import Image


def format_example_for_gemma3_preferences(data, target_score, translation_score):
    prompt = """<bos><start_of_turn>user
You are a professional English-German translator and also a renowned photography critic.

<start_of_image>
Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn>"""
    translation = f"""<start_of_turn>model
{data['Translation']}<end_of_turn>"""
    target = f"""<start_of_turn>model
{data['German']}<end_of_turn>"""

    if target_score == translation_score:
        return None, None, None
    elif target_score > translation_score:
        return prompt, target, translation
    else:
        return prompt, translation, target


def main():
    with open("baseline/target.gemba-gpt4.1.scores", "r") as f:
        scores_target = [int(line.strip()) for line in f.readlines()]
    with open("baseline/translation.gemba-gpt4.1.scores", "r") as f:
        scores_translation = [int(line.strip()) for line in f.readlines()]

    for index, filename in enumerate(sorted(glob("baseline/*.jsonl"))):
        with open(filename, "r") as f:
            data = json.loads(f.read())
        prompt, c, r = format_example_for_gemma3_preferences(data, scores_target[index], scores_translation[index])
        print(f"{c=} {scores_target[index] > scores_translation[index]}")


    from transformers import AutoTokenizer
    model_id = "google/gemma-3-4b-it"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
        

if __name__ == "__main__":
    main()