summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-02 11:17:12 +0100
committerpks <pks@pks.rocks>2025-12-02 11:17:12 +0100
commitd19bd3fbf54f08db5e563b17c64991ef9b9706a6 (patch)
treeb5792882747e0dd5b9cc6412ea578bf337d4bfc3 /finetuning.py
parent0670bfe198e12cb25800cfae48677d90ea978cf2 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/finetuning.py b/finetuning.py
index a0ad1fe..935adfa 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -5,7 +5,7 @@ import json
import os
import torch
-from datasets import Dataset, Image, load_dataset
+from datasets import Dataset, load_dataset
from functools import partial
from glob import glob
from peft import LoraConfig
@@ -75,10 +75,10 @@ def main():
parser.add_argument("--lora-dropout", default=0.05, type=float)
parser.add_argument("--lora-r", default=16, type=int)
parser.add_argument("--bnb-4bit", action="store_true")
- parser.add_argument("--lora-small", action="store_true")
parser.add_argument("--max-grad-norm", default=1.0, type=float)
parser.add_argument("--max-length", default=512, type=int)
- parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M")
+ parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M", type=str)
+ parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str)
args = parser.parse_args()
if args.bnb_4bit:
@@ -120,7 +120,7 @@ def main():
peft_config = LoraConfig(**lora_kwargs)
- dataset = load_dataset("asdf2k/caption_translation")
+ dataset = load_dataset(args.dataset)
dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor))
train_ds = dataset["train"].map(partial(add_chat_text, processor=processor))