diff options
| author | pks <pks@pks.rocks> | 2025-11-20 20:47:31 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-20 20:47:31 +0100 |
| commit | 7b8f054c80b0aecaadd8ae4fc6a6fe37cf1c749f (patch) | |
| tree | 1fe61426027c04fa7ca4b23fb8843fc20a7484b1 /test.py | |
init
Diffstat (limited to 'test.py')
| -rw-r--r-- | test.py | 75 |
1 files changed, 75 insertions, 0 deletions
@@ -0,0 +1,75 @@ +import torch +import requests +from PIL import Image +from transformers import AutoProcessor, AutoModelForCausalLM + +from glob import glob +import json +import os +from datasets import Dataset + + +def make_chat_data(base="./baseline"): + dataset = [] + for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]: + with open(filename, "r") as f: + data = json.loads(f.read()) + image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" + image = Image.open(image_path).convert("RGB") + chat = [{ + "role": "user", + "content": [ + {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}, + {"type": "image"} + ] + }, + #{ "role": "assistant", + # "content": [{"type": "text", "text": '{"English": ' + json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}] + #} + ] + item = {"image": image, "chat": chat} + dataset.append(item) + + return Dataset.from_list(dataset) + + +model_id = "google/gemma-3-4b-it" +processor = AutoProcessor.from_pretrained(model_id, use_fast=True) +model = AutoModelForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + device_map="auto", + attn_implementation="eager", +) +device = model.device +dataset = make_chat_data() +chat_prompt = processor.tokenizer.apply_chat_template( + [item["chat"] for item in dataset], + tokenize=False, + add_generation_prompt=True, +) + +print(dataset[0]) + +inputs = processor( + text=chat_prompt, + images=[item["image"] for item in dataset], + return_tensors="pt" +).to(device) + +print("Keys in the output:", inputs.keys()) + +input_ids = inputs["input_ids"] +print("\nShape of input_ids:", input_ids.shape) +print("input_ids:", input_ids) +decoded_text = processor.decode(input_ids[0], skip_special_tokens=False) +print("\nDecoded input_ids (showing special tokens):") +print(decoded_text) +pixel_values = inputs["pixel_values"] +print("\n--- Generating a Response ---") +output = model.generate( + **inputs, + max_new_tokens=100 +) +generated_text = processor.decode(output[0], skip_special_tokens=True) +print("\nModel's response:\n", generated_text) |
