#!/usr/bin/env python3 import argparse import codecs import datasets import json import os import requests import sys import torch from glob import glob from PIL import Image from transformers import AutoProcessor, AutoModelForImageTextToText def clean_str(s): return codecs.decode(s.removeprefix("```json").removesuffix("```").replace("\n", "").strip(), "unicode_escape") def captioning_prompt(image): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] }, { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} ] } ] def captioning_prompt_with_source(image, source): prompt = captioning_prompt(image) prefix = json.dumps({"English": source, ensure_ascii=False)}).removesuffix("}") + ', "German": "' prompt.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return prompt def translation_prompt(source): return [ { "role": "system", "content": [{"type": "text", "text": "You are a professional English-German translator."}] }, { "role": "user", "content": [ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} ] } ] def make_inputs(processor, messages, device): return processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device, dtype=torch.bfloat16) def make_inputs_with_prefix(processor, messages, device): prefix = processor.apply_chat_template( messages, add_generation_prompt=False, tokenize=False, return_dict=True, return_tensors="pt" ).removesuffix("\n") ret = processor( text=prefix, images=messages[1]['content'][0]['image'], # FIXME: That's not great return_tensors="pt" ).to(device, dtype=torch.bfloat16) return ret, prefix def generate_and_parse(model, processor, messages, args, example_id): sys.stderr.write(f"Processing {example_id=}\n") if args.mode == "with_prefix": inputs, prefix = make_inputs_with_prefix(processor, messages, model.device) else: inputs = make_inputs(processor, messages, model.device) input_len = inputs["input_ids"].shape[-1] stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("")] with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=not args.do_not_sample, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, eos_token_id=stop_token_ids, disable_compile=True, ) output_tokens = generation[0][input_len:] output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True)) if args.mode == "with_prefix": output_text = prefix[prefix.index("{"):] + output_text try: return json.loads(output_text) except Exception: sys.stderr.write(f"Error loading JSON from string '{output_text}' for id={example_id}\n") return output_text def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) parser.add_argument("--attention-implementation", default="eager", type=str) parser.add_argument("--lora-adapter", default=None, type=str) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) parser.add_argument("--max-new-tokens", default=300, type=int) parser.add_argument("--top-p", default=1.0, type=int) parser.add_argument("--top-k", default=50, type=int) parser.add_argument("--temperature", default=0.8, type=int) parser.add_argument("--do-not-sample", action="store_true") args = parser.parse_args() model = AutoModelForImageTextToText.from_pretrained( args.model, device_map="cuda", dtype=torch.bfloat16, attn_implementation=args.attention_implementation, ).eval() processor = AutoProcessor.from_pretrained(args.model, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) dataset = datasets.load_dataset(args.dataset)[args.data_subset] if args.mode == "from_scratch": # Generate caption & translation from scratch for x in dataset: output = generate_and_parse( model, processor, captioning_prompt(x["image"]), args, example_id=x["id"], ) print(f"{x['id']}\t{json.dumps(output, ensure_ascii=False)}") elif args.mode == "translate": # Generate German translation given English source for x in dataset: input_data = json.loads(x["assistant"]) output = generate_and_parse( model, processor, translation_prompt(input_data["English"]), args, example_id=x["id"], ) output = {"English": input_data["English"], "German": output["Translation"]} print(f"{x['id']}\t{json.dumps(output, ensure_ascii=False)}") elif args.mode == "with_prefix": # Generate German translation given English caption and image for x in dataset: assistant_output_as_input = json.loads(x["assistant"]) output = generate_and_parse( model, processor, captioning_prompt_with_source(x["image"], assistant_output_as_input["English"]), args, example_id=x["id"], ) print(f"{x['id']}\t{json.dumps(output, ensure_ascii=False)}") else: sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": main()