summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-29 23:07:22 +0100
committerpks <pks@pks.rocks>2025-11-29 23:07:22 +0100
commit8fd25e95dc733d38db219083e28396045090a556 (patch)
treeeba45b54747aa233a54d141e8221f9dc4cb12d65
parent1cc209768ad0299c396af56231c43fded7dc3515 (diff)
use torchao's 8bit adamw
-rw-r--r--finetune.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/finetune.py b/finetune.py
index 01012c5..ba2a2c7 100644
--- a/finetune.py
+++ b/finetune.py
@@ -24,7 +24,7 @@ def make_dataset(base="./baseline"):
with open(filename, "r") as f:
data = json.loads(f.read())
print(f"{data=}")
- image_path = f"../Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
+ image_path = f"../d/Images/{os.path.basename(filename).removesuffix(".jsonl")}.jpg"
user_prompts.append(prompt)
assistant_replies.append(json.dumps({
"English": data["English"],
@@ -106,7 +106,8 @@ def main():
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
- optim="paged_adamw_8bit",
+ #optim="paged_adamw_8bit",
+ optim="adamw_torch_8bit",
remove_unused_columns=False,
logging_steps=10,
save_steps=100,