diff options
Diffstat (limited to 'finetuning.py')
| -rwxr-xr-x | finetuning.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/finetuning.py b/finetuning.py index a46b544..3d7b011 100755 --- a/finetuning.py +++ b/finetuning.py @@ -67,23 +67,27 @@ def main(): parser.add_argument("--gradient-checkpointing", action="store_true") parser.add_argument("--batch-size", default=1, type=int) parser.add_argument("--gradient-accumulation", default=4, type=int) - parser.add_argument("--learning-rate", default=1e-4, type=float) + parser.add_argument("--learning-rate", default=2e-4, type=float) parser.add_argument("--epochs", default=1, type=int) parser.add_argument("--warmup-ratio", default=0.03, type=float) parser.add_argument("--scheduler-type", default="constant", type=str) - parser.add_argument("--logging-steps", default=3, type=int) + parser.add_argument("--logging-steps", default=10, type=int) parser.add_argument("--lora-alpha", default=32, type=int) parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) + parser.add_argument("--4bit", action="store_true") args = parser.parse_args() - - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_quant_storage=torch.bfloat16, - ) + + if args.4bit: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_storage=torch.bfloat16, + ) + else: + bnb_config = None processor = AutoProcessor.from_pretrained(args.model, use_fast=True) |
