summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-01 23:19:12 +0100
committerpks <pks@pks.rocks>2025-12-01 23:19:12 +0100
commitab9bf07297fd3edf7d894ebf92de175d1246a5da (patch)
tree553dc04d616e3ece9c56f030f2e4be309e6dfb8f /inference.py
parent1924181a775b6131842d840bfb9554142ca55a3c (diff)
WIP
Diffstat (limited to 'inference.py')
-rwxr-xr-xinference.py28
1 files changed, 15 insertions, 13 deletions
diff --git a/inference.py b/inference.py
index 2b0b867..351ffd9 100755
--- a/inference.py
+++ b/inference.py
@@ -13,6 +13,10 @@ from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+def clean_str(s):
+ return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+
+
def captioning_prompt(image):
return [
{
@@ -32,6 +36,7 @@ def captioning_prompt(image):
def captioning_prompt_with_source(image, source):
caption = captioning_prompt(image)
prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "'
+
caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]})
return caption
@@ -88,11 +93,10 @@ def main():
dataset = datasets.load_dataset(args.dataset)[args.data_subset]
if args.mode == "translate": # Generate German translation given English source
- for x in dataset:
- sys.stderr.write(f"Processing {x['image']=}\n")
+ for x in dataset:
+ sys.stderr.write(f"Processing id={x['id']=}\n")
- data = json.loads(x["assistant"])
- exit()
+ data = json.loads(x["assistant"])
inputs = make_inputs(processor,
translation_prompt(data["English"]),
@@ -107,23 +111,20 @@ def main():
top_k=50)
generation = generation[0][input_len:]
- decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ decoded = clean_str(processor.decode(generation, skip_special_tokens=True))
try:
new_data = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
data.update(new_data)
- f.write(json.dumps(data))
- f.truncate()
+ print(json.dumps(data))
- sys.stderr.write(f"{decoded=}\n")
elif args.mode == "from_scratch": # Generate caption & translation from scratch
- for filename in glob("../d/Images/*.jpg"):
- image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
- sys.stderr.write(f"Processing {filename=}\n")
+ for x in dataset:
+ image = x["image"]
inputs = make_inputs(processor,
- captioning_prompt(Image.open(filename)),
+ captioning_prompt(image),
model.device)
input_len = inputs["input_ids"].shape[-1]
@@ -138,7 +139,7 @@ def main():
disable_compile=True)
generation = generation[0][input_len:]
- decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ decoded = clean_str(processor.decode(generation, skip_special_tokens=True))
try:
_ = json.loads(decoded)
except:
@@ -147,6 +148,7 @@ def main():
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
f.write(f"{decoded}\n")
+
elif args.mode == "with_prefix": # Generate German translation given English caption and image
for filename in glob("./baseline/files_test/*.jsonl"):
image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"