import json import os import torch from datasets import Dataset, Image from functools import partial from glob import glob from peft import LoraConfig from transformers import ( AutoProcessor, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, ) from trl import SFTTrainer def make_dataset(base="./baseline"): prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "" user_prompts = [] images = [] assistant_replies = [] for filename in glob(f"{base}/*.jsonl"): with open(filename, "r") as f: data = json.loads(f.read()) print(f"{data=}") image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" user_prompts.append(prompt) assistant_replies.append(json.dumps({ "English": data["English"], "German": data["Translation"], }, ensure_ascii=False, indent=0)) images.append(image_path) return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) def add_chat_text(example, processor): messages = [ {"role": "user", "content": example["user"]}, {"role": "assistant", "content": example["assistant"]}, ] example["text"] = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) return example def collate(batch, processor): images = [i["image"] for i in batch] texts = [i["text"] for i in batch] out = processor( text=texts, images=images, padding=True, truncation=True, max_length=512, return_tensors="pt", ) out["labels"] = out["input_ids"].clone() out["labels"][out["attention_mask"] == 0] = -100 return out def main(): model_name = "google/gemma-3-4b-it" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) processor = AutoProcessor.from_pretrained(model_name, use_fast=True) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", ) peft_config = LoraConfig( lora_alpha=32, lora_dropout=0.05, r=16, task_type="CAUSAL_LM", bias="none", target_modules="all-linear", modules_to_save=["lm_head", "embed_tokens"], ) dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor)) train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor)) args = TrainingArguments( output_dir="gemma3-mm-sft-lora", per_device_train_batch_size=1, gradient_accumulation_steps=24, num_train_epochs=3, learning_rate=1e-5, fp16=False, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="paged_adamw_8bit", remove_unused_columns=False, logging_steps=10, save_steps=100, ) trainer = SFTTrainer( model=model, train_dataset=train_ds, eval_dataset=dev_ds, data_collator=partial(collate, processor=processor), args=args, peft_config=peft_config, ) trainer.train() trainer.save_model() if __name__ == "__main__": main()