summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
committerpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
commitfc6e8636014dbaf882a2fa5685497dc225ee8e29 (patch)
tree0ecd23bd0e8b2ff8aba7b7085657919b91f3bc7c /finetuning.py
parentabb0df2560b53ec624d8ecaf8444a679fdadb6b2 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rw-r--r--finetuning.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/finetuning.py b/finetuning.py
new file mode 100644
index 0000000..988a027
--- /dev/null
+++ b/finetuning.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import os
+import torch
+
+from datasets import Dataset, Image
+from functools import partial
+from glob import glob
+from peft import LoraConfig
+from transformers import (
+ AutoProcessor,
+ AutoModelForCausalLM,
+ TrainingArguments,
+ BitsAndBytesConfig,
+)
+from trl import SFTTrainer
+
+
+def make_dataset(base="./baseline"): # TODO: Make actual hf dataset
+ prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
+ user_prompts = []
+ images = []
+ assistant_replies = []
+ for filename in glob(f"{base}/*.jsonl"):
+ with open(filename, "r") as f:
+ data = json.loads(f.read())
+ image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
+ user_prompts.append(prompt)
+ assistant_replies.append(json.dumps({
+ "English": data["English"],
+ "German": data["Translation"],
+ }, ensure_ascii=False, indent=0))
+ images.append(image_path)
+
+ return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image())
+
+
+def add_chat_text(example, processor):
+ messages = [
+ {"role": "user", "content": example["user"]},
+ {"role": "assistant", "content": example["assistant"]},
+ ]
+ example["text"] = processor.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+
+ return example
+
+
+def collate(batch, processor): # FIXME: Support batch_size > 1
+ images = [i["image"] for i in batch]
+ texts = [i["text"] for i in batch]
+ out = processor(
+ text=texts,
+ images=images,
+ padding=True,
+ truncation=True,
+ max_length=512,
+ return_tensors="pt",
+ )
+ out["labels"] = out["input_ids"].clone()
+ out["labels"][out["attention_mask"] == 0] = -100
+
+ return out
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", default="google/gemma-3-4b-it")
+ args = parser.parse_args()
+
+ bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ )
+
+ processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model,
+ quantization_config=bnb_config,
+ device_map="auto",
+ )
+
+ peft_config = LoraConfig(
+ lora_alpha=32,
+ lora_dropout=0.05,
+ r=16,
+ task_type="CAUSAL_LM",
+ bias="none",
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
+ )
+
+ dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor))
+ train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor))
+
+ args = TrainingArguments(
+ output_dir="gemma3-mm-sft-lora",
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=24,
+ num_train_epochs=3,
+ learning_rate=1e-5,
+ fp16=False,
+ bf16=True,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit
+ remove_unused_columns=False,
+ logging_steps=10,
+ save_steps=100,
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ train_dataset=train_ds,
+ eval_dataset=dev_ds,
+ data_collator=partial(collate, processor=processor),
+ args=args,
+ peft_config=peft_config,
+ )
+
+ trainer.train()
+ trainer.save_model()
+
+
+if __name__ == "__main__":
+ main()