summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-01 12:07:08 +0100
committerpks <pks@pks.rocks>2025-12-01 12:07:08 +0100
commite7dd88970fddea62006ba7b6620db6a31c97f5ed (patch)
tree2a17c9d279254c91a14c8f0ac7ec0c04eaad52c0 /inference.py
parent3645232a2e350b37a20deb67f88c654f15efb635 (diff)
WIP
Diffstat (limited to 'inference.py')
-rwxr-xr-xinference.py68
1 files changed, 35 insertions, 33 deletions
diff --git a/inference.py b/inference.py
index 297f423..2b0b867 100755
--- a/inference.py
+++ b/inference.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import argparse
+import datasets
import json
import os
import requests
@@ -68,54 +69,55 @@ def main():
parser.add_argument("--model", default="google/gemma-3-4b-it")
parser.add_argument("--lora-adapter", default=None)
parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"])
+ parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str)
+ parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test")
args = parser.parse_args()
model = Gemma3ForConditionalGeneration.from_pretrained(
- args.model_id,
+ args.model,
device_map="cuda",
dtype=torch.bfloat16,
attn_implementation="eager",
).eval()
- processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True)
+ processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
if args.lora_adapter:
from peft import PeftModel
model = PeftModel.from_pretrained(model, args.lora_adapter)
+ dataset = datasets.load_dataset(args.dataset)[args.data_subset]
+
if args.mode == "translate": # Generate German translation given English source
- for filename in glob("*.jsonl"):
- sys.stderr.write(f"Processing {filename=}\n")
+ for x in dataset:
+ sys.stderr.write(f"Processing {x['image']=}\n")
+
+ data = json.loads(x["assistant"])
+ exit()
+
+ inputs = make_inputs(processor,
+ translation_prompt(data["English"]),
+ model.device)
+ input_len = inputs["input_ids"].shape[-1]
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0][input_len:]
+
+ decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
- with open(filename, "r+") as f:
- data = json.loads(f.read())
- f.seek(0)
-
- inputs = make_inputs(processor,
- translation_prompt(data["English"]),
- model.device)
- input_len = inputs["input_ids"].shape[-1]
-
- with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- top_p=1.0,
- top_k=50)
- generation = generation[0][input_len:]
-
- decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
- try:
- new_data = json.loads(decoded)
- except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
- data.update(new_data)
- f.write(json.dumps(data))
- f.truncate()
-
- sys.stderr.write(f"{decoded=}\n")
+ new_data = json.loads(decoded)
except:
- pass
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ data.update(new_data)
+ f.write(json.dumps(data))
+ f.truncate()
+
+ sys.stderr.write(f"{decoded=}\n")
elif args.mode == "from_scratch": # Generate caption & translation from scratch
for filename in glob("../d/Images/*.jpg"):
image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"