diff options
| author | pks <pks@pks.rocks> | 2025-12-05 22:27:37 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-05 22:27:37 +0100 |
| commit | 64f714542dbe8ee015afa94b3418d8f51c558070 (patch) | |
| tree | af8692a3de61bbf4eafebc34f5d7576a3519e226 /inference.py | |
| parent | dc4103e2c3c882134f1cdc9f2745aa3903f9aea7 (diff) | |
WIP
Diffstat (limited to 'inference.py')
| -rwxr-xr-x | inference.py | 184 |
1 files changed, 0 insertions, 184 deletions
diff --git a/inference.py b/inference.py deleted file mode 100755 index c8adb16..0000000 --- a/inference.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import datasets -import json -import os -import requests -import sys -import torch - -from glob import glob -from PIL import Image -from transformers import AutoProcessor, AutoModelForImageTextToText - - -def clean_str(s): - return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() - - -def captioning_prompt(image): - return [ - { - "role": "system", - "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] - }, - { - "role": "user", - "content": [ - {"type": "image", "image": image}, - {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} - ] - } - ] - - -def captioning_prompt_with_source(image, source): - caption = captioning_prompt(image) - prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' - - caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) - - return caption - - -def translation_prompt(source): - return [ - { - "role": "system", - "content": [{"type": "text", "text": "You are a professional English-German translator."}] - }, - { - "role": "user", - "content": [ - {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} - ] - } - ] - - -def make_inputs(processor, - messages, - device): - return processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt" - ).to(device, dtype=torch.bfloat16) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) - parser.add_argument("--attention-implementation", default="eager", type=str) - parser.add_argument("--lora-adapter", default=None, type=str) - parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) - parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) - parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) - args = parser.parse_args() - - model = AutoModelForImageTextToText.from_pretrained( - args.model, - device_map="cuda", - dtype=torch.bfloat16, - attn_implementation=args.attention_implementation, - ).eval() - processor = AutoProcessor.from_pretrained(args.model, use_fast=True) - - if args.lora_adapter: - from peft import PeftModel - model = PeftModel.from_pretrained(model, args.lora_adapter) - - dataset = datasets.load_dataset(args.dataset)[args.data_subset] - - if args.mode == "translate": # Generate German translation given English source - for x in dataset: - sys.stderr.write(f"Processing id={x['id']}\n") - - data = json.loads(x["assistant"]) - - inputs = make_inputs(processor, - translation_prompt(data["English"]), - model.device) - input_len = inputs["input_ids"].shape[-1] - - with torch.inference_mode(): - output = model.generate(**inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - disable_compile=True) - output = generation[0][input_len:] - - output = clean_str(processor.decode(output, skip_special_tokens=True)) - - try: - output = json.loads(output) - except: - sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") - - print(json.dumps(output)) - - elif args.mode == "from_scratch": # Generate caption & translation from scratch - for x in dataset: - image = x["image"] - inputs = make_inputs(processor, - captioning_prompt(image), - model.device) - input_len = inputs["input_ids"].shape[-1] - - with torch.inference_mode(): - output = model.generate(**inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - eos_token_id=stop_token_ids, - disable_compile=True) - output = output[0][input_len:] - - output = clean_str(processor.decode(output, skip_special_tokens=True)) - try: - output = json.loads(output) - except: - sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") - - print(json.dumps(output)) - - elif args.mode == "with_prefix": # Generate German translation given English caption and image - for x in dataset: - sys.stderr.write(f"Processing id={x['id']}\n") - data = json.loads(x['assistant_reply']) - prompt = captioning_prompt_with_source(Image.open(image), data["English"]) - inputs = make_inputs(processor, - prompt, - model.device) - input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix - - with torch.inference_mode(): - output = model.generate(**inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - args.temperature, - top_p=args.top_p, - top_k=args.top_k) - output = generation[0][input_len:] - - output = clean_str(processor.decode(output, skip_special_tokens=True)) - try: - output = json.loads(output) - except: - sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") - - print(json.dumps(output)) - else: - sys.stderr.write(f"Unkown mode '{args.mode}'") - - -if __name__ == "__main__": - main() |
