summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 22:53:48 +0100
committerpks <pks@pks.rocks>2025-11-30 22:53:48 +0100
commit82ffee64ad6e8e837d529fd89ff03da9042e52fd (patch)
tree41f5720dfdbcd2d236b84a77fc9930143ee48923 /finetuning.py
parente82be8356620659497d5aaccac8d76435a8391c4 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/finetuning.py b/finetuning.py
index 379f9b0..f4543b7 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -100,22 +100,24 @@ def main():
low_cpu_mem_usage=True,
)
- if args.lora_small:
- target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
- modules_to_save = None
+ lora_kwargs = { "lora_alpha": args.lora_alpha,
+ "lora_dropout": args.lora_dropout,
+ "r": args.lora_r,
+ "task_type": "CAUSAL_LM",
+ "bias": "none",
+ }
+ if args.lora_config == "S":
+ lora_kwargs["target_modules"] = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
+ elif args.lora_config == "M":
+ lora_kwargs["target_modules"] = "all-linear"
+ elif args.lora_config == "L":
+ lora_kwargs["target_modules"] = "all-linear"
+ lora_kwargs["modules_to_save"] = ["lm_head", "embed_tokens"]
else:
- target_modules = "all-linear"
- modules_to_save = ["lm_head", "embed_tokens"]
-
- peft_config = LoraConfig(
- lora_alpha=args.lora_alpha,
- lora_dropout=args.lora_dropout,
- r=args.lora_r,
- task_type="CAUSAL_LM",
- bias="none",
- target_modules=target_modules,
- #modules_to_save=modules_to_save,
- )
+ sys.stderr.write(f"Unknown LoRa config: '{args.lora_config}'\n")
+ exit(1)
+
+ peft_config = LoraConfig(**lora_kwargs)
dataset = load_dataset("asdf2k/caption_translation")
dev_ds = dataset["dev"].map(partial(add_chat_text, processor=processor))