summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
committerpks <pks@pks.rocks>2025-11-29 23:56:12 +0100
commitfc6e8636014dbaf882a2fa5685497dc225ee8e29 (patch)
tree0ecd23bd0e8b2ff8aba7b7085657919b91f3bc7c /inference.py
parentabb0df2560b53ec624d8ecaf8444a679fdadb6b2 (diff)
WIP
Diffstat (limited to 'inference.py')
-rw-r--r--inference.py103
1 files changed, 57 insertions, 46 deletions
diff --git a/inference.py b/inference.py
index b08a3f6..d20dc55 100644
--- a/inference.py
+++ b/inference.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
+import argparse
import json
import os
import requests
@@ -63,58 +64,66 @@ def make_inputs(processor,
def main():
- model_id = "google/gemma-3-4b-it"
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", default="google/gemma-3-4b-it")
+ parser.add_argument("--lora-adapter", default=None)
+ parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"])
+ args = parser.parse_args()
+
model = Gemma3ForConditionalGeneration.from_pretrained(
- model_id,
+ args.model_id,
device_map="cuda",
dtype=torch.bfloat16,
).eval()
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
+ processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True)
- if len(sys.argv) == 2:
+ if args.lora_adapter:
from peft import PeftModel
- model = PeftModel.from_pretrained(model, sys.argv[1])
-
- if True: # Generate German translation given English source
+ model = PeftModel.from_pretrained(model, args.lora_adapter)
+
+ if args.mode == "translate": # Generate German translation given English source
for filename in glob("*.jsonl"):
sys.stderr.write(f"Processing {filename=}\n")
- with open(filename, "r+") as f:
- data = json.loads(f.read())
- f.seek(0)
-
- inputs = make_inputs(processor,
- translation_prompt(data["English"]),
- model.device)
- input_len = inputs["input_ids"].shape[-1]
-
- with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- top_p=1.0,
- top_k=50)
- generation = generation[0][input_len:]
-
- decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
- try:
- new_data = json.loads(decoded)
- except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
- data.update(new_data)
- f.write(json.dumps(data))
- f.truncate()
-
- sys.stderr.write(f"{decoded=}\n")
- elif 2 == 3: # Generate caption & translation from scratch
- for filename in glob("../Images/*.jpg"):
- image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
+ try:
+ with open(filename, "r+") as f:
+ data = json.loads(f.read())
+ f.seek(0)
+
+ inputs = make_inputs(processor,
+ translation_prompt(data["English"]),
+ model.device)
+ input_len = inputs["input_ids"].shape[-1]
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0][input_len:]
+
+ decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ try:
+ new_data = json.loads(decoded)
+ except:
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ data.update(new_data)
+ f.write(json.dumps(data))
+ f.truncate()
+
+ sys.stderr.write(f"{decoded=}\n")
+ except:
+ pass
+ elif args.mode == "from_scratch": # Generate caption & translation from scratch
+ for filename in glob("../d/Images/*.jpg"):
+ image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
sys.stderr.write(f"Processing {filename=}\n")
inputs = make_inputs(processor,
captioning_prompt(Image.open(filename)),
model.device)
input_len = inputs["input_ids"].shape[-1]
-
+
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
@@ -122,19 +131,19 @@ def main():
top_p=1.0,
top_k=50)
generation = generation[0][input_len:]
-
+
decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
_ = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
+
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
f.write(f"{decoded}\n")
- elif False: # Generate German translation given English caption and image
+ elif args.mode == "with_prefix": # Generate German translation given English caption and image
for filename in glob("./baseline/files_test/*.jsonl"):
- image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
+ image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
sys.stderr.write(f"Processing {filename=}\n")
with open(filename, "r+") as f:
data = json.loads(f.read())
@@ -144,7 +153,7 @@ def main():
model.device)
input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix
-
+
with torch.inference_mode():
generation = model.generate(**inputs,
max_new_tokens=300,
@@ -153,16 +162,18 @@ def main():
top_k=50)
generation = generation[0] # batch size 1
truncated_generation = generation[input_len:]
-
+
decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
try:
_ = json.loads(decoded)
except:
sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
+
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename)}", "w") as f:
f.write(f"{decoded}\n")
+ else:
+ sys.stderr.write(f"Unkown mode '{args.mode}'")
if __name__ == "__main__":