diff options
| author | pks <pks@pks.rocks> | 2025-12-01 11:24:14 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-01 11:24:14 +0100 |
| commit | 3645232a2e350b37a20deb67f88c654f15efb635 (patch) | |
| tree | 59ca989bb9c86282c3369a3ffbacacb0282a462c /inference.py | |
| parent | 54a77fc6145494644b98f2bab2320cbe3f9fcc6d (diff) | |
WIP
Diffstat (limited to 'inference.py')
| -rwxr-xr-x[-rw-r--r--] | inference.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/inference.py b/inference.py index d20dc55..297f423 100644..100755 --- a/inference.py +++ b/inference.py @@ -74,6 +74,7 @@ def main(): args.model_id, device_map="cuda", dtype=torch.bfloat16, + attn_implementation="eager", ).eval() processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True) @@ -128,8 +129,11 @@ def main(): generation = model.generate(**inputs, max_new_tokens=300, do_sample=True, + temperature=0.8, top_p=1.0, - top_k=50) + top_k=50, + eos_token_id=stop_token_ids, + disable_compile=True) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() |
