#!/usr/bin/env python3 import argparse import datasets import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText def clean_str(s): return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): prompt = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' prompt.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return prompt def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def generate_and_parse(model, processor, messages, args, example_id=None): sys.stderr.write(f"Processing {example_id=}\n") inputs = make_inputs(processor, messages, model.device) input_len = inputs["input_ids"].shape[-1] stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("")] with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=not args.do_not_sample, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, eos_token_id=stop_token_ids, disable_compile=True, ) output_tokens = generation[0][input_len:] output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True)) try: return json.loads(output_text) except Exception: if example_id is not None: sys.stderr.write( f"Error loading JSON from string '{output_text}' for id={example_id}\n" ) else: sys.stderr.write( f"Error loading JSON from string '{output_text}'\n" ) return output_text def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) parser.add_argument("--attention-implementation", default="eager", type=str) parser.add_argument("--lora-adapter", default=None, type=str) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) parser.add_argument("--max-new-tokens", default=300, type=int) parser.add_argument("--top-p", default=1.0, type=int) parser.add_argument("--top-k", default=50, type=int) parser.add_argument("--temperature", default=0.8, type=int) parser.add_argument("--do-not-sample", action="store_true") args = parser.parse_args() model = AutoModelForImageTextToText.from_pretrained( args.model, device_map="cuda", dtype=torch.bfloat16, attn_implementation=args.attention_implementation, ).eval() processor = AutoProcessor.from_pretrained(args.model, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) dataset = datasets.load_dataset(args.dataset)[args.data_subset] if args.mode == "from_scratch": # Generate caption & translation from scratch for x in dataset: output = generate_and_parse( model, processor, captioning_prompt(x["image"]), args, example_id=x["id"], ) print(f"{x['id']}\t{json.dumps(output)}") elif args.mode == "translate": # Generate German translation given English source for x in dataset: input_data = json.loads(x["assistant"]) output = generate_and_parse( model, processor, translation_prompt(input_data["English"]), args, example_id=x["id"], ) output = {"English": input_data["English"], "German": output["Translation"]} print(f"{x['id']}\t{json.dumps(output)}") elif args.mode == "with_prefix": # Generate German translation given English caption and image for x in dataset: assistant_output_as_input = json.loads(x["assistant"]) output = generate_and_parse( model, processor, captioning_prompt_with_source(x["image"], assistant_output_as_input["English"]), args, example_id=x["id"], ) print(f"{x['id']}\t{json.dumps(output)}") else: sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": main()