summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-01 11:24:14 +0100
committerpks <pks@pks.rocks>2025-12-01 11:24:14 +0100
commit3645232a2e350b37a20deb67f88c654f15efb635 (patch)
tree59ca989bb9c86282c3369a3ffbacacb0282a462c /inference.py
parent54a77fc6145494644b98f2bab2320cbe3f9fcc6d (diff)
WIP
Diffstat (limited to 'inference.py')
-rwxr-xr-x[-rw-r--r--]inference.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/inference.py b/inference.py
index d20dc55..297f423 100644..100755
--- a/inference.py
+++ b/inference.py
@@ -74,6 +74,7 @@ def main():
args.model_id,
device_map="cuda",
dtype=torch.bfloat16,
+ attn_implementation="eager",
).eval()
processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True)
@@ -128,8 +129,11 @@ def main():
generation = model.generate(**inputs,
max_new_tokens=300,
do_sample=True,
+ temperature=0.8,
top_p=1.0,
- top_k=50)
+ top_k=50,
+ eos_token_id=stop_token_ids,
+ disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()