summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 22:40:03 +0100
committerpks <pks@pks.rocks>2025-11-30 22:40:03 +0100
commit8c77b1f424f0b00bf76fe959e66e1858fd0672b1 (patch)
treecde109a47b84edd3544726e1fd526b308081469c /finetuning.py
parent14dc940c6787c521e164289ce7ba9ef3cd901bf5 (diff)
WIP
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/finetuning.py b/finetuning.py
index 490855e..ed7e954 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -99,14 +99,21 @@ def main():
low_cpu_mem_usage=True,
)
+ if args.lora_small:
+ target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
+ modules_to_save = []
+ else:
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
+
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
task_type="CAUSAL_LM",
bias="none",
- target_modules="all-linear",
- modules_to_save=["lm_head", "embed_tokens"],
+ target_modules=target_modules,
+ modules_to_save=modules_to_save,
)
dataset = load_dataset("asdf2k/caption_translation")