diff options
Diffstat (limited to 'finetuning.py')
| -rwxr-xr-x | finetuning.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/finetuning.py b/finetuning.py index abb4b89..f787f95 100755 --- a/finetuning.py +++ b/finetuning.py @@ -55,7 +55,7 @@ def add_chat_text(example, processor): def collate(batch, processor): # FIXME: Support batch_size > 1 images = [i["image"] for i in batch] texts = [i["text"] for i in batch] - out = processor( + processor_output = processor( text=texts, images=images, padding=True, @@ -70,13 +70,14 @@ def collate(batch, processor): # FIXME: Support batch_size > 1 ) ] - out["labels"] = out["input_ids"].clone() + labels = processor_output["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + processor_output["labels"] = labels - out["labels"][labels == processor.tokenizer.pad_token_id] = -100 - out["labels"][labels == image_token_id] = -100 - out["labels"][labels == 262144] = -100 + return processor_output - return out def main(): parser = argparse.ArgumentParser() |
