diff options
| author | pks <pks@pks.rocks> | 2025-11-29 23:56:12 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-29 23:56:12 +0100 |
| commit | fc6e8636014dbaf882a2fa5685497dc225ee8e29 (patch) | |
| tree | 0ecd23bd0e8b2ff8aba7b7085657919b91f3bc7c /finetuning.py | |
| parent | abb0df2560b53ec624d8ecaf8444a679fdadb6b2 (diff) | |
WIP
Diffstat (limited to 'finetuning.py')
| -rw-r--r-- | finetuning.py | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/finetuning.py b/finetuning.py new file mode 100644 index 0000000..988a027 --- /dev/null +++ b/finetuning.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import torch + +from datasets import Dataset, Image +from functools import partial +from glob import glob +from peft import LoraConfig +from transformers import ( + AutoProcessor, + AutoModelForCausalLM, + TrainingArguments, + BitsAndBytesConfig, +) +from trl import SFTTrainer + + +def make_dataset(base="./baseline"): # TODO: Make actual hf dataset + prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>" + user_prompts = [] + images = [] + assistant_replies = [] + for filename in glob(f"{base}/*.jsonl"): + with open(filename, "r") as f: + data = json.loads(f.read()) + image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" + user_prompts.append(prompt) + assistant_replies.append(json.dumps({ + "English": data["English"], + "German": data["Translation"], + }, ensure_ascii=False, indent=0)) + images.append(image_path) + + return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) + + +def add_chat_text(example, processor): + messages = [ + {"role": "user", "content": example["user"]}, + {"role": "assistant", "content": example["assistant"]}, + ] + example["text"] = processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + + return example + + +def collate(batch, processor): # FIXME: Support batch_size > 1 + images = [i["image"] for i in batch] + texts = [i["text"] for i in batch] + out = processor( + text=texts, + images=images, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt", + ) + out["labels"] = out["input_ids"].clone() + out["labels"][out["attention_mask"] == 0] = -100 + + return out + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="google/gemma-3-4b-it") + args = parser.parse_args() + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + processor = AutoProcessor.from_pretrained(args.model, use_fast=True) + + model = AutoModelForCausalLM.from_pretrained( + args.model, + quantization_config=bnb_config, + device_map="auto", + ) + + peft_config = LoraConfig( + lora_alpha=32, + lora_dropout=0.05, + r=16, + task_type="CAUSAL_LM", + bias="none", + target_modules="all-linear", + modules_to_save=["lm_head", "embed_tokens"], + ) + + dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor)) + train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor)) + + args = TrainingArguments( + output_dir="gemma3-mm-sft-lora", + per_device_train_batch_size=1, + gradient_accumulation_steps=24, + num_train_epochs=3, + learning_rate=1e-5, + fp16=False, + bf16=True, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit + remove_unused_columns=False, + logging_steps=10, + save_steps=100, + ) + + trainer = SFTTrainer( + model=model, + train_dataset=train_ds, + eval_dataset=dev_ds, + data_collator=partial(collate, processor=processor), + args=args, + peft_config=peft_config, + ) + + trainer.train() + trainer.save_model() + + +if __name__ == "__main__": + main() |
