summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 22:30:01 +0100
committerpks <pks@pks.rocks>2025-11-30 22:30:01 +0100
commit14dc940c6787c521e164289ce7ba9ef3cd901bf5 (patch)
tree98eea1387740609c5c2631bfcda5f24443652b77 /finetuning.py
parent3dbd2c5d3ecb70b778bf7192ab12b98bb5c2ec12 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/finetuning.py b/finetuning.py
index 4bf82d9..490855e 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -75,7 +75,7 @@ def main():
parser.add_argument("--lora-dropout", default=0.05, type=float)
parser.add_argument("--lora-r", default=16, type=int)
parser.add_argument("--bnb4bit", action="store_true")
- parser.add_argument("--max-grad-norm", default=1.0, type="float")
+ parser.add_argument("--max-grad-norm", default=1.0, type=float)
parser.add_argument("--max-length", default=512, type=int)
args = parser.parse_args()
@@ -113,7 +113,7 @@ def main():
dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor))
train_ds = dataset["train"].map(partial(add_chat_text, processor=processor))
- args = TrainingArguments(
+ training_args = TrainingArguments(
output_dir="gemma3-mm-sft-lora",
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
@@ -137,7 +137,7 @@ def main():
train_dataset=train_ds,
eval_dataset=dev_ds,
data_collator=partial(collate, processor=processor, max_length=args.max_length),
- args=args,
+ args=training_args,
peft_config=peft_config,
)