diff options
| author | pks <pks@pks.rocks> | 2025-12-01 23:19:12 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-01 23:19:12 +0100 |
| commit | ab9bf07297fd3edf7d894ebf92de175d1246a5da (patch) | |
| tree | 553dc04d616e3ece9c56f030f2e4be309e6dfb8f | |
| parent | 1924181a775b6131842d840bfb9554142ca55a3c (diff) | |
WIP
| -rwxr-xr-x | inference.py | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/inference.py b/inference.py index 2b0b867..351ffd9 100755 --- a/inference.py +++ b/inference.py @@ -13,6 +13,10 @@ from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration +def clean_str(s): + return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() + + def captioning_prompt(image): return [ { @@ -32,6 +36,7 @@ def captioning_prompt(image): def captioning_prompt_with_source(image, source): caption = captioning_prompt(image) prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' + caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) return caption @@ -88,11 +93,10 @@ def main(): dataset = datasets.load_dataset(args.dataset)[args.data_subset] if args.mode == "translate": # Generate German translation given English source - for x in dataset: - sys.stderr.write(f"Processing {x['image']=}\n") + for x in dataset: + sys.stderr.write(f"Processing id={x['id']=}\n") - data = json.loads(x["assistant"]) - exit() + data = json.loads(x["assistant"]) inputs = make_inputs(processor, translation_prompt(data["English"]), @@ -107,23 +111,20 @@ def main(): top_k=50) generation = generation[0][input_len:] - decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) try: new_data = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") data.update(new_data) - f.write(json.dumps(data)) - f.truncate() + print(json.dumps(data)) - sys.stderr.write(f"{decoded=}\n") elif args.mode == "from_scratch": # Generate caption & translation from scratch - for filename in glob("../d/Images/*.jpg"): - image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" - sys.stderr.write(f"Processing {filename=}\n") + for x in dataset: + image = x["image"] inputs = make_inputs(processor, - captioning_prompt(Image.open(filename)), + captioning_prompt(image), model.device) input_len = inputs["input_ids"].shape[-1] @@ -138,7 +139,7 @@ def main(): disable_compile=True) generation = generation[0][input_len:] - decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) try: _ = json.loads(decoded) except: @@ -147,6 +148,7 @@ def main(): sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") + elif args.mode == "with_prefix": # Generate German translation given English caption and image for filename in glob("./baseline/files_test/*.jsonl"): image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" |
