#!/usr/bin/env python3 import argparse import datasets import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText def clean_str(s): return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): caption = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return caption def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) parser.add_argument("--attention-implementation", default="eager", type=str) parser.add_argument("--lora-adapter", default=None, type=str) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) args = parser.parse_args() model = AutoModelForImageTextToText.from_pretrained( args.model, device_map="cuda", dtype=torch.bfloat16, attn_implementation=args.attention_implementation, ).eval() processor = AutoProcessor.from_pretrained(args.model, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) dataset = datasets.load_dataset(args.dataset)[args.data_subset] if args.mode == "translate": # Generate German translation given English source for x in dataset: sys.stderr.write(f"Processing id={x['id']}\n") data = json.loads(x["assistant"]) inputs = make_inputs(processor, translation_prompt(data["English"]), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, disable_compile=True) output = generation[0][input_len:] output = clean_str(processor.decode(output, skip_special_tokens=True)) try: output = json.loads(output) except: sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") print(json.dumps(output)) elif args.mode == "from_scratch": # Generate caption & translation from scratch for x in dataset: image = x["image"] inputs = make_inputs(processor, captioning_prompt(image), model.device) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, eos_token_id=stop_token_ids, disable_compile=True) output = output[0][input_len:] output = clean_str(processor.decode(output, skip_special_tokens=True)) try: output = json.loads(output) except: sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") print(json.dumps(output)) elif args.mode == "with_prefix": # Generate German translation given English caption and image for x in dataset: sys.stderr.write(f"Processing id={x['id']}\n") data = json.loads(x['assistant_reply']) prompt = captioning_prompt_with_source(Image.open(image), data["English"]) inputs = make_inputs(processor, prompt, model.device) input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, args.temperature, top_p=args.top_p, top_k=args.top_k) output = generation[0][input_len:] output = clean_str(processor.decode(output, skip_special_tokens=True)) try: output = json.loads(output) except: sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") print(json.dumps(output)) else: sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": main()