summaryrefslogtreecommitdiff
path: root/finetuning.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-30 21:23:16 +0100
committerpks <pks@pks.rocks>2025-11-30 21:23:16 +0100
commitd1924e295bce5018ded513c23d863a8b2cfa5d61 (patch)
treeb7f59d5d04805f32ca7c21dbf7e26528d0249132 /finetuning.py
parentd047a5c8f0047ec7953e8850597e62dfdfdd93d5 (diff)
fix
Diffstat (limited to 'finetuning.py')
-rwxr-xr-xfinetuning.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/finetuning.py b/finetuning.py
index abb4b89..f787f95 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -55,7 +55,7 @@ def add_chat_text(example, processor):
def collate(batch, processor): # FIXME: Support batch_size > 1
images = [i["image"] for i in batch]
texts = [i["text"] for i in batch]
- out = processor(
+ processor_output = processor(
text=texts,
images=images,
padding=True,
@@ -70,13 +70,14 @@ def collate(batch, processor): # FIXME: Support batch_size > 1
)
]
- out["labels"] = out["input_ids"].clone()
+ labels = processor_output["input_ids"].clone()
+ labels[labels == processor.tokenizer.pad_token_id] = -100
+ labels[labels == image_token_id] = -100
+ labels[labels == 262144] = -100
+ processor_output["labels"] = labels
- out["labels"][labels == processor.tokenizer.pad_token_id] = -100
- out["labels"][labels == image_token_id] = -100
- out["labels"][labels == 262144] = -100
+ return processor_output
- return out
def main():
parser = argparse.ArgumentParser()