1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
#!/usr/bin/env python3
import json
import os
import requests
import sys
import torch
from glob import glob
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
def captioning_prompt(image):
return [
{
"role": "system",
"content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}]
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}
]
}
]
def translation_prompt(source):
return [
{
"role": "system",
"content": [{"type": "text", "text": "You are a professional English-German translator."}]
},
{
"role": "user",
"content": [
{"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"}
]
}
]
def make_inputs(processor,
messages,
device):
return processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(device, dtype=torch.bfloat16)
def main():
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda",
dtype=torch.bfloat16,
).eval()
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
if len(sys.argv) == 2:
from peft import PeftModel
model = PeftModel.from_pretrained(model, sys.argv[1])
if True:
for filename in glob("*.jsonl"):
sys.stderr.write(f"Processing {filename=}\n")
with open(filename, "r+") as f:
data = json.loads(f.read())
f.seek(0)
inputs = make_inputs(processor,
translation_prompt(data["English"]),
model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
do_sample=True,
top_p=1.0,
top_k=50)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
new_data = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
data.update(new_data)
f.write(json.dumps(data))
f.truncate()
sys.stderr.write(f"{decoded=}\n")
else:
for filename in glob("../Images/*.jpg"):
sys.stderr.write(f"Processing {filename=}\n")
inputs = make_inputs(processor,
captioning_prompt(Image.open(filename)),
model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
do_sample=True,
top_p=1.0,
top_k=50)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
_ = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
f.write(f"{decoded}\n")
if __name__ == "__main__":
main()
|