1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
import json
import os
import torch
from datasets import Dataset, Image
from functools import partial
from glob import glob
from peft import LoraConfig
from transformers import (
AutoProcessor,
AutoModelForCausalLM,
TrainingArguments,
BitsAndBytesConfig,
)
from trl import SFTTrainer
def make_dataset(base="./baseline"):
prompt = "You are a professional English-German translator and also a renowned photography critic.\n\nWrite a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else." + "<start_of_image>"
user_prompts = []
images = []
assistant_replies = []
for filename in glob(f"{base}/*.jsonl"):
with open(filename, "r") as f:
data = json.loads(f.read())
print(f"{data=}")
image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
user_prompts.append(prompt)
assistant_replies.append(json.dumps({
"English": data["English"],
"German": data["Translation"],
}, ensure_ascii=False, indent=0))
images.append(image_path)
return Dataset.from_dict({"image": images, "user": user_prompts, "assistant": assistant_replies}).cast_column("image", Image())
def add_chat_text(example, processor):
messages = [
{"role": "user", "content": example["user"]},
{"role": "assistant", "content": example["assistant"]},
]
example["text"] = processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
return example
def collate(batch, processor):
images = [i["image"] for i in batch]
texts = [i["text"] for i in batch]
out = processor(
text=texts,
images=images,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
out["labels"] = out["input_ids"].clone()
out["labels"][out["attention_mask"] == 0] = -100
return out
def main():
model_name = "google/gemma-3-4b-it"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)
peft_config = LoraConfig(
lora_alpha=32,
lora_dropout=0.05,
r=16,
task_type="CAUSAL_LM",
bias="none",
target_modules="all-linear",
modules_to_save=["lm_head", "embed_tokens"],
)
dev_ds = make_dataset("./baseline/files_dev").map(partial(add_chat_text, processor=processor))
train_ds = make_dataset("./baseline/files_train").map(partial(add_chat_text, processor=processor))
args = TrainingArguments(
output_dir="gemma3-mm-sft-lora",
per_device_train_batch_size=1,
gradient_accumulation_steps=24,
num_train_epochs=3,
learning_rate=1e-5,
fp16=False,
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_8bit",
remove_unused_columns=False,
logging_steps=10,
save_steps=100,
)
trainer = SFTTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=dev_ds,
data_collator=partial(collate, processor=processor),
args=args,
peft_config=peft_config,
)
trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()
|