summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 20:43:50 +0100
committerpks <pks@pks.rocks>2025-11-30 20:43:50 +0100
commit4180040332fb8a9864fdfc4499c5d856b17bfd9d (patch)
treeb14a18a9297bf229485439db6ae51c05977972c8 /finetuning.py
parent6d990329f807a3606c95a5295bec7b7f95ab9770 (diff)
args for finetuning.py
Diffstat (limited to 'finetuning.py')
-rw-r--r--finetuning.py49
1 files changed, 37 insertions, 12 deletions
diff --git a/finetuning.py b/finetuning.py
index 988a027..1de68c7 100644
--- a/finetuning.py
+++ b/finetuning.py
@@ -5,13 +5,14 @@ import json
import os
import torch
-from datasets import Dataset, Image
+from datasets import Dataset, Image # Use PIL?
from functools import partial
from glob import glob
from peft import LoraConfig
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
+ AutoModelForImageTextToText,
TrainingArguments,
BitsAndBytesConfig,
)
@@ -62,14 +63,34 @@ def collate(batch, processor): # FIXME: Support batch_size > 1
max_length=512,
return_tensors="pt",
)
+
+ image_token_id = [
+ processor.tokenizer.convert_tokens_to_ids(
+ processor.tokenizer.special_tokens_map["boi_token"]
+ )
+ ]
+
out["labels"] = out["input_ids"].clone()
- out["labels"][out["attention_mask"] == 0] = -100
+
+ out["labels"][labels == processor.tokenizer.pad_token_id] = -100
+ out["labels"][labels == image_token_id] = -100
+ out["labels"][labels == 262144] = -100
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="google/gemma-3-4b-it")
+ parser.add_argument("--batch-size", default=1)
+ parser.add_argument("--gradient-accumulation", default=4)
+ parser.add_argument("--learning-rate", default=1e-4)
+ parser.add_argument("--epochs", default=1)
+ parser.add_argument("--warump-ratio", default=0.03)
+ parser.add_argument("--scheduler-type", default="constant")
+ parser.add_argument("--logging-steps", default=10)
+ parser.add_argument("--lora-alpha", default=32)
+ parser.add_argument("--lora-dropout", default=0.05)
+ parser.add_argument("--lora-r", default=16)
args = parser.parse_args()
bnb_config = BitsAndBytesConfig(
@@ -81,16 +102,18 @@ def main():
processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
- model = AutoModelForCausalLM.from_pretrained(
+ #model = AutoModelForCausalLM.from_pretrained(
+ model = AutoModelForImageTextToText.from_pretrained(
args.model,
quantization_config=bnb_config,
device_map="auto",
+ low_cpu_mem_usage=True,
)
peft_config = LoraConfig(
- lora_alpha=32,
- lora_dropout=0.05,
- r=16,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ r=args.lora_r,
task_type="CAUSAL_LM",
bias="none",
target_modules="all-linear",
@@ -102,18 +125,20 @@ def main():
args = TrainingArguments(
output_dir="gemma3-mm-sft-lora",
- per_device_train_batch_size=1,
- gradient_accumulation_steps=24,
- num_train_epochs=3,
- learning_rate=1e-5,
+ per_device_train_batch_size=args.batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation,
+ num_train_epochs=args.epochs,
+ learning_rate=args.learning_rate,
+ warmup_ratio=args.warmup_ratio,
+ lr_scheduler_type=args.scheduler_type,
fp16=False,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit
remove_unused_columns=False,
- logging_steps=10,
- save_steps=100,
+ logging_steps=args.logging_steps,
+ save_strategy="epochs",
)
trainer = SFTTrainer(