#!/usr/bin/env python3 import argparse import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): caption = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return caption def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") parser.add_argument("--lora-adapter", default=None) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"]) args = parser.parse_args() model = Gemma3ForConditionalGeneration.from_pretrained( args.model_id, device_map="cuda", dtype=torch.bfloat16, ).eval() processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) if args.mode == "translate": # Generate German translation given English source for filename in glob("*.jsonl"): sys.stderr.write(f"Processing {filename=}\n") try: with open(filename, "r+") as f: data = json.loads(f.read()) f.seek(0) inputs = make_inputs(processor, translation_prompt(data["English"]), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: new_data = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") data.update(new_data) f.write(json.dumps(data)) f.truncate() sys.stderr.write(f"{decoded=}\n") except: pass elif args.mode == "from_scratch": # Generate caption & translation from scratch for filename in glob("../d/Images/*.jpg"): image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") inputs = make_inputs(processor, captioning_prompt(Image.open(filename)), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") elif args.mode == "with_prefix": # Generate German translation given English caption and image for filename in glob("./baseline/files_test/*.jsonl"): image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: data = json.loads(f.read()) prompt = captioning_prompt_with_source(Image.open(image), data["English"]) inputs = make_inputs(processor, prompt, model.device) input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0] # batch size 1 truncated_generation = generation[input_len:] decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename)}", "w") as f: f.write(f"{decoded}\n") else: sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": main()