diff options
| author | pks <pks@pks.rocks> | 2025-11-29 23:53:52 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-29 23:53:52 +0100 |
| commit | abb0df2560b53ec624d8ecaf8444a679fdadb6b2 (patch) | |
| tree | 2ed1d3c70a309b6072395d72b8af9bf686a7889e | |
| parent | 8fd25e95dc733d38db219083e28396045090a556 (diff) | |
WIP
| -rw-r--r-- | finetune.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/finetune.py b/finetune.py index ba2a2c7..988a027 100644 --- a/finetune.py +++ b/finetune.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python3 + +import argparse import json import os import torch @@ -15,7 +18,7 @@ from transformers import ( from trl import SFTTrainer -def make_dataset(base="./baseline"): +def make_dataset(base="./baseline"): # TODO: Make actual hf dataset prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>" user_prompts = [] images = [] @@ -23,7 +26,6 @@ def make_dataset(base="./baseline"): for filename in glob(f"{base}/*.jsonl"): with open(filename, "r") as f: data = json.loads(f.read()) - print(f"{data=}") image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" user_prompts.append(prompt) assistant_replies.append(json.dumps({ @@ -49,7 +51,7 @@ def add_chat_text(example, processor): return example -def collate(batch, processor): +def collate(batch, processor): # FIXME: Support batch_size > 1 images = [i["image"] for i in batch] texts = [i["text"] for i in batch] out = processor( @@ -66,7 +68,9 @@ def collate(batch, processor): return out def main(): - model_name = "google/gemma-3-4b-it" + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="google/gemma-3-4b-it") + args = parser.parse_args() bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -75,10 +79,10 @@ def main(): bnb_4bit_compute_dtype=torch.bfloat16, ) - processor = AutoProcessor.from_pretrained(model_name, use_fast=True) + processor = AutoProcessor.from_pretrained(args.model, use_fast=True) model = AutoModelForCausalLM.from_pretrained( - model_name, + args.model, quantization_config=bnb_config, device_map="auto", ) @@ -106,8 +110,7 @@ def main(): bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, - #optim="paged_adamw_8bit", - optim="adamw_torch_8bit", + optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit remove_unused_columns=False, logging_steps=10, save_steps=100, |
