summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 22:01:39 +0100
committerpks <pks@pks.rocks>2025-11-30 22:01:39 +0100
commit8d49cccecfbbb3aae031b22efe26dc8d357da493 (patch)
tree8eddfb0f59724f147969756bd986f5efedc090fe /finetuning.py
parent1af7511e5fbf3c6e89f0452623eaf663387ccfc2 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py39
1 files changed, 12 insertions, 27 deletions
diff --git a/finetuning.py b/finetuning.py
index f787f95..1a44345 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -5,7 +5,7 @@ import json
import os
import torch
-from datasets import Dataset, Image # Use PIL?
+from datasets import Dataset, Image, load_dataset
from functools import partial
from glob import glob
from peft import LoraConfig
@@ -19,25 +19,6 @@ from transformers import (
from trl import SFTTrainer
-def make_dataset(base="./baseline"): # TODO: Make actual hf dataset
- prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
- user_prompts = []
- images = []
- assistant_replies = []
- for filename in glob(f"{base}/*.jsonl"):
- with open(filename, "r") as f:
- data = json.loads(f.read())
- image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
- user_prompts.append(prompt)
- assistant_replies.append(json.dumps({
- "English": data["English"],
- "German": data["Translation"],
- }, ensure_ascii=False, indent=0))
- images.append(image_path)
-
- return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image())
-
-
def add_chat_text(example, processor):
messages = [
{"role": "user", "content": example["user"]},
@@ -82,6 +63,8 @@ def collate(batch, processor): # FIXME: Support batch_size > 1
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="google/gemma-3-4b-it")
+ parser.add_argument("--optimizer", default="adamw_torch_fused")
+ parser.add_argument("--gradient-checkpointing", action="store_true")
parser.add_argument("--batch-size", default=1)
parser.add_argument("--gradient-accumulation", default=4)
parser.add_argument("--learning-rate", default=1e-4)
@@ -99,11 +82,11 @@ def main():
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
+ bnb_4bit_quant_storage=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
- #model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForImageTextToText.from_pretrained(
args.model,
quantization_config=bnb_config,
@@ -117,12 +100,14 @@ def main():
r=args.lora_r,
task_type="CAUSAL_LM",
bias="none",
- target_modules="all-linear",
- modules_to_save=["lm_head", "embed_tokens"],
+ #target_modules="all-linear",
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
+ #modules_to_save=["lm_head", "embed_tokens"],
)
- dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor))
- train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor))
+ dataset = load_dataset("asdf2k/caption_translation")
+ dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor))
+ train_ds = dataset["train"].map(partial(add_chat_text, processor=processor))
args = TrainingArguments(
output_dir="gemma3-mm-sft-lora",
@@ -134,9 +119,9 @@ def main():
lr_scheduler_type=args.scheduler_type,
fp16=False,
bf16=True,
- gradient_checkpointing=True,
+ gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
- optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit
+ optim=args.optimizer,
remove_unused_columns=False,
logging_steps=args.logging_steps,
save_strategy="epoch",