summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-20 20:47:31 +0100
committerpks <pks@pks.rocks>2025-11-20 20:47:31 +0100
commit7b8f054c80b0aecaadd8ae4fc6a6fe37cf1c749f (patch)
tree1fe61426027c04fa7ca4b23fb8843fc20a7484b1 /test.py
init
Diffstat (limited to 'test.py')
-rw-r--r--test.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..35ceb5f
--- /dev/null
+++ b/test.py
@@ -0,0 +1,75 @@
+import torch
+import requests
+from PIL import Image
+from transformers import AutoProcessor, AutoModelForCausalLM
+
+from glob import glob
+import json
+import os
+from datasets import Dataset
+
+
+def make_chat_data(base="./baseline"):
+ dataset = []
+ for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]:
+ with open(filename, "r") as f:
+ data = json.loads(f.read())
+ image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
+ image = Image.open(image_path).convert("RGB")
+ chat = [{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."},
+ {"type": "image"}
+ ]
+ },
+ #{ "role": "assistant",
+ # "content": [{"type": "text", "text": '{"English": ' + json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}]
+ #}
+ ]
+ item = {"image": image, "chat": chat}
+ dataset.append(item)
+
+ return Dataset.from_list(dataset)
+
+
+model_id = "google/gemma-3-4b-it"
+processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation="eager",
+)
+device = model.device
+dataset = make_chat_data()
+chat_prompt = processor.tokenizer.apply_chat_template(
+ [item["chat"] for item in dataset],
+ tokenize=False,
+ add_generation_prompt=True,
+)
+
+print(dataset[0])
+
+inputs = processor(
+ text=chat_prompt,
+ images=[item["image"] for item in dataset],
+ return_tensors="pt"
+).to(device)
+
+print("Keys in the output:", inputs.keys())
+
+input_ids = inputs["input_ids"]
+print("\nShape of input_ids:", input_ids.shape)
+print("input_ids:", input_ids)
+decoded_text = processor.decode(input_ids[0], skip_special_tokens=False)
+print("\nDecoded input_ids (showing special tokens):")
+print(decoded_text)
+pixel_values = inputs["pixel_values"]
+print("\n--- Generating a Response ---")
+output = model.generate(
+ **inputs,
+ max_new_tokens=100
+)
+generated_text = processor.decode(output[0], skip_special_tokens=True)
+print("\nModel's response:\n", generated_text)