summaryrefslogtreecommitdiff
path: root/finetune.py
diff options
context:
space:
mode:
Diffstat (limited to 'finetune.py')
-rw-r--r--finetune.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/finetune.py b/finetune.py
new file mode 100644
index 0000000..f747f99
--- /dev/null
+++ b/finetune.py
@@ -0,0 +1,122 @@
+import torch
+from datasets import Dataset, Image
+from transformers import (
+ AutoProcessor,
+ AutoModelForCausalLM,
+ TrainingArguments,
+ BitsAndBytesConfig,
+)
+from trl import SFTTrainer
+from peft import LoraConfig
+from glob import glob
+import json
+import os
+
+
+model_name = "google/gemma-3-4b-it"
+
+bnb_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.bfloat16,
+)
+
+processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
+
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ quantization_config=bnb_config,
+ device_map="auto",
+)
+
+peft_config = LoraConfig(
+ lora_alpha=32,
+ lora_dropout=0.05,
+ r=16,
+ task_type="CAUSAL_LM",
+ bias="none",
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
+)
+
+def make_dataset(base="./baseline"):
+ prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
+ user_prompts = []
+ images = []
+ assistant_replies = []
+ for filename in glob(f"{base}/*.jsonl"):
+ with open(filename, "r") as f:
+ data = json.loads(f.read())
+ print(f"{data=}")
+ image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
+ user_prompts.append(prompt)
+ assistant_replies.append(json.dumps({
+ "English": data["English"],
+ "German": data["Translation"],
+ }, ensure_ascii=False, indent=0))
+ images.append(image_path)
+
+ return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image())
+
+dataset = make_dataset()
+
+def add_chat_text(example):
+ messages = [
+ {"role": "user", "content": example["user"]},
+ {"role": "assistant", "content": example["assistant"]},
+ ]
+ example["text"] = processor.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=False,
+ )
+ return example
+
+dataset = dataset.map(add_chat_text)
+
+def collate(batch):
+ images = [i["image"] for i in batch]
+ texts = [i["text"] for i in batch]
+ out = processor(
+ text=texts,
+ images=images,
+ padding=True,
+ truncation=True,
+ max_length=512,
+ return_tensors="pt",
+ )
+ out["labels"] = out["input_ids"].clone()
+ out["labels"][out["attention_mask"] == 0] = -100
+
+ return out
+
+ds_split = dataset.train_test_split(test_size=0.2, seed=42)
+
+args = TrainingArguments(
+ output_dir="gemma3-mm-sft-lora",
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=24,
+ num_train_epochs=1,
+ learning_rate=1e-5,
+ fp16=False,
+ bf16=True,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": False},
+ optim="paged_adamw_8bit",
+ remove_unused_columns=False,
+ logging_steps=10,
+ save_steps=20,
+)
+
+trainer = SFTTrainer(
+ model=model,
+ train_dataset=ds_split["train"],
+ eval_dataset=ds_split["test"],
+ data_collator=collate,
+ args=args,
+ peft_config=peft_config,
+)
+
+trainer.train()
+trainer.save_model()