summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-01 11:24:14 +0100
committerpks <pks@pks.rocks>2025-12-01 11:24:14 +0100
commit3645232a2e350b37a20deb67f88c654f15efb635 (patch)
tree59ca989bb9c86282c3369a3ffbacacb0282a462c
parent54a77fc6145494644b98f2bab2320cbe3f9fcc6d (diff)
WIP
-rwxr-xr-xfinetuning.py2
-rwxr-xr-x[-rw-r--r--]inference.py6
-rw-r--r--pyproject.toml1
3 files changed, 7 insertions, 2 deletions
diff --git a/finetuning.py b/finetuning.py
index 62f506c..a0ad1fe 100755
--- a/finetuning.py
+++ b/finetuning.py
@@ -80,7 +80,7 @@ def main():
parser.add_argument("--max-length", default=512, type=int)
parser.add_argument("--lora-config", choices=["S", "M", "L"], default="M")
args = parser.parse_args()
-
+
if args.bnb_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
diff --git a/inference.py b/inference.py
index d20dc55..297f423 100644..100755
--- a/inference.py
+++ b/inference.py
@@ -74,6 +74,7 @@ def main():
args.model_id,
device_map="cuda",
dtype=torch.bfloat16,
+ attn_implementation="eager",
).eval()
processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True)
@@ -128,8 +129,11 @@ def main():
generation = model.generate(**inputs,
max_new_tokens=300,
do_sample=True,
+ temperature=0.8,
top_p=1.0,
- top_k=50)
+ top_k=50,
+ eos_token_id=stop_token_ids,
+ disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
diff --git a/pyproject.toml b/pyproject.toml
index f76839e..7248d3a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,6 +10,7 @@ dependencies = [
"bitsandbytes>=0.48.2",
"datasets>=4.4.1",
"diskcache>=5.6.3",
+ "hf_transfer",
"huggingface-hub>=0.36.0",
"ipdb>=0.13.13",
"jsonargparse[signatures]>=4.35.0",