#!/usr/bin/env python3 import argparse import json import os import torch from datasets import Dataset, Image # Use PIL? from functools import partial from glob import glob from peft import LoraConfig from transformers import ( AutoProcessor, AutoModelForCausalLM, AutoModelForImageTextToText, TrainingArguments, BitsAndBytesConfig, ) from trl import SFTTrainer def make_dataset(base="./baseline"): # TODO: Make actual hf dataset prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "" user_prompts = [] images = [] assistant_replies = [] for filename in glob(f"{base}/*.jsonl"): with open(filename, "r") as f: data = json.loads(f.read()) image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" user_prompts.append(prompt) assistant_replies.append(json.dumps({ "English": data["English"], "German": data["Translation"], }, ensure_ascii=False, indent=0)) images.append(image_path) return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) def add_chat_text(example, processor): messages = [ {"role": "user", "content": example["user"]}, {"role": "assistant", "content": example["assistant"]}, ] example["text"] = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) return example def collate(batch, processor): # FIXME: Support batch_size > 1 images = [i["image"] for i in batch] texts = [i["text"] for i in batch] out = processor( text=texts, images=images, padding=True, truncation=True, max_length=512, return_tensors="pt", ) image_token_id = [ processor.tokenizer.convert_tokens_to_ids( processor.tokenizer.special_tokens_map["boi_token"] ) ] out["labels"] = out["input_ids"].clone() out["labels"][labels == processor.tokenizer.pad_token_id] = -100 out["labels"][labels == image_token_id] = -100 out["labels"][labels == 262144] = -100 return out def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") parser.add_argument("--batch-size", default=1) parser.add_argument("--gradient-accumulation", default=4) parser.add_argument("--learning-rate", default=1e-4) parser.add_argument("--epochs", default=1) parser.add_argument("--warmup-ratio", default=0.03) parser.add_argument("--scheduler-type", default="constant") parser.add_argument("--logging-steps", default=10) parser.add_argument("--lora-alpha", default=32) parser.add_argument("--lora-dropout", default=0.05) parser.add_argument("--lora-r", default=16) args = parser.parse_args() bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) processor = AutoProcessor.from_pretrained(args.model, use_fast=True) #model = AutoModelForCausalLM.from_pretrained( model = AutoModelForImageTextToText.from_pretrained( args.model, quantization_config=bnb_config, device_map="auto", low_cpu_mem_usage=True, ) peft_config = LoraConfig( lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, r=args.lora_r, task_type="CAUSAL_LM", bias="none", target_modules="all-linear", modules_to_save=["lm_head", "embed_tokens"], ) dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor)) train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor)) args = TrainingArguments( output_dir="gemma3-mm-sft-lora", per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, num_train_epochs=args.epochs, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.scheduler_type, fp16=False, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit remove_unused_columns=False, logging_steps=args.logging_steps, save_strategy="epoch", ) trainer = SFTTrainer( model=model, train_dataset=train_ds, eval_dataset=dev_ds, data_collator=partial(collate, processor=processor), args=args, peft_config=peft_config, ) trainer.train() trainer.save_model() if __name__ == "__main__": main()