summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
committerpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
commitfc6e8636014dbaf882a2fa5685497dc225ee8e29 (patch)
tree0ecd23bd0e8b2ff8aba7b7085657919b91f3bc7c
parentabb0df2560b53ec624d8ecaf8444a679fdadb6b2 (diff)
WIP
-rw-r--r--README5
-rw-r--r--finetuning.py (renamed from finetune.py)0
-rw-r--r--inference.py103
-rw-r--r--pyproject.toml7
-rw-r--r--uv.lock114
5 files changed, 128 insertions, 101 deletions
diff --git a/README b/README
new file mode 100644
index 0000000..2c9e5bd
--- /dev/null
+++ b/README
@@ -0,0 +1,5 @@
+# Setup
+sudo apt install jq pipx
+pipx install uv
+huggingface-cli login
+
diff --git a/finetune.py b/finetuning.py
index 988a027..988a027 100644
--- a/finetune.py
+++ b/finetuning.py
diff --git a/inference.py b/inference.py
index b08a3f6..d20dc55 100644
--- a/inference.py
+++ b/inference.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
+import argparse
import json
import os
import requests
@@ -63,58 +64,66 @@ def make_inputs(processor,
def main():
- model_id = "google/gemma-3-4b-it"
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", default="google/gemma-3-4b-it")
+ parser.add_argument("--lora-adapter", default=None)
+ parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"])
+ args = parser.parse_args()
+
model = Gemma3ForConditionalGeneration.from_pretrained(
- model_id,
+ args.model_id,
device_map="cuda",
dtype=torch.bfloat16,
).eval()
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
+ processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True)
- if len(sys.argv) == 2:
+ if args.lora_adapter:
from peft import PeftModel
- model = PeftModel.from_pretrained(model, sys.argv[1])
-
- if True: # Generate German translation given English source
+ model = PeftModel.from_pretrained(model, args.lora_adapter)
+
+ if args.mode == "translate": # Generate German translation given English source
for filename in glob("*.jsonl"):
sys.stderr.write(f"Processing {filename=}\n")
- with open(filename, "r+") as f:
- data = json.loads(f.read())
- f.seek(0)
-
- inputs = make_inputs(processor,
- translation_prompt(data["English"]),
- model.device)
- input_len = inputs["input_ids"].shape[-1]
-
- with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- top_p=1.0,
- top_k=50)
- generation = generation[0][input_len:]
-
- decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
- try:
- new_data = json.loads(decoded)
- except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
- data.update(new_data)
- f.write(json.dumps(data))
- f.truncate()
-
- sys.stderr.write(f"{decoded=}\n")
- elif 2 == 3: # Generate caption & translation from scratch
- for filename in glob("../Images/*.jpg"):
- image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
+ try:
+ with open(filename, "r+") as f:
+ data = json.loads(f.read())
+ f.seek(0)
+
+ inputs = make_inputs(processor,
+ translation_prompt(data["English"]),
+ model.device)
+ input_len = inputs["input_ids"].shape[-1]
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0][input_len:]
+
+ decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ try:
+ new_data = json.loads(decoded)
+ except:
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ data.update(new_data)
+ f.write(json.dumps(data))
+ f.truncate()
+
+ sys.stderr.write(f"{decoded=}\n")
+ except:
+ pass
+ elif args.mode == "from_scratch": # Generate caption & translation from scratch
+ for filename in glob("../d/Images/*.jpg"):
+ image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
sys.stderr.write(f"Processing {filename=}\n")
inputs = make_inputs(processor,
captioning_prompt(Image.open(filename)),
model.device)
input_len = inputs["input_ids"].shape[-1]
-
+
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
@@ -122,19 +131,19 @@ def main():
top_p=1.0,
top_k=50)
generation = generation[0][input_len:]
-
+
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
_ = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
+
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
f.write(f"{decoded}\n")
- elif False: # Generate German translation given English caption and image
+ elif args.mode == "with_prefix": # Generate German translation given English caption and image
for filename in glob("./baseline/files_test/*.jsonl"):
- image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
+ image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
sys.stderr.write(f"Processing {filename=}\n")
with open(filename, "r+") as f:
data = json.loads(f.read())
@@ -144,7 +153,7 @@ def main():
model.device)
input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix
-
+
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
@@ -153,16 +162,18 @@ def main():
top_k=50)
generation = generation[0] # batch size 1
truncated_generation = generation[input_len:]
-
+
decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
_ = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
+
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename)}", "w") as f:
f.write(f"{decoded}\n")
+ else:
+ sys.stderr.write(f"Unkown mode '{args.mode}'")
if __name__ == "__main__":
diff --git a/pyproject.toml b/pyproject.toml
index bbf1288..c650ffa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,16 +1,18 @@
[project]
-name = "gemma3_4b"
+name = "CaptionTranslation"
version = "0.1.0"
-description = "Add your description here"
+description = "Finetuning gemma3 towards translating captions"
requires-python = ">=3.13"
dependencies = [
"absl-py>=2.3.1",
"accelerate>=1.11.0",
+ "adamw-bf16>=0.0.3",
"bitsandbytes>=0.48.2",
"datasets>=4.4.1",
"diskcache>=5.6.3",
"huggingface-hub>=0.36.0",
"ipdb>=0.13.13",
+ "jsonargparse[signatures]>=4.35.0",
"ninja>=1.13.0",
"openai>=2.8.0",
"packaging>=25.0",
@@ -18,6 +20,7 @@ dependencies = [
"pillow>=12.0.0",
"sacrebleu>=2.5.1",
"termcolor>=3.2.0",
+ "torchao>=0.14.1",
"torchvision>=0.24.1",
"transformers>=4.57.1",
"trl>=0.25.1",
diff --git a/uv.lock b/uv.lock
index 92c1f0e..162bf27 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.13"
[[package]]
@@ -13,7 +13,7 @@ wheels = [
[[package]]
name = "accelerate"
-version = "1.11.0"
+version = "1.12.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
@@ -24,9 +24,21 @@ dependencies = [
{ name = "safetensors" },
{ name = "torch" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/23/60/2757c4f03a8705dbf80b1268b03881927878dca5ed07d74f733fb6c219e0/accelerate-1.11.0.tar.gz", hash = "sha256:bb1caf2597b4cd632b917b5000c591d10730bb024a79746f1ee205bba80bd229", size = 393715, upload-time = "2025-10-20T14:42:25.025Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/4a/8e/ac2a9566747a93f8be36ee08532eb0160558b07630a081a6056a9f89bf1d/accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6", size = 398399, upload-time = "2025-11-21T11:27:46.973Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/77/85/85951bc0f9843e2c10baaa1b6657227056095de08f4d1eea7d8b423a6832/accelerate-1.11.0-py3-none-any.whl", hash = "sha256:a628fa6beb069b8e549460fc449135d5bd8d73e7a11fd09f0bc9fc4ace7f06f1", size = 375777, upload-time = "2025-10-20T14:42:23.256Z" },
+ { url = "https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11", size = 380935, upload-time = "2025-11-21T11:27:44.522Z" },
+]
+
+[[package]]
+name = "adamw-bf16"
+version = "0.0.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "torch" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/2d/2f/5eb18ee927565eb420583a7b5c95779e196b3e14b958e98a1ee44bfe2b18/adamw_bf16-0.0.3.tar.gz", hash = "sha256:826b046e23ebd34ab1fd0e8700ea4e3f419d0d66e6211b814cf120440796cc0a", size = 16442, upload-time = "2024-04-04T16:07:46.9Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7c/4c/195a2835e02f862b6fe4d7067142ce66cde47a45274840c4e1b4b81fcc18/adamw_bf16-0.0.3-py3-none-any.whl", hash = "sha256:6795187b69b92cc962d0b02471173b997e6fb15b9ffc36825e9de3a207bebeec", size = 16847, upload-time = "2024-04-04T16:07:44.572Z" },
]
[[package]]
@@ -142,11 +154,11 @@ wheels = [
[[package]]
name = "asttokens"
-version = "3.0.0"
+version = "3.0.1"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/be/a5/8e3f9b6771b0b408517c82d97aed8f2036509bc247d46114925e32fe33f0/asttokens-3.0.1.tar.gz", hash = "sha256:71a4ee5de0bde6a31d64f6b13f2293ac190344478f081c3d1bccfcf5eacb0cb7", size = 62308, upload-time = "2025-11-15T16:43:48.578Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" },
+ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" },
]
[[package]]
@@ -159,21 +171,6 @@ wheels = [
]
[[package]]
-name = "bitsandbytes"
-version = "0.48.2"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "numpy" },
- { name = "packaging" },
- { name = "torch" },
-]
-wheels = [
- { url = "https://files.pythonhosted.org/packages/d5/56/e58f1eeb41c71d3789d401ad147fc0fa78fa7004ea58616ba6abc50889c8/bitsandbytes-0.48.2-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:defbfa374d93809de3811cd2bca6978d1d51ecaa39f5bdd2018e1394a4886603", size = 33764723, upload-time = "2025-10-29T21:40:36.389Z" },
- { url = "https://files.pythonhosted.org/packages/3b/72/f6934097c94e023967bf7eb24e50987aca8a40f4cabf4e0282957f6258c2/bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:cd289562cb7308ee2a707e6884fecca9bbbcfc9ec33a86df2a45e0779692c1a3", size = 59399452, upload-time = "2025-10-29T21:40:39.941Z" },
- { url = "https://files.pythonhosted.org/packages/17/2d/309b3ba276d1fe5f3db4fb5e0a6926d10654656cdbfe49881aa95f4cfcc2/bitsandbytes-0.48.2-py3-none-win_amd64.whl", hash = "sha256:a048c285eb6ff53a8d189880e9dfa421d2bfb54e8cab263311757cf5b742d865", size = 58992447, upload-time = "2025-10-29T21:40:43.921Z" },
-]
-
-[[package]]
name = "certifi"
version = "2025.11.12"
source = { registry = "https://pypi.org/simple" }
@@ -416,7 +413,7 @@ source = { virtual = "." }
dependencies = [
{ name = "absl-py" },
{ name = "accelerate" },
- { name = "bitsandbytes" },
+ { name = "adamw-bf16" },
{ name = "datasets" },
{ name = "diskcache" },
{ name = "huggingface-hub" },
@@ -428,6 +425,7 @@ dependencies = [
{ name = "pillow" },
{ name = "sacrebleu" },
{ name = "termcolor" },
+ { name = "torchao" },
{ name = "torchvision" },
{ name = "transformers" },
{ name = "trl" },
@@ -438,7 +436,7 @@ dependencies = [
requires-dist = [
{ name = "absl-py", specifier = ">=2.3.1" },
{ name = "accelerate", specifier = ">=1.11.0" },
- { name = "bitsandbytes", specifier = ">=0.48.2" },
+ { name = "adamw-bf16", specifier = ">=0.0.3" },
{ name = "datasets", specifier = ">=4.4.1" },
{ name = "diskcache", specifier = ">=5.6.3" },
{ name = "huggingface-hub", specifier = ">=0.36.0" },
@@ -450,6 +448,7 @@ requires-dist = [
{ name = "pillow", specifier = ">=12.0.0" },
{ name = "sacrebleu", specifier = ">=2.5.1" },
{ name = "termcolor", specifier = ">=3.2.0" },
+ { name = "torchao", specifier = ">=0.14.1" },
{ name = "torchvision", specifier = ">=0.24.1" },
{ name = "transformers", specifier = ">=4.57.1" },
{ name = "trl", git = "https://github.com/huggingface/trl.git" },
@@ -932,11 +931,11 @@ wheels = [
[[package]]
name = "networkx"
-version = "3.5"
+version = "3.6"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/e8/fc/7b6fd4d22c8c4dc5704430140d8b3f520531d4fe7328b8f8d03f5a7950e8/networkx-3.6.tar.gz", hash = "sha256:285276002ad1f7f7da0f7b42f004bcba70d381e936559166363707fdad3d72ad", size = 2511464, upload-time = "2025-11-24T03:03:47.158Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" },
+ { url = "https://files.pythonhosted.org/packages/07/c7/d64168da60332c17d24c0d2f08bdf3987e8d1ae9d84b5bbd0eec2eb26a55/networkx-3.6-py3-none-any.whl", hash = "sha256:cdb395b105806062473d3be36458d8f1459a4e4b98e236a66c3a48996e07684f", size = 2063713, upload-time = "2025-11-24T03:03:45.21Z" },
]
[[package]]
@@ -1107,7 +1106,7 @@ wheels = [
[[package]]
name = "openai"
-version = "2.8.0"
+version = "2.8.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@@ -1119,9 +1118,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/04/0c/b9321e12f89e236f5e9a46346c30fb801818e22ba33b798a5aca84be895c/openai-2.8.0.tar.gz", hash = "sha256:4851908f6d6fcacbd47ba659c5ac084f7725b752b6bfa1e948b6fbfc111a6bad", size = 602412, upload-time = "2025-11-13T18:15:25.847Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/d5/e4/42591e356f1d53c568418dc7e30dcda7be31dd5a4d570bca22acb0525862/openai-2.8.1.tar.gz", hash = "sha256:cb1b79eef6e809f6da326a7ef6038719e35aa944c42d081807bfa1be8060f15f", size = 602490, upload-time = "2025-11-17T22:39:59.549Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/5b/e1/0a6560bab7fb7b5a88d35a505b859c6d969cb2fa2681b568eb5d95019dec/openai-2.8.0-py3-none-any.whl", hash = "sha256:ba975e347f6add2fe13529ccb94d54a578280e960765e5224c34b08d7e029ddf", size = 1022692, upload-time = "2025-11-13T18:15:23.621Z" },
+ { url = "https://files.pythonhosted.org/packages/55/4f/dbc0c124c40cb390508a82770fb9f6e3ed162560181a85089191a851c59a/openai-2.8.1-py3-none-any.whl", hash = "sha256:c6c3b5a04994734386e8dad3c00a393f56d3b68a27cd2e8acae91a59e4122463", size = 1022688, upload-time = "2025-11-17T22:39:57.675Z" },
]
[[package]]
@@ -1462,7 +1461,7 @@ wheels = [
[[package]]
name = "pydantic"
-version = "2.12.4"
+version = "2.12.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "annotated-types" },
@@ -1470,9 +1469,9 @@ dependencies = [
{ name = "typing-extensions" },
{ name = "typing-inspection" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/96/ad/a17bc283d7d81837c061c49e3eaa27a45991759a1b7eae1031921c6bd924/pydantic-2.12.4.tar.gz", hash = "sha256:0f8cb9555000a4b5b617f66bfd2566264c4984b27589d3b845685983e8ea85ac", size = 821038, upload-time = "2025-11-05T10:50:08.59Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/69/44/36f1a6e523abc58ae5f928898e4aca2e0ea509b5aa6f6f392a5d882be928/pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49", size = 821591, upload-time = "2025-11-26T15:11:46.471Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/82/2f/e68750da9b04856e2a7ec56fc6f034a5a79775e9b9a81882252789873798/pydantic-2.12.4-py3-none-any.whl", hash = "sha256:92d3d202a745d46f9be6df459ac5a064fdaa3c1c4cd8adcfa332ccf3c05f871e", size = 463400, upload-time = "2025-11-05T10:50:06.732Z" },
+ { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" },
]
[[package]]
@@ -1724,24 +1723,24 @@ wheels = [
[[package]]
name = "safetensors"
-version = "0.6.2"
+version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ac/cc/738f3011628920e027a11754d9cae9abec1aed00f7ae860abbf843755233/safetensors-0.6.2.tar.gz", hash = "sha256:43ff2aa0e6fa2dc3ea5524ac7ad93a9839256b8703761e76e2d0b2a3fa4f15d9", size = 197968, upload-time = "2025-08-08T13:13:58.654Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/29/9c/6e74567782559a63bd040a236edca26fd71bc7ba88de2ef35d75df3bca5e/safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0", size = 200878, upload-time = "2025-11-19T15:18:43.199Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/4d/b1/3f5fd73c039fc87dba3ff8b5d528bfc5a32b597fea8e7a6a4800343a17c7/safetensors-0.6.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9c85ede8ec58f120bad982ec47746981e210492a6db876882aa021446af8ffba", size = 454797, upload-time = "2025-08-08T13:13:52.066Z" },
- { url = "https://files.pythonhosted.org/packages/8c/c9/bb114c158540ee17907ec470d01980957fdaf87b4aa07914c24eba87b9c6/safetensors-0.6.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d6675cf4b39c98dbd7d940598028f3742e0375a6b4d4277e76beb0c35f4b843b", size = 432206, upload-time = "2025-08-08T13:13:50.931Z" },
- { url = "https://files.pythonhosted.org/packages/d3/8e/f70c34e47df3110e8e0bb268d90db8d4be8958a54ab0336c9be4fe86dac8/safetensors-0.6.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d2d2b3ce1e2509c68932ca03ab8f20570920cd9754b05063d4368ee52833ecd", size = 473261, upload-time = "2025-08-08T13:13:41.259Z" },
- { url = "https://files.pythonhosted.org/packages/2a/f5/be9c6a7c7ef773e1996dc214e73485286df1836dbd063e8085ee1976f9cb/safetensors-0.6.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93de35a18f46b0f5a6a1f9e26d91b442094f2df02e9fd7acf224cfec4238821a", size = 485117, upload-time = "2025-08-08T13:13:43.506Z" },
- { url = "https://files.pythonhosted.org/packages/c9/55/23f2d0a2c96ed8665bf17a30ab4ce5270413f4d74b6d87dd663258b9af31/safetensors-0.6.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89a89b505f335640f9120fac65ddeb83e40f1fd081cb8ed88b505bdccec8d0a1", size = 616154, upload-time = "2025-08-08T13:13:45.096Z" },
- { url = "https://files.pythonhosted.org/packages/98/c6/affb0bd9ce02aa46e7acddbe087912a04d953d7a4d74b708c91b5806ef3f/safetensors-0.6.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fc4d0d0b937e04bdf2ae6f70cd3ad51328635fe0e6214aa1fc811f3b576b3bda", size = 520713, upload-time = "2025-08-08T13:13:46.25Z" },
- { url = "https://files.pythonhosted.org/packages/fe/5d/5a514d7b88e310c8b146e2404e0dc161282e78634d9358975fd56dfd14be/safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8045db2c872db8f4cbe3faa0495932d89c38c899c603f21e9b6486951a5ecb8f", size = 485835, upload-time = "2025-08-08T13:13:49.373Z" },
- { url = "https://files.pythonhosted.org/packages/7a/7b/4fc3b2ba62c352b2071bea9cfbad330fadda70579f617506ae1a2f129cab/safetensors-0.6.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:81e67e8bab9878bb568cffbc5f5e655adb38d2418351dc0859ccac158f753e19", size = 521503, upload-time = "2025-08-08T13:13:47.651Z" },
- { url = "https://files.pythonhosted.org/packages/5a/50/0057e11fe1f3cead9254315a6c106a16dd4b1a19cd247f7cc6414f6b7866/safetensors-0.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0e4d029ab0a0e0e4fdf142b194514695b1d7d3735503ba700cf36d0fc7136ce", size = 652256, upload-time = "2025-08-08T13:13:53.167Z" },
- { url = "https://files.pythonhosted.org/packages/e9/29/473f789e4ac242593ac1656fbece6e1ecd860bb289e635e963667807afe3/safetensors-0.6.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:fa48268185c52bfe8771e46325a1e21d317207bcabcb72e65c6e28e9ffeb29c7", size = 747281, upload-time = "2025-08-08T13:13:54.656Z" },
- { url = "https://files.pythonhosted.org/packages/68/52/f7324aad7f2df99e05525c84d352dc217e0fa637a4f603e9f2eedfbe2c67/safetensors-0.6.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:d83c20c12c2d2f465997c51b7ecb00e407e5f94d7dec3ea0cc11d86f60d3fde5", size = 692286, upload-time = "2025-08-08T13:13:55.884Z" },
- { url = "https://files.pythonhosted.org/packages/ad/fe/cad1d9762868c7c5dc70c8620074df28ebb1a8e4c17d4c0cb031889c457e/safetensors-0.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d944cea65fad0ead848b6ec2c37cc0b197194bec228f8020054742190e9312ac", size = 655957, upload-time = "2025-08-08T13:13:57.029Z" },
- { url = "https://files.pythonhosted.org/packages/59/a7/e2158e17bbe57d104f0abbd95dff60dda916cf277c9f9663b4bf9bad8b6e/safetensors-0.6.2-cp38-abi3-win32.whl", hash = "sha256:cab75ca7c064d3911411461151cb69380c9225798a20e712b102edda2542ddb1", size = 308926, upload-time = "2025-08-08T13:14:01.095Z" },
- { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" },
+ { url = "https://files.pythonhosted.org/packages/fa/47/aef6c06649039accf914afef490268e1067ed82be62bcfa5b7e886ad15e8/safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517", size = 467781, upload-time = "2025-11-19T15:18:35.84Z" },
+ { url = "https://files.pythonhosted.org/packages/e8/00/374c0c068e30cd31f1e1b46b4b5738168ec79e7689ca82ee93ddfea05109/safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57", size = 447058, upload-time = "2025-11-19T15:18:34.416Z" },
+ { url = "https://files.pythonhosted.org/packages/f1/06/578ffed52c2296f93d7fd2d844cabfa92be51a587c38c8afbb8ae449ca89/safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542", size = 491748, upload-time = "2025-11-19T15:18:09.79Z" },
+ { url = "https://files.pythonhosted.org/packages/ae/33/1debbbb70e4791dde185edb9413d1fe01619255abb64b300157d7f15dddd/safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104", size = 503881, upload-time = "2025-11-19T15:18:16.145Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/1c/40c2ca924d60792c3be509833df711b553c60effbd91da6f5284a83f7122/safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d", size = 623463, upload-time = "2025-11-19T15:18:21.11Z" },
+ { url = "https://files.pythonhosted.org/packages/9b/3a/13784a9364bd43b0d61eef4bea2845039bc2030458b16594a1bd787ae26e/safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a", size = 532855, upload-time = "2025-11-19T15:18:25.719Z" },
+ { url = "https://files.pythonhosted.org/packages/a0/60/429e9b1cb3fc651937727befe258ea24122d9663e4d5709a48c9cbfceecb/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48", size = 507152, upload-time = "2025-11-19T15:18:33.023Z" },
+ { url = "https://files.pythonhosted.org/packages/3c/a8/4b45e4e059270d17af60359713ffd83f97900d45a6afa73aaa0d737d48b6/safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981", size = 541856, upload-time = "2025-11-19T15:18:31.075Z" },
+ { url = "https://files.pythonhosted.org/packages/06/87/d26d8407c44175d8ae164a95b5a62707fcc445f3c0c56108e37d98070a3d/safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b", size = 674060, upload-time = "2025-11-19T15:18:37.211Z" },
+ { url = "https://files.pythonhosted.org/packages/11/f5/57644a2ff08dc6325816ba7217e5095f17269dada2554b658442c66aed51/safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85", size = 771715, upload-time = "2025-11-19T15:18:38.689Z" },
+ { url = "https://files.pythonhosted.org/packages/86/31/17883e13a814bd278ae6e266b13282a01049b0c81341da7fd0e3e71a80a3/safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0", size = 714377, upload-time = "2025-11-19T15:18:40.162Z" },
+ { url = "https://files.pythonhosted.org/packages/4a/d8/0c8a7dc9b41dcac53c4cbf9df2b9c83e0e0097203de8b37a712b345c0be5/safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4", size = 677368, upload-time = "2025-11-19T15:18:41.627Z" },
+ { url = "https://files.pythonhosted.org/packages/05/e5/cb4b713c8a93469e3c5be7c3f8d77d307e65fe89673e731f5c2bfd0a9237/safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba", size = 326423, upload-time = "2025-11-19T15:18:45.74Z" },
+ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" },
]
[[package]]
@@ -1980,6 +1979,15 @@ wheels = [
]
[[package]]
+name = "torchao"
+version = "0.14.1"
+source = { registry = "https://pypi.org/simple" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/91/56/19abb32bbdc55d9fdebf8d6315a8f7d8ae10e387a91c631abd92afe0056b/torchao-0.14.1-cp310-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50f68db5e41952e88daa383fc2f358541e617654f388f508d5c7580c3bee9447", size = 7197175, upload-time = "2025-10-24T01:02:59.223Z" },
+ { url = "https://files.pythonhosted.org/packages/41/a7/b888635fbb6ae951cffd41e1318966cbed96ec762b4999815ab68269e23f/torchao-0.14.1-py3-none-any.whl", hash = "sha256:c9896e14531817bc2ca6847b3fe71c42592ab80a43628b36668b2d6d6713fb5b", size = 1067611, upload-time = "2025-10-24T01:03:01.357Z" },
+]
+
+[[package]]
name = "torchmetrics"
version = "0.10.3"
source = { registry = "https://pypi.org/simple" }
@@ -2044,7 +2052,7 @@ wheels = [
[[package]]
name = "transformers"
-version = "4.57.1"
+version = "4.57.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
@@ -2058,9 +2066,9 @@ dependencies = [
{ name = "tokenizers" },
{ name = "tqdm" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/d6/68/a39307bcc4116a30b2106f2e689130a48de8bd8a1e635b5e1030e46fcd9e/transformers-4.57.1.tar.gz", hash = "sha256:f06c837959196c75039809636cd964b959f6604b75b8eeec6fdfc0440b89cc55", size = 10142511, upload-time = "2025-10-14T15:39:26.18Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/dd/70/d42a739e8dfde3d92bb2fff5819cbf331fe9657323221e79415cd5eb65ee/transformers-4.57.3.tar.gz", hash = "sha256:df4945029aaddd7c09eec5cad851f30662f8bd1746721b34cc031d70c65afebc", size = 10139680, upload-time = "2025-11-25T15:51:30.139Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/71/d3/c16c3b3cf7655a67db1144da94b021c200ac1303f82428f2beef6c2e72bb/transformers-4.57.1-py3-none-any.whl", hash = "sha256:b10d05da8fa67dc41644dbbf9bc45a44cb86ae33da6f9295f5fbf5b7890bd267", size = 11990925, upload-time = "2025-10-14T15:39:23.085Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/6b/2f416568b3c4c91c96e5a365d164f8a4a4a88030aa8ab4644181fdadce97/transformers-4.57.3-py3-none-any.whl", hash = "sha256:c77d353a4851b1880191603d36acb313411d3577f6e2897814f333841f7003f4", size = 11993463, upload-time = "2025-11-25T15:51:26.493Z" },
]
[[package]]
@@ -2077,7 +2085,7 @@ wheels = [
[[package]]
name = "trl"
version = "0.26.0.dev0"
-source = { git = "https://github.com/huggingface/trl.git#1b1242cc6522feb4eb063feb20097a79b11b127a" }
+source = { git = "https://github.com/huggingface/trl.git#cb5fdf9ab57c1635774c8549d81a87866bb9b8a7" }
dependencies = [
{ name = "accelerate" },
{ name = "datasets" },