summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rw-r--r--test.py75
1 files changed, 0 insertions, 75 deletions
diff --git a/test.py b/test.py
deleted file mode 100644
index 35ceb5f..0000000
--- a/test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import torch
-import requests
-from PIL import Image
-from transformers import AutoProcessor, AutoModelForCausalLM
-
-from glob import glob
-import json
-import os
-from datasets import Dataset
-
-
-def make_chat_data(base="./baseline"):
- dataset = []
- for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]:
- with open(filename, "r") as f:
- data = json.loads(f.read())
- image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
- image = Image.open(image_path).convert("RGB")
- chat = [{
- "role": "user",
- "content": [
- {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."},
- {"type": "image"}
- ]
- },
- #{ "role": "assistant",
- # "content": [{"type": "text", "text": '{"English": ' + json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}]
- #}
- ]
- item = {"image": image, "chat": chat}
- dataset.append(item)
-
- return Dataset.from_list(dataset)
-
-
-model_id = "google/gemma-3-4b-it"
-processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
-model = AutoModelForCausalLM.from_pretrained(
- model_id,
- dtype=torch.bfloat16,
- device_map="auto",
- attn_implementation="eager",
-)
-device = model.device
-dataset = make_chat_data()
-chat_prompt = processor.tokenizer.apply_chat_template(
- [item["chat"] for item in dataset],
- tokenize=False,
- add_generation_prompt=True,
-)
-
-print(dataset[0])
-
-inputs = processor(
- text=chat_prompt,
- images=[item["image"] for item in dataset],
- return_tensors="pt"
-).to(device)
-
-print("Keys in the output:", inputs.keys())
-
-input_ids = inputs["input_ids"]
-print("\nShape of input_ids:", input_ids.shape)
-print("input_ids:", input_ids)
-decoded_text = processor.decode(input_ids[0], skip_special_tokens=False)
-print("\nDecoded input_ids (showing special tokens):")
-print(decoded_text)
-pixel_values = inputs["pixel_values"]
-print("\n--- Generating a Response ---")
-output = model.generate(
- **inputs,
- max_new_tokens=100
-)
-generated_text = processor.decode(output[0], skip_special_tokens=True)
-print("\nModel's response:\n", generated_text)