#!/usr/bin/env python3 import argparse import json import os import torch from datasets import Dataset, load_dataset from functools import partial from glob import glob from peft import LoraConfig from transformers import ( AutoProcessor, AutoModelForImageTextToText, TrainingArguments, BitsAndBytesConfig, ) from trl import SFTTrainer def add_chat_text(example, processor): messages = [ {"role": "user", "content": example["user"]}, {"role": "assistant", "content": example["assistant"]}, ] example["text"] = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) return example def collate(batch, processor, max_length): images = [[i["image"]] for i in batch] texts = [i["text"] for i in batch] processor_output = processor( text=texts, images=images, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) image_token_id = [ processor.tokenizer.convert_tokens_to_ids( processor.tokenizer.special_tokens_map["boi_token"] ) ] labels = processor_output["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 labels[labels == image_token_id] = -100 labels[labels == 262144] = -100 processor_output["labels"] = labels return processor_output def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") parser.add_argument("--optimizer", default="adamw_torch_fused") parser.add_argument("--gradient-checkpointing", action="store_true") parser.add_argument("--batch-size", default=1, type=int) parser.add_argument("--gradient-accumulation", default=4, type=int) parser.add_argument("--learning-rate", default=2e-4, type=float) parser.add_argument("--epochs", default=1, type=int) parser.add_argument("--warmup-ratio", default=0.03, type=float) parser.add_argument("--scheduler-type", default="constant", type=str) parser.add_argument("--logging-steps", default=10, type=int) parser.add_argument("--lora-alpha", default=32, type=int) parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) parser.add_argument("--bnb-4bit", action="store_true") parser.add_argument("--max-grad-norm", default=1.0, type=float) parser.add_argument("--max-length", default=512, type=int) parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M", type=str) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) args = parser.parse_args() if args.bnb_4bit: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_storage=torch.bfloat16, ) else: bnb_config = None processor = AutoProcessor.from_pretrained(args.model, use_fast=True) model = AutoModelForImageTextToText.from_pretrained( args.model, quantization_config=bnb_config, device_map="auto", low_cpu_mem_usage=True, ) lora_kwargs = { "lora_alpha": args.lora_alpha, "lora_dropout": args.lora_dropout, "r": args.lora_r, "task_type": "CAUSAL_LM", "bias": "none", } if args.lora_config == "S": lora_kwargs["target_modules"] = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"] elif args.lora_config == "M": lora_kwargs["target_modules"] = "all-linear" elif args.lora_config == "L": lora_kwargs["target_modules"] = "all-linear" lora_kwargs["modules_to_save"] = ["lm_head", "embed_tokens"] else: sys.stderr.write(f"Unknown LoRa config: '{args.lora_config}'\n") exit(1) peft_config = LoraConfig(**lora_kwargs) dataset = load_dataset(args.dataset) dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor)) train_ds = dataset["train"].map(partial(add_chat_text, processor=processor)) training_args = TrainingArguments( output_dir="gemma3-mm-sft-lora", per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, num_train_epochs=args.epochs, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, lr_scheduler_type=args.scheduler_type, fp16=False, bf16=True, gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, optim=args.optimizer, max_grad_norm=args.max_grad_norm, remove_unused_columns=False, logging_steps=args.logging_steps, save_strategy="epoch", ) trainer = SFTTrainer( model=model, train_dataset=train_ds, eval_dataset=dev_ds, data_collator=partial(collate, processor=processor, max_length=args.max_length), args=training_args, peft_config=peft_config, ) trainer.train() trainer.save_model() if __name__ == "__main__": main()