summaryrefslogtreecommitdiff
path: root/finetune.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-21 23:02:16 +0100
committerpks <pks@pks.rocks>2025-11-21 23:02:16 +0100
commit1cc209768ad0299c396af56231c43fded7dc3515 (patch)
treebc7fc3dcfd094078e052807f4e336dc2a5e0a43b /finetune.py
parentea7a98c4f6fc0d0df785ef601e9fc30b78e51335 (diff)
WIP
Diffstat (limited to 'finetune.py')
-rw-r--r--finetune.py131
1 files changed, 69 insertions, 62 deletions
diff --git a/finetune.py b/finetune.py
index f747f99..01012c5 100644
--- a/finetune.py
+++ b/finetune.py
@@ -1,5 +1,11 @@
+import json
+import os
import torch
+
from datasets import Dataset, Image
+from functools import partial
+from glob import glob
+from peft import LoraConfig
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
@@ -7,38 +13,7 @@ from transformers import (
BitsAndBytesConfig,
)
from trl import SFTTrainer
-from peft import LoraConfig
-from glob import glob
-import json
-import os
-
-
-model_name = "google/gemma-3-4b-it"
-
-bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
-)
-
-processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
-
-model = AutoModelForCausalLM.from_pretrained(
- model_name,
- quantization_config=bnb_config,
- device_map="auto",
-)
-peft_config = LoraConfig(
- lora_alpha=32,
- lora_dropout=0.05,
- r=16,
- task_type="CAUSAL_LM",
- bias="none",
- target_modules="all-linear",
- modules_to_save=["lm_head", "embed_tokens"],
-)
def make_dataset(base="./baseline"):
prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
@@ -59,9 +34,8 @@ def make_dataset(base="./baseline"):
return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image())
-dataset = make_dataset()
-def add_chat_text(example):
+def add_chat_text(example, processor):
messages = [
{"role": "user", "content": example["user"]},
{"role": "assistant", "content": example["assistant"]},
@@ -71,11 +45,11 @@ def add_chat_text(example):
tokenize=False,
add_generation_prompt=False,
)
+
return example
-dataset = dataset.map(add_chat_text)
-def collate(batch):
+def collate(batch, processor):
images = [i["image"] for i in batch]
texts = [i["text"] for i in batch]
out = processor(
@@ -91,32 +65,65 @@ def collate(batch):
return out
-ds_split = dataset.train_test_split(test_size=0.2, seed=42)
-
-args = TrainingArguments(
- output_dir="gemma3-mm-sft-lora",
- per_device_train_batch_size=1,
- gradient_accumulation_steps=24,
- num_train_epochs=1,
- learning_rate=1e-5,
- fp16=False,
- bf16=True,
- gradient_checkpointing=True,
- gradient_checkpointing_kwargs={"use_reentrant": False},
- optim="paged_adamw_8bit",
- remove_unused_columns=False,
- logging_steps=10,
- save_steps=20,
-)
+def main():
+ model_name = "google/gemma-3-4b-it"
+
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ )
+
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ quantization_config=bnb_config,
+ device_map="auto",
+ )
+
+ peft_config = LoraConfig(
+ lora_alpha=32,
+ lora_dropout=0.05,
+ r=16,
+ task_type="CAUSAL_LM",
+ bias="none",
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
+ )
+
+ dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor))
+ train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor))
+
+ args = TrainingArguments(
+ output_dir="gemma3-mm-sft-lora",
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=24,
+ num_train_epochs=3,
+ learning_rate=1e-5,
+ fp16=False,
+ bf16=True,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ optim="paged_adamw_8bit",
+ remove_unused_columns=False,
+ logging_steps=10,
+ save_steps=100,
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ train_dataset=train_ds,
+ eval_dataset=dev_ds,
+ data_collator=partial(collate, processor=processor),
+ args=args,
+ peft_config=peft_config,
+ )
+
+ trainer.train()
+ trainer.save_model()
-trainer = SFTTrainer(
- model=model,
- train_dataset=ds_split["train"],
- eval_dataset=ds_split["test"],
- data_collator=collate,
- args=args,
- peft_config=peft_config,
-)
-trainer.train()
-trainer.save_model()
+if __name__ == "__main__":
+ main()