diff options
| author | pks <pks@pks.rocks> | 2025-12-01 12:07:08 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-01 12:07:08 +0100 |
| commit | e7dd88970fddea62006ba7b6620db6a31c97f5ed (patch) | |
| tree | 2a17c9d279254c91a14c8f0ac7ec0c04eaad52c0 /inference.py | |
| parent | 3645232a2e350b37a20deb67f88c654f15efb635 (diff) | |
WIP
Diffstat (limited to 'inference.py')
| -rwxr-xr-x | inference.py | 68 |
1 files changed, 35 insertions, 33 deletions
diff --git a/inference.py b/inference.py index 297f423..2b0b867 100755 --- a/inference.py +++ b/inference.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import datasets import json import os import requests @@ -68,54 +69,55 @@ def main(): parser.add_argument("--model", default="google/gemma-3-4b-it") parser.add_argument("--lora-adapter", default=None) parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"]) + parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) + parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test") args = parser.parse_args() model = Gemma3ForConditionalGeneration.from_pretrained( - args.model_id, + args.model, device_map="cuda", dtype=torch.bfloat16, attn_implementation="eager", ).eval() - processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True) + processor = AutoProcessor.from_pretrained(args.model, use_fast=True) if args.lora_adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, args.lora_adapter) + dataset = datasets.load_dataset(args.dataset)[args.data_subset] + if args.mode == "translate": # Generate German translation given English source - for filename in glob("*.jsonl"): - sys.stderr.write(f"Processing {filename=}\n") + for x in dataset: + sys.stderr.write(f"Processing {x['image']=}\n") + + data = json.loads(x["assistant"]) + exit() + + inputs = make_inputs(processor, + translation_prompt(data["English"]), + model.device) + input_len = inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model.generate(**inputs, + max_new_tokens=300, + do_sample=True, + top_p=1.0, + top_k=50) + generation = generation[0][input_len:] + + decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: - with open(filename, "r+") as f: - data = json.loads(f.read()) - f.seek(0) - - inputs = make_inputs(processor, - translation_prompt(data["English"]), - model.device) - input_len = inputs["input_ids"].shape[-1] - - with torch.inference_mode(): - generation = model.generate(**inputs, - max_new_tokens=300, - do_sample=True, - top_p=1.0, - top_k=50) - generation = generation[0][input_len:] - - decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() - try: - new_data = json.loads(decoded) - except: - sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") - - data.update(new_data) - f.write(json.dumps(data)) - f.truncate() - - sys.stderr.write(f"{decoded=}\n") + new_data = json.loads(decoded) except: - pass + sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + + data.update(new_data) + f.write(json.dumps(data)) + f.truncate() + + sys.stderr.write(f"{decoded=}\n") elif args.mode == "from_scratch": # Generate caption & translation from scratch for filename in glob("../d/Images/*.jpg"): image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" |
