diff options
| author | pks <pks@pks.rocks> | 2025-11-30 22:40:03 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 22:40:03 +0100 |
| commit | 8c77b1f424f0b00bf76fe959e66e1858fd0672b1 (patch) | |
| tree | cde109a47b84edd3544726e1fd526b308081469c | |
| parent | 14dc940c6787c521e164289ce7ba9ef3cd901bf5 (diff) | |
WIP
| -rwxr-xr-x | finetuning.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/finetuning.py b/finetuning.py index 490855e..ed7e954 100755 --- a/finetuning.py +++ b/finetuning.py @@ -99,14 +99,21 @@ def main(): low_cpu_mem_usage=True, ) + if args.lora_small: + target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"] + modules_to_save = [] + else: + target_modules="all-linear", + modules_to_save=["lm_head", "embed_tokens"], + peft_config = LoraConfig( lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, r=args.lora_r, task_type="CAUSAL_LM", bias="none", - target_modules="all-linear", - modules_to_save=["lm_head", "embed_tokens"], + target_modules=target_modules, + modules_to_save=modules_to_save, ) dataset = load_dataset("asdf2k/caption_translation") |
