diff options
Diffstat (limited to 'finetune.py')
| -rw-r--r-- | finetune.py | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/finetune.py b/finetune.py new file mode 100644 index 0000000..f747f99 --- /dev/null +++ b/finetune.py @@ -0,0 +1,122 @@ +import torch +from datasets import Dataset, Image +from transformers import ( + AutoProcessor, + AutoModelForCausalLM, + TrainingArguments, + BitsAndBytesConfig, +) +from trl import SFTTrainer +from peft import LoraConfig +from glob import glob +import json +import os + + +model_name = "google/gemma-3-4b-it" + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + +processor = AutoProcessor.from_pretrained(model_name, use_fast=True) + +model = AutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=bnb_config, + device_map="auto", +) + +peft_config = LoraConfig( + lora_alpha=32, + lora_dropout=0.05, + r=16, + task_type="CAUSAL_LM", + bias="none", + target_modules="all-linear", + modules_to_save=["lm_head", "embed_tokens"], +) + +def make_dataset(base="./baseline"): + prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>" + user_prompts = [] + images = [] + assistant_replies = [] + for filename in glob(f"{base}/*.jsonl"): + with open(filename, "r") as f: + data = json.loads(f.read()) + print(f"{data=}") + image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg" + user_prompts.append(prompt) + assistant_replies.append(json.dumps({ + "English": data["English"], + "German": data["Translation"], + }, ensure_ascii=False, indent=0)) + images.append(image_path) + + return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image()) + +dataset = make_dataset() + +def add_chat_text(example): + messages = [ + {"role": "user", "content": example["user"]}, + {"role": "assistant", "content": example["assistant"]}, + ] + example["text"] = processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + return example + +dataset = dataset.map(add_chat_text) + +def collate(batch): + images = [i["image"] for i in batch] + texts = [i["text"] for i in batch] + out = processor( + text=texts, + images=images, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt", + ) + out["labels"] = out["input_ids"].clone() + out["labels"][out["attention_mask"] == 0] = -100 + + return out + +ds_split = dataset.train_test_split(test_size=0.2, seed=42) + +args = TrainingArguments( + output_dir="gemma3-mm-sft-lora", + per_device_train_batch_size=1, + gradient_accumulation_steps=24, + num_train_epochs=1, + learning_rate=1e-5, + fp16=False, + bf16=True, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + optim="paged_adamw_8bit", + remove_unused_columns=False, + logging_steps=10, + save_steps=20, +) + +trainer = SFTTrainer( + model=model, + train_dataset=ds_split["train"], + eval_dataset=ds_split["test"], + data_collator=collate, + args=args, + peft_config=peft_config, +) + +trainer.train() +trainer.save_model() |
