diff options
| author | pks <pks@pks.rocks> | 2025-12-02 11:17:12 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-02 11:17:12 +0100 |
| commit | d19bd3fbf54f08db5e563b17c64991ef9b9706a6 (patch) | |
| tree | b5792882747e0dd5b9cc6412ea578bf337d4bfc3 | |
| parent | 0670bfe198e12cb25800cfae48677d90ea978cf2 (diff) | |
WIP
| -rwxr-xr-x | finetuning.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/finetuning.py b/finetuning.py index a0ad1fe..935adfa 100755 --- a/finetuning.py +++ b/finetuning.py @@ -5,7 +5,7 @@ import json import os import torch -from datasets import Dataset, Image, load_dataset +from datasets import Dataset, load_dataset from functools import partial from glob import glob from peft import LoraConfig @@ -75,10 +75,10 @@ def main(): parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) parser.add_argument("--bnb-4bit", action="store_true") - parser.add_argument("--lora-small", action="store_true") parser.add_argument("--max-grad-norm", default=1.0, type=float) parser.add_argument("--max-length", default=512, type=int) - parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M") + parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M", type=str) + parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) args = parser.parse_args() if args.bnb_4bit: @@ -120,7 +120,7 @@ def main(): peft_config = LoraConfig(**lora_kwargs) - dataset = load_dataset("asdf2k/caption_translation") + dataset = load_dataset(args.dataset) dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor)) train_ds = dataset["train"].map(partial(add_chat_text, processor=processor)) |
