diff options
| author | pks <pks@pks.rocks> | 2025-11-30 20:43:50 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 20:43:50 +0100 |
| commit | 4180040332fb8a9864fdfc4499c5d856b17bfd9d (patch) | |
| tree | b14a18a9297bf229485439db6ae51c05977972c8 | |
| parent | 6d990329f807a3606c95a5295bec7b7f95ab9770 (diff) | |
args for finetuning.py
| -rw-r--r-- | finetuning.py | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/finetuning.py b/finetuning.py index 988a027..1de68c7 100644 --- a/finetuning.py +++ b/finetuning.py @@ -5,13 +5,14 @@ import json import os import torch -from datasets import Dataset, Image +from datasets import Dataset, Image # Use PIL? from functools import partial from glob import glob from peft import LoraConfig from transformers import ( AutoProcessor, AutoModelForCausalLM, + AutoModelForImageTextToText, TrainingArguments, BitsAndBytesConfig, ) @@ -62,14 +63,34 @@ def collate(batch, processor): # FIXME: Support batch_size > 1 max_length=512, return_tensors="pt", ) + + image_token_id = [ + processor.tokenizer.convert_tokens_to_ids( + processor.tokenizer.special_tokens_map["boi_token"] + ) + ] + out["labels"] = out["input_ids"].clone() - out["labels"][out["attention_mask"] == 0] = -100 + + out["labels"][labels == processor.tokenizer.pad_token_id] = -100 + out["labels"][labels == image_token_id] = -100 + out["labels"][labels == 262144] = -100 return out def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") + parser.add_argument("--batch-size", default=1) + parser.add_argument("--gradient-accumulation", default=4) + parser.add_argument("--learning-rate", default=1e-4) + parser.add_argument("--epochs", default=1) + parser.add_argument("--warump-ratio", default=0.03) + parser.add_argument("--scheduler-type", default="constant") + parser.add_argument("--logging-steps", default=10) + parser.add_argument("--lora-alpha", default=32) + parser.add_argument("--lora-dropout", default=0.05) + parser.add_argument("--lora-r", default=16) args = parser.parse_args() bnb_config = BitsAndBytesConfig( @@ -81,16 +102,18 @@ def main(): processor = AutoProcessor.from_pretrained(args.model, use_fast=True) - model = AutoModelForCausalLM.from_pretrained( + #model = AutoModelForCausalLM.from_pretrained( + model = AutoModelForImageTextToText.from_pretrained( args.model, quantization_config=bnb_config, device_map="auto", + low_cpu_mem_usage=True, ) peft_config = LoraConfig( - lora_alpha=32, - lora_dropout=0.05, - r=16, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + r=args.lora_r, task_type="CAUSAL_LM", bias="none", target_modules="all-linear", @@ -102,18 +125,20 @@ def main(): args = TrainingArguments( output_dir="gemma3-mm-sft-lora", - per_device_train_batch_size=1, - gradient_accumulation_steps=24, - num_train_epochs=3, - learning_rate=1e-5, + per_device_train_batch_size=args.batch_size, + gradient_accumulation_steps=args.gradient_accumulation, + num_train_epochs=args.epochs, + learning_rate=args.learning_rate, + warmup_ratio=args.warmup_ratio, + lr_scheduler_type=args.scheduler_type, fp16=False, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit remove_unused_columns=False, - logging_steps=10, - save_steps=100, + logging_steps=args.logging_steps, + save_strategy="epochs", ) trainer = SFTTrainer( |
