diff options
| author | pks <pks@pks.rocks> | 2025-11-30 22:28:41 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 22:28:41 +0100 |
| commit | 3dbd2c5d3ecb70b778bf7192ab12b98bb5c2ec12 (patch) | |
| tree | 62182f128ab451a33720e88454db848bbb6098c2 | |
| parent | 7fd45074ab47b2316c6a451981ad9f1f28cfead7 (diff) | |
WIP
| -rwxr-xr-x | finetuning.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/finetuning.py b/finetuning.py index b0cf1cb..4bf82d9 100755 --- a/finetuning.py +++ b/finetuning.py @@ -11,7 +11,6 @@ from glob import glob from peft import LoraConfig from transformers import ( AutoProcessor, - AutoModelForCausalLM, AutoModelForImageTextToText, TrainingArguments, BitsAndBytesConfig, @@ -33,7 +32,7 @@ def add_chat_text(example, processor): return example -def collate(batch, processor): # FIXME: Support batch_size > 1 +def collate(batch, processor, max_length): # FIXME: Support batch_size > 1 images = [i["image"] for i in batch] texts = [i["text"] for i in batch] processor_output = processor( @@ -41,7 +40,7 @@ def collate(batch, processor): # FIXME: Support batch_size > 1 images=images, padding=True, truncation=True, - max_length=512, + max_length=max_length, return_tensors="pt", ) @@ -76,6 +75,8 @@ def main(): parser.add_argument("--lora-dropout", default=0.05, type=float) parser.add_argument("--lora-r", default=16, type=int) parser.add_argument("--bnb4bit", action="store_true") + parser.add_argument("--max-grad-norm", default=1.0, type="float") + parser.add_argument("--max-length", default=512, type=int) args = parser.parse_args() if args.bnb4bit: @@ -104,9 +105,8 @@ def main(): r=args.lora_r, task_type="CAUSAL_LM", bias="none", - #target_modules="all-linear", - target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], - #modules_to_save=["lm_head", "embed_tokens"], + target_modules="all-linear", + modules_to_save=["lm_head", "embed_tokens"], ) dataset = load_dataset("asdf2k/caption_translation") @@ -126,6 +126,7 @@ def main(): gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, optim=args.optimizer, + max_grad_norm=args.max_grad_norm, remove_unused_columns=False, logging_steps=args.logging_steps, save_strategy="epoch", @@ -135,7 +136,7 @@ def main(): model=model, train_dataset=train_ds, eval_dataset=dev_ds, - data_collator=partial(collate, processor=processor), + data_collator=partial(collate, processor=processor, max_length=args.max_length), args=args, peft_config=peft_config, ) |
