diff options
| author | pks <pks@pks.rocks> | 2025-11-30 22:01:39 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 22:01:39 +0100 |
| commit | 8d49cccecfbbb3aae031b22efe26dc8d357da493 (patch) | |
| tree | 8eddfb0f59724f147969756bd986f5efedc090fe /finetuning.py | |
| parent | 1af7511e5fbf3c6e89f0452623eaf663387ccfc2 (diff) | |
WIP
Diffstat (limited to 'finetuning.py')
| -rwxr-xr-x | finetuning.py | 39 |
1 files changed, 12 insertions, 27 deletions
diff --git a/finetuning.py b/finetuning.py index f787f95..1a44345 100755 --- a/finetuning.py +++ b/finetuning.py @@ -5,7 +5,7 @@ import json import os import torch -from datasets import Dataset, Image # Use PIL? +from datasets import Dataset, Image, load_dataset from functools import partial from glob import glob from peft import LoraConfig @@ -19,25 +19,6 @@ from transformers import ( from trl import SFTTrainer -def make_dataset(base="./baseline"): # TODO: Make actual hf dataset - prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>" - user_prompts = [] - images = [] - assistant_replies = [] - for filename in glob(f"{base}/*.jsonl"): - with open(filename, "r") as f: - data = json.loads(f.read()) - image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" - user_prompts.append(prompt) - assistant_replies.append(json.dumps({ - "English": data["English"], - "German": data["Translation"], - }, ensure_ascii=False, indent=0)) - images.append(image_path) - - return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) - - def add_chat_text(example, processor): messages = [ {"role": "user", "content": example["user"]}, @@ -82,6 +63,8 @@ def collate(batch, processor): # FIXME: Support batch_size > 1 def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", default="google/gemma-3-4b-it") + parser.add_argument("--optimizer", default="adamw_torch_fused") + parser.add_argument("--gradient-checkpointing", action="store_true") parser.add_argument("--batch-size", default=1) parser.add_argument("--gradient-accumulation", default=4) parser.add_argument("--learning-rate", default=1e-4) @@ -99,11 +82,11 @@ def main(): bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_storage=torch.bfloat16, ) processor = AutoProcessor.from_pretrained(args.model, use_fast=True) - #model = AutoModelForCausalLM.from_pretrained( model = AutoModelForImageTextToText.from_pretrained( args.model, quantization_config=bnb_config, @@ -117,12 +100,14 @@ def main(): r=args.lora_r, task_type="CAUSAL_LM", bias="none", - target_modules="all-linear", - modules_to_save=["lm_head", "embed_tokens"], + #target_modules="all-linear", + target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + #modules_to_save=["lm_head", "embed_tokens"], ) - dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor)) - train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor)) + dataset = load_dataset("asdf2k/caption_translation") + dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor)) + train_ds = dataset["train"].map(partial(add_chat_text, processor=processor)) args = TrainingArguments( output_dir="gemma3-mm-sft-lora", @@ -134,9 +119,9 @@ def main(): lr_scheduler_type=args.scheduler_type, fp16=False, bf16=True, - gradient_checkpointing=True, + gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, - optim="adamw_torch_8bit", # Alternative from BnB: paged_adamw_8bit + optim=args.optimizer, remove_unused_columns=False, logging_steps=args.logging_steps, save_strategy="epoch", |
