diff options
| -rwxr-xr-x | finetuning.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/finetuning.py b/finetuning.py index a0ad1fe..935adfa 100755 --- a/finetuning.py +++ b/finetuning.py @@ -5,7 +5,7 @@ import json import os import torch -from datasets import Dataset, Image, load_dataset +from datasets import Dataset, load_dataset from functools import partial from glob import glob from peft import LoraConfig @@ -75,10 +75,10 @@ def main(): parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) parser.add_argument("--bnb-4bit", action="store_true") - parser.add_argument("--lora-small", action="store_true") parser.add_argument("--max-grad-norm", default=1.0, type=float) parser.add_argument("--max-length", default=512, type=int) - parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M") + parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M", type=str) + parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) args = parser.parse_args() if args.bnb_4bit: @@ -120,7 +120,7 @@ def main(): peft_config = LoraConfig(**lora_kwargs) - dataset = load_dataset("asdf2k/caption_translation") + dataset = load_dataset(args.dataset) dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor)) train_ds = dataset["train"].map(partial(add_chat_text, processor=processor)) |
