summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-29 23:53:52 +0100
committerpks <pks@pks.rocks>2025-11-29 23:53:52 +0100
commitabb0df2560b53ec624d8ecaf8444a679fdadb6b2 (patch)
tree2ed1d3c70a309b6072395d72b8af9bf686a7889e
parent8fd25e95dc733d38db219083e28396045090a556 (diff)
WIP
-rw-r--r--finetune.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/finetune.py b/finetune.py
index ba2a2c7..988a027 100644
--- a/finetune.py
+++ b/finetune.py
@@ -1,3 +1,6 @@
+#!/usr/bin/env python3
+
+import argparse
import json
import os
import torch
@@ -15,7 +18,7 @@ from transformers import (
from trl import SFTTrainer
-def make_dataset(base="./baseline"):
+def make_dataset(base="./baseline"): # TODO: Make actual hf dataset
prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
user_prompts = []
images = []
@@ -23,7 +26,6 @@ def make_dataset(base="./baseline"):
for filename in glob(f"{base}/*.jsonl"):
with open(filename, "r") as f:
data = json.loads(f.read())
- print(f"{data=}")
image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
user_prompts.append(prompt)
assistant_replies.append(json.dumps({
@@ -49,7 +51,7 @@ def add_chat_text(example, processor):
return example
-def collate(batch, processor):
+def collate(batch, processor): # FIXME: Support batch_size > 1
images = [i["image"] for i in batch]
texts = [i["text"] for i in batch]
out = processor(
@@ -66,7 +68,9 @@ def collate(batch, processor):
return out
def main():
- model_name = "google/gemma-3-4b-it"
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", default="google/gemma-3-4b-it")
+ args = parser.parse_args()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
@@ -75,10 +79,10 @@ def main():
bnb_4bit_compute_dtype=torch.bfloat16,
)
- processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
+ processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
- model_name,
+ args.model,
quantization_config=bnb_config,
device_map="auto",
)
@@ -106,8 +110,7 @@ def main():
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
- #optim="paged_adamw_8bit",
- optim="adamw_torch_8bit",
+ optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit
remove_unused_columns=False,
logging_steps=10,
save_steps=100,