summaryrefslogtreecommitdiff
path: root/test.py
blob: 35ceb5f189506e0e2a95d029c6ee7819112af52c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM

from glob import glob
import json
import os
from datasets import Dataset


def make_chat_data(base="./baseline"):
    dataset = []
    for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]:
        with open(filename, "r") as f:
            data = json.loads(f.read())
        image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
        image = Image.open(image_path).convert("RGB")
        chat = [{
                "role": "user",
                "content": [
                    {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."},
                    {"type": "image"}
                ]
                },
                #{ "role": "assistant",
                #  "content": [{"type": "text", "text": '{"English": ' +  json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}]
                #}
        ]
        item = {"image": image, "chat": chat}
        dataset.append(item)

    return Dataset.from_list(dataset)


model_id = "google/gemma-3-4b-it"
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",
)
device = model.device
dataset = make_chat_data()
chat_prompt = processor.tokenizer.apply_chat_template(
    [item["chat"] for item in dataset],
    tokenize=False,
	add_generation_prompt=True,
)

print(dataset[0])

inputs = processor(
    text=chat_prompt,
    images=[item["image"] for item in dataset],
    return_tensors="pt"
).to(device)

print("Keys in the output:", inputs.keys())

input_ids = inputs["input_ids"]
print("\nShape of input_ids:", input_ids.shape)
print("input_ids:", input_ids)
decoded_text = processor.decode(input_ids[0], skip_special_tokens=False)
print("\nDecoded input_ids (showing special tokens):")
print(decoded_text)
pixel_values = inputs["pixel_values"]
print("\n--- Generating a Response ---")
output = model.generate(
    **inputs,
    max_new_tokens=100
)
generated_text = processor.decode(output[0], skip_special_tokens=True)
print("\nModel's response:\n", generated_text)