summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-20 20:47:31 +0100
committerpks <pks@pks.rocks>2025-11-20 20:47:31 +0100
commit7b8f054c80b0aecaadd8ae4fc6a6fe37cf1c749f (patch)
tree1fe61426027c04fa7ca4b23fb8843fc20a7484b1 /inference.py
init
Diffstat (limited to 'inference.py')
-rw-r--r--inference.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000..fc8c1cc
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+
+import json
+import os
+import requests
+import sys
+import torch
+
+from glob import glob
+from PIL import Image
+from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+
+
+def captioning_prompt(image):
+ return [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}]
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}
+ ]
+ }
+ ]
+
+def translation_prompt(source):
+ return [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are a professional English-German translator."}]
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"}
+ ]
+ }
+ ]
+
+
+def make_inputs(processor,
+ messages,
+ device):
+ return processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt"
+ ).to(device, dtype=torch.bfloat16)
+
+
+def main():
+ model_id = "google/gemma-3-4b-it"
+ model = Gemma3ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="cuda",
+ dtype=torch.bfloat16,
+ ).eval()
+ processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
+
+ if len(sys.argv) == 2:
+ from peft import PeftModel
+ model = PeftModel.from_pretrained(model, sys.argv[1])
+
+ if True:
+ for filename in glob("*.jsonl"):
+ sys.stderr.write(f"Processing {filename=}\n")
+ with open(filename, "r+") as f:
+ data = json.loads(f.read())
+ f.seek(0)
+
+ inputs = make_inputs(processor,
+ translation_prompt(data["English"]),
+ model.device)
+ input_len = inputs["input_ids"].shape[-1]
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0][input_len:]
+
+ decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ try:
+ new_data = json.loads(decoded)
+ except:
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ data.update(new_data)
+ f.write(json.dumps(data))
+ f.truncate()
+
+ sys.stderr.write(f"{decoded=}\n")
+ else:
+ for filename in glob("../Images/*.jpg"):
+ sys.stderr.write(f"Processing {filename=}\n")
+ inputs = make_inputs(processor,
+ captioning_prompt(Image.open(filename)),
+ model.device)
+
+ input_len = inputs["input_ids"].shape[-1]
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0][input_len:]
+
+ decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ try:
+ _ = json.loads(decoded)
+ except:
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ sys.stderr.write(f"{decoded=}\n")
+ with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
+ f.write(f"{decoded}\n")
+
+
+if __name__ == "__main__":
+ main()