diff options
| author | pks <pks@pks.rocks> | 2025-11-21 23:02:16 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-21 23:02:16 +0100 |
| commit | 1cc209768ad0299c396af56231c43fded7dc3515 (patch) | |
| tree | bc7fc3dcfd094078e052807f4e336dc2a5e0a43b /finetune.py | |
| parent | ea7a98c4f6fc0d0df785ef601e9fc30b78e51335 (diff) | |
WIP
Diffstat (limited to 'finetune.py')
| -rw-r--r-- | finetune.py | 131 |
1 files changed, 69 insertions, 62 deletions
diff --git a/finetune.py b/finetune.py index f747f99..01012c5 100644 --- a/finetune.py +++ b/finetune.py @@ -1,5 +1,11 @@ +import json +import os import torch + from datasets import Dataset, Image +from functools import partial +from glob import glob +from peft import LoraConfig from transformers import ( AutoProcessor, AutoModelForCausalLM, @@ -7,38 +13,7 @@ from transformers import ( BitsAndBytesConfig, ) from trl import SFTTrainer -from peft import LoraConfig -from glob import glob -import json -import os - - -model_name = "google/gemma-3-4b-it" - -bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, -) - -processor = AutoProcessor.from_pretrained(model_name, use_fast=True) - -model = AutoModelForCausalLM.from_pretrained( - model_name, - quantization_config=bnb_config, - device_map="auto", -) -peft_config = LoraConfig( - lora_alpha=32, - lora_dropout=0.05, - r=16, - task_type="CAUSAL_LM", - bias="none", - target_modules="all-linear", - modules_to_save=["lm_head", "embed_tokens"], -) def make_dataset(base="./baseline"): prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>" @@ -59,9 +34,8 @@ def make_dataset(base="./baseline"): return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) -dataset = make_dataset() -def add_chat_text(example): +def add_chat_text(example, processor): messages = [ {"role": "user", "content": example["user"]}, {"role": "assistant", "content": example["assistant"]}, @@ -71,11 +45,11 @@ def add_chat_text(example): tokenize=False, add_generation_prompt=False, ) + return example -dataset = dataset.map(add_chat_text) -def collate(batch): +def collate(batch, processor): images = [i["image"] for i in batch] texts = [i["text"] for i in batch] out = processor( @@ -91,32 +65,65 @@ def collate(batch): return out -ds_split = dataset.train_test_split(test_size=0.2, seed=42) - -args = TrainingArguments( - output_dir="gemma3-mm-sft-lora", - per_device_train_batch_size=1, - gradient_accumulation_steps=24, - num_train_epochs=1, - learning_rate=1e-5, - fp16=False, - bf16=True, - gradient_checkpointing=True, - gradient_checkpointing_kwargs={"use_reentrant": False}, - optim="paged_adamw_8bit", - remove_unused_columns=False, - logging_steps=10, - save_steps=20, -) +def main(): + model_name = "google/gemma-3-4b-it" + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + processor = AutoProcessor.from_pretrained(model_name, use_fast=True) + + model = AutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=bnb_config, + device_map="auto", + ) + + peft_config = LoraConfig( + lora_alpha=32, + lora_dropout=0.05, + r=16, + task_type="CAUSAL_LM", + bias="none", + target_modules="all-linear", + modules_to_save=["lm_head", "embed_tokens"], + ) + + dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor)) + train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor)) + + args = TrainingArguments( + output_dir="gemma3-mm-sft-lora", + per_device_train_batch_size=1, + gradient_accumulation_steps=24, + num_train_epochs=3, + learning_rate=1e-5, + fp16=False, + bf16=True, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + optim="paged_adamw_8bit", + remove_unused_columns=False, + logging_steps=10, + save_steps=100, + ) + + trainer = SFTTrainer( + model=model, + train_dataset=train_ds, + eval_dataset=dev_ds, + data_collator=partial(collate, processor=processor), + args=args, + peft_config=peft_config, + ) + + trainer.train() + trainer.save_model() -trainer = SFTTrainer( - model=model, - train_dataset=ds_split["train"], - eval_dataset=ds_split["test"], - data_collator=collate, - args=args, - peft_config=peft_config, -) -trainer.train() -trainer.save_model() +if __name__ == "__main__": + main() |
