summaryrefslogtreecommitdiff
path: root/inference.py
blob: fc8c1ccaed3f266beb1b37a0c0de685a9bac6d82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3

import json
import os
import requests
import sys
import torch

from glob import glob
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration


def captioning_prompt(image):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}]
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}
            ]
        }
    ]

def translation_prompt(source):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a professional English-German translator."}]
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"}
            ]
        }
    ]


def make_inputs(processor,
                messages,
                device):
    return processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device, dtype=torch.bfloat16)


def main():
    model_id = "google/gemma-3-4b-it"
    model = Gemma3ForConditionalGeneration.from_pretrained(
        model_id,
        device_map="cuda",
        dtype=torch.bfloat16,
    ).eval()
    processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

    if len(sys.argv) == 2:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, sys.argv[1])
    
    if True:
        for filename in glob("*.jsonl"):
            sys.stderr.write(f"Processing {filename=}\n")
            with open(filename, "r+") as f:
                data = json.loads(f.read())
                f.seek(0)

                inputs = make_inputs(processor,
                                     translation_prompt(data["English"]),
                                     model.device)
                input_len = inputs["input_ids"].shape[-1]
             
                with torch.inference_mode():
                    generation = model.generate(**inputs,
                                                max_new_tokens=300,
                                                do_sample=True,
                                                top_p=1.0,
                                                top_k=50)
                    generation = generation[0][input_len:]
                
                decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
                try:
                    new_data = json.loads(decoded)
                except:
                    sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
             
                data.update(new_data)
                f.write(json.dumps(data))
                f.truncate()

                sys.stderr.write(f"{decoded=}\n")
    else:
        for filename in glob("../Images/*.jpg"):
            sys.stderr.write(f"Processing {filename=}\n")
            inputs = make_inputs(processor,
                                 captioning_prompt(Image.open(filename)),
                                 model.device)
  
            input_len = inputs["input_ids"].shape[-1]
  
            with torch.inference_mode():
                generation = model.generate(**inputs,
                                            max_new_tokens=300,
                                            do_sample=True,
                                            top_p=1.0,
                                            top_k=50)
                generation = generation[0][input_len:]
            
            decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
            try:
                _ = json.loads(decoded)
            except:
                sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
  
            sys.stderr.write(f"{decoded=}\n")
            with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
                f.write(f"{decoded}\n")


if __name__ == "__main__":
    main()