diff options
| author | pks <pks@pks.rocks> | 2025-11-20 20:47:31 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-20 20:47:31 +0100 |
| commit | 7b8f054c80b0aecaadd8ae4fc6a6fe37cf1c749f (patch) | |
| tree | 1fe61426027c04fa7ca4b23fb8843fc20a7484b1 /inference.py | |
init
Diffstat (limited to 'inference.py')
| -rw-r--r-- | inference.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..fc8c1cc --- /dev/null +++ b/inference.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +import json +import os +import requests +import sys +import torch + +from glob import glob +from PIL import Image +from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + +def captioning_prompt(image): + return [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] + }, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} + ] + } + ] + +def translation_prompt(source): + return [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a professional English-German translator."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} + ] + } + ] + + +def make_inputs(processor, + messages, + device): + return processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" + ).to(device, dtype=torch.bfloat16) + + +def main(): + model_id = "google/gemma-3-4b-it" + model = Gemma3ForConditionalGeneration.from_pretrained( + model_id, + device_map="cuda", + dtype=torch.bfloat16, + ).eval() + processor = AutoProcessor.from_pretrained(model_id, use_fast=True) + + if len(sys.argv) == 2: + from peft import PeftModel + model = PeftModel.from_pretrained(model, sys.argv[1]) + + if True: + for filename in glob("*.jsonl"): + sys.stderr.write(f"Processing {filename=}\n") + with open(filename, "r+") as f: + data = json.loads(f.read()) + f.seek(0) + + inputs = make_inputs(processor, + translation_prompt(data["English"]), + model.device) + input_len = inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model.generate(**inputs, + max_new_tokens=300, + do_sample=True, + top_p=1.0, + top_k=50) + generation = generation[0][input_len:] + + decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + try: + new_data = json.loads(decoded) + except: + sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + + data.update(new_data) + f.write(json.dumps(data)) + f.truncate() + + sys.stderr.write(f"{decoded=}\n") + else: + for filename in glob("../Images/*.jpg"): + sys.stderr.write(f"Processing {filename=}\n") + inputs = make_inputs(processor, + captioning_prompt(Image.open(filename)), + model.device) + + input_len = inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model.generate(**inputs, + max_new_tokens=300, + do_sample=True, + top_p=1.0, + top_k=50) + generation = generation[0][input_len:] + + decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + try: + _ = json.loads(decoded) + except: + sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + + sys.stderr.write(f"{decoded=}\n") + with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: + f.write(f"{decoded}\n") + + +if __name__ == "__main__": + main() |
