summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 22:28:41 +0100
committerpks <pks@pks.rocks>2025-11-30 22:28:41 +0100
commit3dbd2c5d3ecb70b778bf7192ab12b98bb5c2ec12 (patch)
tree62182f128ab451a33720e88454db848bbb6098c2 /finetuning.py
parent7fd45074ab47b2316c6a451981ad9f1f28cfead7 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/finetuning.py b/finetuning.py
index b0cf1cb..4bf82d9 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -11,7 +11,6 @@ from glob import glob
from peft import LoraConfig
from transformers import (
AutoProcessor,
- AutoModelForCausalLM,
AutoModelForImageTextToText,
TrainingArguments,
BitsAndBytesConfig,
@@ -33,7 +32,7 @@ def add_chat_text(example, processor):
return example
-def collate(batch, processor): # FIXME: Support batch_size > 1
+def collate(batch, processor, max_length): # FIXME: Support batch_size > 1
images = [i["image"] for i in batch]
texts = [i["text"] for i in batch]
processor_output = processor(
@@ -41,7 +40,7 @@ def collate(batch, processor): # FIXME: Support batch_size > 1
images=images,
padding=True,
truncation=True,
- max_length=512,
+ max_length=max_length,
return_tensors="pt",
)
@@ -76,6 +75,8 @@ def main():
parser.add_argument("--lora-dropout", default=0.05, type=float)
parser.add_argument("--lora-r", default=16, type=int)
parser.add_argument("--bnb4bit", action="store_true")
+ parser.add_argument("--max-grad-norm", default=1.0, type="float")
+ parser.add_argument("--max-length", default=512, type=int)
args = parser.parse_args()
if args.bnb4bit:
@@ -104,9 +105,8 @@ def main():
r=args.lora_r,
task_type="CAUSAL_LM",
bias="none",
- #target_modules="all-linear",
- target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
- #modules_to_save=["lm_head", "embed_tokens"],
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
)
dataset = load_dataset("asdf2k/caption_translation")
@@ -126,6 +126,7 @@ def main():
gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim=args.optimizer,
+ max_grad_norm=args.max_grad_norm,
remove_unused_columns=False,
logging_steps=args.logging_steps,
save_strategy="epoch",
@@ -135,7 +136,7 @@ def main():
model=model,
train_dataset=train_ds,
eval_dataset=dev_ds,
- data_collator=partial(collate, processor=processor),
+ data_collator=partial(collate, processor=processor, max_length=args.max_length),
args=args,
peft_config=peft_config,
)