diff options
Diffstat (limited to 'test.py')
| -rw-r--r-- | test.py | 75 |
1 files changed, 0 insertions, 75 deletions
diff --git a/test.py b/test.py deleted file mode 100644 index 35ceb5f..0000000 --- a/test.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import requests -from PIL import Image -from transformers import AutoProcessor, AutoModelForCausalLM - -from glob import glob -import json -import os -from datasets import Dataset - - -def make_chat_data(base="./baseline"): - dataset = [] - for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]: - with open(filename, "r") as f: - data = json.loads(f.read()) - image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" - image = Image.open(image_path).convert("RGB") - chat = [{ - "role": "user", - "content": [ - {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}, - {"type": "image"} - ] - }, - #{ "role": "assistant", - # "content": [{"type": "text", "text": '{"English": ' + json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}] - #} - ] - item = {"image": image, "chat": chat} - dataset.append(item) - - return Dataset.from_list(dataset) - - -model_id = "google/gemma-3-4b-it" -processor = AutoProcessor.from_pretrained(model_id, use_fast=True) -model = AutoModelForCausalLM.from_pretrained( - model_id, - dtype=torch.bfloat16, - device_map="auto", - attn_implementation="eager", -) -device = model.device -dataset = make_chat_data() -chat_prompt = processor.tokenizer.apply_chat_template( - [item["chat"] for item in dataset], - tokenize=False, - add_generation_prompt=True, -) - -print(dataset[0]) - -inputs = processor( - text=chat_prompt, - images=[item["image"] for item in dataset], - return_tensors="pt" -).to(device) - -print("Keys in the output:", inputs.keys()) - -input_ids = inputs["input_ids"] -print("\nShape of input_ids:", input_ids.shape) -print("input_ids:", input_ids) -decoded_text = processor.decode(input_ids[0], skip_special_tokens=False) -print("\nDecoded input_ids (showing special tokens):") -print(decoded_text) -pixel_values = inputs["pixel_values"] -print("\n--- Generating a Response ---") -output = model.generate( - **inputs, - max_new_tokens=100 -) -generated_text = processor.decode(output[0], skip_special_tokens=True) -print("\nModel's response:\n", generated_text) |
