summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-05 22:27:37 +0100
committerpks <pks@pks.rocks>2025-12-05 22:27:37 +0100
commit64f714542dbe8ee015afa94b3418d8f51c558070 (patch)
treeaf8692a3de61bbf4eafebc34f5d7576a3519e226 /inference.py
parentdc4103e2c3c882134f1cdc9f2745aa3903f9aea7 (diff)
WIP
Diffstat (limited to 'inference.py')
-rwxr-xr-xinference.py184
1 files changed, 0 insertions, 184 deletions
diff --git a/inference.py b/inference.py
deleted file mode 100755
index c8adb16..0000000
--- a/inference.py
+++ /dev/null
@@ -1,184 +0,0 @@
-#!/usr/bin/env python3
-
-import argparse
-import datasets
-import json
-import os
-import requests
-import sys
-import torch
-
-from glob import glob
-from PIL import Image
-from transformers import AutoProcessor, AutoModelForImageTextToText
-
-
-def clean_str(s):
- return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip()
-
-
-def captioning_prompt(image):
- return [
- {
- "role": "system",
- "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}]
- },
- {
- "role": "user",
- "content": [
- {"type": "image", "image": image},
- {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}
- ]
- }
- ]
-
-
-def captioning_prompt_with_source(image, source):
- caption = captioning_prompt(image)
- prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "'
-
- caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]})
-
- return caption
-
-
-def translation_prompt(source):
- return [
- {
- "role": "system",
- "content": [{"type": "text", "text": "You are a professional English-German translator."}]
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"}
- ]
- }
- ]
-
-
-def make_inputs(processor,
- messages,
- device):
- return processor.apply_chat_template(
- messages,
- add_generation_prompt=True,
- tokenize=True,
- return_dict=True,
- return_tensors="pt"
- ).to(device, dtype=torch.bfloat16)
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--model", default="google/gemma-3-4b-it", type=str)
- parser.add_argument("--attention-implementation", default="eager", type=str)
- parser.add_argument("--lora-adapter", default=None, type=str)
- parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True)
- parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str)
- parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str)
- args = parser.parse_args()
-
- model = AutoModelForImageTextToText.from_pretrained(
- args.model,
- device_map="cuda",
- dtype=torch.bfloat16,
- attn_implementation=args.attention_implementation,
- ).eval()
- processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
-
- if args.lora_adapter:
- from peft import PeftModel
- model = PeftModel.from_pretrained(model, args.lora_adapter)
-
- dataset = datasets.load_dataset(args.dataset)[args.data_subset]
-
- if args.mode == "translate": # Generate German translation given English source
- for x in dataset:
- sys.stderr.write(f"Processing id={x['id']}\n")
-
- data = json.loads(x["assistant"])
-
- inputs = make_inputs(processor,
- translation_prompt(data["English"]),
- model.device)
- input_len = inputs["input_ids"].shape[-1]
-
- with torch.inference_mode():
- output = model.generate(**inputs,
- max_new_tokens=args.max_new_tokens,
- do_sample=args.do_sample,
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- disable_compile=True)
- output = generation[0][input_len:]
-
- output = clean_str(processor.decode(output, skip_special_tokens=True))
-
- try:
- output = json.loads(output)
- except:
- sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
-
- print(json.dumps(output))
-
- elif args.mode == "from_scratch": # Generate caption & translation from scratch
- for x in dataset:
- image = x["image"]
- inputs = make_inputs(processor,
- captioning_prompt(image),
- model.device)
- input_len = inputs["input_ids"].shape[-1]
-
- with torch.inference_mode():
- output = model.generate(**inputs,
- max_new_tokens=args.max_new_tokens,
- do_sample=args.do_sample,
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- eos_token_id=stop_token_ids,
- disable_compile=True)
- output = output[0][input_len:]
-
- output = clean_str(processor.decode(output, skip_special_tokens=True))
- try:
- output = json.loads(output)
- except:
- sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
-
- print(json.dumps(output))
-
- elif args.mode == "with_prefix": # Generate German translation given English caption and image
- for x in dataset:
- sys.stderr.write(f"Processing id={x['id']}\n")
- data = json.loads(x['assistant_reply'])
- prompt = captioning_prompt_with_source(Image.open(image), data["English"])
- inputs = make_inputs(processor,
- prompt,
- model.device)
- input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix
-
- with torch.inference_mode():
- output = model.generate(**inputs,
- max_new_tokens=args.max_new_tokens,
- do_sample=args.do_sample,
- args.temperature,
- top_p=args.top_p,
- top_k=args.top_k)
- output = generation[0][input_len:]
-
- output = clean_str(processor.decode(output, skip_special_tokens=True))
- try:
- output = json.loads(output)
- except:
- sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
-
- print(json.dumps(output))
- else:
- sys.stderr.write(f"Unkown mode '{args.mode}'")
-
-
-if __name__ == "__main__":
- main()