diff options
| author | pks <pks@pks.rocks> | 2025-11-30 21:23:16 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-30 21:23:16 +0100 |
| commit | d1924e295bce5018ded513c23d863a8b2cfa5d61 (patch) | |
| tree | b7f59d5d04805f32ca7c21dbf7e26528d0249132 /finetuning.py | |
| parent | d047a5c8f0047ec7953e8850597e62dfdfdd93d5 (diff) | |
fix
Diffstat (limited to 'finetuning.py')
| -rwxr-xr-x | finetuning.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/finetuning.py b/finetuning.py index abb4b89..f787f95 100755 --- a/finetuning.py +++ b/finetuning.py @@ -55,7 +55,7 @@ def add_chat_text(example, processor): def collate(batch, processor): # FIXME: Support batch_size > 1 images = [i["image"] for i in batch] texts = [i["text"] for i in batch] - out = processor( + processor_output = processor( text=texts, images=images, padding=True, @@ -70,13 +70,14 @@ def collate(batch, processor): # FIXME: Support batch_size > 1 ) ] - out["labels"] = out["input_ids"].clone() + labels = processor_output["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + processor_output["labels"] = labels - out["labels"][labels == processor.tokenizer.pad_token_id] = -100 - out["labels"][labels == image_token_id] = -100 - out["labels"][labels == 262144] = -100 + return processor_output - return out def main(): parser = argparse.ArgumentParser() |
