#!/usr/bin/env python3 import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): caption = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return caption def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def main(): model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="cuda", dtype=torch.bfloat16, ).eval() processor = AutoProcessor.from_pretrained(model_id, use_fast=True) if len(sys.argv) == 2: from peft import PeftModel model = PeftModel.from_pretrained(model, sys.argv[1]) if True: # Generate German translation given English source for filename in glob("*.jsonl"): sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: data = json.loads(f.read()) f.seek(0) inputs = make_inputs(processor, translation_prompt(data["English"]), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: new_data = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") data.update(new_data) f.write(json.dumps(data)) f.truncate() sys.stderr.write(f"{decoded=}\n") elif 2 == 3: # Generate caption & translation from scratch for filename in glob("../Images/*.jpg"): image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") inputs = make_inputs(processor, captioning_prompt(Image.open(filename)), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") elif False: # Generate German translation given English caption and image for filename in glob("./baseline/files_test/*.jsonl"): image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: data = json.loads(f.read()) prompt = captioning_prompt_with_source(Image.open(image), data["English"]) inputs = make_inputs(processor, prompt, model.device) input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0] # batch size 1 truncated_generation = generation[input_len:] decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename)}", "w") as f: f.write(f"{decoded}\n") if __name__ == "__main__": main()