summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/finetuning.py b/finetuning.py
index 490855e..ed7e954 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -99,14 +99,21 @@ def main():
low_cpu_mem_usage=True,
)
+ if args.lora_small:
+ target_modules = ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"]
+ modules_to_save = []
+ else:
+ target_modules="all-linear",
+ modules_to_save=["lm_head", "embed_tokens"],
+
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
task_type="CAUSAL_LM",
bias="none",
- target_modules="all-linear",
- modules_to_save=["lm_head", "embed_tokens"],
+ target_modules=target_modules,
+ modules_to_save=modules_to_save,
)
dataset = load_dataset("asdf2k/caption_translation")