summaryrefslogtreecommitdiff
path: root/inference.py
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-11-21 23:02:16 +0100
committerpks <pks@pks.rocks>2025-11-21 23:02:16 +0100
commit1cc209768ad0299c396af56231c43fded7dc3515 (patch)
treebc7fc3dcfd094078e052807f4e336dc2a5e0a43b /inference.py
parentea7a98c4f6fc0d0df785ef601e9fc30b78e51335 (diff)
WIP
Diffstat (limited to 'inference.py')
-rw-r--r--inference.py46
1 files changed, 43 insertions, 3 deletions
diff --git a/inference.py b/inference.py
index fc8c1cc..b08a3f6 100644
--- a/inference.py
+++ b/inference.py
@@ -26,6 +26,15 @@ def captioning_prompt(image):
}
]
+
+def captioning_prompt_with_source(image, source):
+ caption = captioning_prompt(image)
+ prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "'
+ caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]})
+
+ return caption
+
+
def translation_prompt(source):
return [
{
@@ -66,7 +75,7 @@ def main():
from peft import PeftModel
model = PeftModel.from_pretrained(model, sys.argv[1])
- if True:
+ if True: # Generate German translation given English source
for filename in glob("*.jsonl"):
sys.stderr.write(f"Processing {filename=}\n")
with open(filename, "r+") as f:
@@ -97,13 +106,13 @@ def main():
f.truncate()
sys.stderr.write(f"{decoded=}\n")
- else:
+ elif 2 == 3: # Generate caption & translation from scratch
for filename in glob("../Images/*.jpg"):
+ image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
sys.stderr.write(f"Processing {filename=}\n")
inputs = make_inputs(processor,
captioning_prompt(Image.open(filename)),
model.device)
-
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
@@ -123,6 +132,37 @@ def main():
sys.stderr.write(f"{decoded=}\n")
with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
f.write(f"{decoded}\n")
+ elif False: # Generate German translation given English caption and image
+ for filename in glob("./baseline/files_test/*.jsonl"):
+ image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
+ sys.stderr.write(f"Processing {filename=}\n")
+ with open(filename, "r+") as f:
+ data = json.loads(f.read())
+ prompt = captioning_prompt_with_source(Image.open(image), data["English"])
+ inputs = make_inputs(processor,
+ prompt,
+ model.device)
+
+ input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix
+
+ with torch.inference_mode():
+ generation = model.generate(**inputs,
+ max_new_tokens=300,
+ do_sample=True,
+ top_p=1.0,
+ top_k=50)
+ generation = generation[0] # batch size 1
+ truncated_generation = generation[input_len:]
+
+ decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ try:
+ _ = json.loads(decoded)
+ except:
+ sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+
+ sys.stderr.write(f"{decoded=}\n")
+ with open(f"{os.path.basename(filename)}", "w") as f:
+ f.write(f"{decoded}\n")
if __name__ == "__main__":