import torch import requests from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from glob import glob import json import os from datasets import Dataset def make_chat_data(base="./baseline"): dataset = [] for filename in sorted(glob(f"{base}/*.jsonl"))[0:1]: with open(filename, "r") as f: data = json.loads(f.read()) image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" image = Image.open(image_path).convert("RGB") chat = [{ "role": "user", "content": [ {"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}, {"type": "image"} ] }, #{ "role": "assistant", # "content": [{"type": "text", "text": '{"English": ' + json.dumps(data["English"]) + ', "German": ' + json.dumps(data["Translation"]) + '}'}] #} ] item = {"image": image, "chat": chat} dataset.append(item) return Dataset.from_list(dataset) model_id = "google/gemma-3-4b-it" processor = AutoProcessor.from_pretrained(model_id, use_fast=True) model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.bfloat16, device_map="auto", attn_implementation="eager", ) device = model.device dataset = make_chat_data() chat_prompt = processor.tokenizer.apply_chat_template( [item["chat"] for item in dataset], tokenize=False, add_generation_prompt=True, ) print(dataset[0]) inputs = processor( text=chat_prompt, images=[item["image"] for item in dataset], return_tensors="pt" ).to(device) print("Keys in the output:", inputs.keys()) input_ids = inputs["input_ids"] print("\nShape of input_ids:", input_ids.shape) print("input_ids:", input_ids) decoded_text = processor.decode(input_ids[0], skip_special_tokens=False) print("\nDecoded input_ids (showing special tokens):") print(decoded_text) pixel_values = inputs["pixel_values"] print("\n--- Generating a Response ---") output = model.generate( **inputs, max_new_tokens=100 ) generated_text = processor.decode(output[0], skip_special_tokens=True) print("\nModel's response:\n", generated_text)