diff options
| author | pks <pks@pks.rocks> | 2025-11-30 22:30:01 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 22:30:01 +0100 |
| commit | 14dc940c6787c521e164289ce7ba9ef3cd901bf5 (patch) | |
| tree | 98eea1387740609c5c2631bfcda5f24443652b77 | |
| parent | 3dbd2c5d3ecb70b778bf7192ab12b98bb5c2ec12 (diff) | |
WIP
| -rwxr-xr-x | finetuning.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/finetuning.py b/finetuning.py index 4bf82d9..490855e 100755 --- a/finetuning.py +++ b/finetuning.py @@ -75,7 +75,7 @@ def main(): parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) parser.add_argument("--bnb4bit", action="store_true") - parser.add_argument("--max-grad-norm", default=1.0, type="float") + parser.add_argument("--max-grad-norm", default=1.0, type=float) parser.add_argument("--max-length", default=512, type=int) args = parser.parse_args() @@ -113,7 +113,7 @@ def main(): dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor)) train_ds = dataset["train"].map(partial(add_chat_text, processor=processor)) - args = TrainingArguments( + training_args = TrainingArguments( output_dir="gemma3-mm-sft-lora", per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, @@ -137,7 +137,7 @@ def main(): train_dataset=train_ds, eval_dataset=dev_ds, data_collator=partial(collate, processor=processor, max_length=args.max_length), - args=args, + args=training_args, peft_config=peft_config, ) |
