#!/usr/bin/env python3 import argparse import datasets import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration def clean_str(s): return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): caption = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return caption def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") parser.add_argument("--lora-adapter", default=None) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"]) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test") args = parser.parse_args() model = Gemma3ForConditionalGeneration.from_pretrained( args.model, device_map="cuda", dtype=torch.bfloat16, attn_implementation="eager", ).eval() processor = AutoProcessor.from_pretrained(args.model, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) dataset = datasets.load_dataset(args.dataset)[args.data_subset] if args.mode == "translate": # Generate German translation given English source for x in dataset: sys.stderr.write(f"Processing id={x['id']=}\n") data = json.loads(x["assistant"]) inputs = make_inputs(processor, translation_prompt(data["English"]), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0][input_len:] decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) try: new_data = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") data.update(new_data) print(json.dumps(data)) elif args.mode == "from_scratch": # Generate caption & translation from scratch for x in dataset: image = x["image"] inputs = make_inputs(processor, captioning_prompt(image), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.8, top_p=1.0, top_k=50, eos_token_id=stop_token_ids, disable_compile=True) generation = generation[0][input_len:] decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") elif args.mode == "with_prefix": # Generate German translation given English caption and image for filename in glob("./baseline/files_test/*.jsonl"): image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: data = json.loads(f.read()) prompt = captioning_prompt_with_source(Image.open(image), data["English"]) inputs = make_inputs(processor, prompt, model.device) input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, top_p=1.0, top_k=50) generation = generation[0] # batch size 1 truncated_generation = generation[input_len:] decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename)}", "w") as f: f.write(f"{decoded}\n") else: sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": main()