summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-05 22:23:16 +0100
committerpks <pks@pks.rocks>2025-12-05 22:23:16 +0100
commitf243793b76a8ace9b8a690cb02afb0e91a5b0531 (patch)
tree7ec30fdd8642a3edfe41ed09ced8271200150909
parentc0ed7b3ada7f41faaad9a2a64697d6a0e385ed86 (diff)
WIP
-rwxr-xr-xinference.py110
-rw-r--r--inference2.py179
-rwxr-xr-xsetup.sh8
3 files changed, 240 insertions, 57 deletions
diff --git a/inference.py b/inference.py
index 69ab33b..c8adb16 100755
--- a/inference.py
+++ b/inference.py
@@ -10,7 +10,7 @@ import torch
from glob import glob
from PIL import Image
-from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+from transformers import AutoProcessor, AutoModelForImageTextToText
def clean_str(s):
@@ -71,18 +71,19 @@ def make_inputs(processor,
def main():
parser = argparse.ArgumentParser()
- parser.add_argument("--model", default="google/gemma-3-4b-it")
- parser.add_argument("--lora-adapter", default=None)
- parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"])
+ parser.add_argument("--model", default="google/gemma-3-4b-it", type=str)
+ parser.add_argument("--attention-implementation", default="eager", type=str)
+ parser.add_argument("--lora-adapter", default=None, type=str)
+ parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True)
parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str)
- parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test")
+ parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str)
args = parser.parse_args()
- model = Gemma3ForConditionalGeneration.from_pretrained(
+ model = AutoModelForImageTextToText.from_pretrained(
args.model,
device_map="cuda",
dtype=torch.bfloat16,
- attn_implementation="eager",
+ attn_implementation=args.attention_implementation,
).eval()
processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
@@ -94,7 +95,7 @@ def main():
if args.mode == "translate": # Generate German translation given English source
for x in dataset:
- sys.stderr.write(f"Processing id={x['id']=}\n")
+ sys.stderr.write(f"Processing id={x['id']}\n")
data = json.loads(x["assistant"])
@@ -104,21 +105,23 @@ def main():
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- top_p=1.0,
- top_k=50)
- generation = generation[0][input_len:]
-
- decoded = clean_str(processor.decode(generation, skip_special_tokens=True))
+ output = model.generate(**inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ disable_compile=True)
+ output = generation[0][input_len:]
+
+ output = clean_str(processor.decode(output, skip_special_tokens=True))
+
try:
- new_data = json.loads(decoded)
+ output = json.loads(output)
except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+ sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
- data.update(new_data)
- print(json.dumps(data))
+ print(json.dumps(output))
elif args.mode == "from_scratch": # Generate caption & translation from scratch
for x in dataset:
@@ -129,57 +132,50 @@ def main():
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- temperature=0.8,
- top_p=1.0,
- top_k=50,
- eos_token_id=stop_token_ids,
- disable_compile=True)
- generation = generation[0][input_len:]
-
- decoded = clean_str(processor.decode(generation, skip_special_tokens=True))
+ output = model.generate(**inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ eos_token_id=stop_token_ids,
+ disable_compile=True)
+ output = output[0][input_len:]
+
+ output = clean_str(processor.decode(output, skip_special_tokens=True))
try:
- _ = json.loads(decoded)
+ output = json.loads(output)
except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
+ sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
- sys.stderr.write(f"{decoded=}\n")
- with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f:
- f.write(f"{decoded}\n")
+ print(json.dumps(output))
elif args.mode == "with_prefix": # Generate German translation given English caption and image
- for filename in glob("./baseline/files_test/*.jsonl"):
- image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg"
- sys.stderr.write(f"Processing {filename=}\n")
- with open(filename, "r+") as f:
- data = json.loads(f.read())
+ for x in dataset:
+ sys.stderr.write(f"Processing id={x['id']}\n")
+ data = json.loads(x['assistant_reply'])
prompt = captioning_prompt_with_source(Image.open(image), data["English"])
inputs = make_inputs(processor,
prompt,
model.device)
-
input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix
with torch.inference_mode():
- generation = model.generate(**inputs,
- max_new_tokens=300,
- do_sample=True,
- top_p=1.0,
- top_k=50)
- generation = generation[0] # batch size 1
- truncated_generation = generation[input_len:]
-
- decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+ output = model.generate(**inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k)
+ output = generation[0][input_len:]
+
+ output = clean_str(processor.decode(output, skip_special_tokens=True))
try:
- _ = json.loads(decoded)
+ output = json.loads(output)
except:
- sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n")
-
- sys.stderr.write(f"{decoded=}\n")
- with open(f"{os.path.basename(filename)}", "w") as f:
- f.write(f"{decoded}\n")
+ sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n")
+
+ print(json.dumps(output))
else:
sys.stderr.write(f"Unkown mode '{args.mode}'")
diff --git a/inference2.py b/inference2.py
new file mode 100644
index 0000000..67e633a
--- /dev/null
+++ b/inference2.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+
+import argparse
+import datasets
+import json
+import os
+import requests
+import sys
+import torch
+from glob import glob
+from PIL import Image
+from transformers import AutoProcessor, AutoModelForImageTextToText
+
+
+def clean_str(s):
+ return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip()
+
+
+def captioning_prompt(image):
+ return [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}]
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": image},
+ {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."}
+ ]
+ }
+ ]
+
+
+def captioning_prompt_with_source(image, source):
+ prompt = captioning_prompt(image)
+ prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "'
+ prompt.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]})
+
+ return prompt
+
+
+def translation_prompt(source):
+ return [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are a professional English-German translator."}]
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"}
+ ]
+ }
+ ]
+
+
+def make_inputs(processor,
+ messages,
+ device):
+ return processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt"
+ ).to(device, dtype=torch.bfloat16)
+
+
+def generate_and_parse(model,
+ processor,
+ messages,
+ args,
+ example_id=None):
+ sys.stderr.write(f"Processing {example_id=}\n")
+ inputs = make_inputs(processor, messages, model.device)
+ input_len = inputs["input_ids"].shape[-1]
+
+ stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
+
+ with torch.inference_mode():
+ generation = model.generate(
+ **inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=not args.do_not_sample,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ eos_token_id=stop_token_ids,
+ disable_compile=True,
+ )
+
+ output_tokens = generation[0][input_len:]
+ output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True))
+
+ try:
+ return json.loads(output_text)
+ except Exception:
+ if example_id is not None:
+ sys.stderr.write(
+ f"Error loading JSON from string '{output_text}' for id={example_id}\n"
+ )
+ else:
+ sys.stderr.write(
+ f"Error loading JSON from string '{output_text}'\n"
+ )
+ return output_text
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", default="google/gemma-3-4b-it", type=str)
+ parser.add_argument("--attention-implementation", default="eager", type=str)
+ parser.add_argument("--lora-adapter", default=None, type=str)
+ parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True)
+ parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str)
+ parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str)
+ parser.add_argument("--max-new-tokens", default=300, type=int)
+ parser.add_argument("--top-p", default=1.0, type=int)
+ parser.add_argument("--top-k", default=50, type=int)
+ parser.add_argument("--temperature", default=0.8, type=int)
+ parser.add_argument("--do-not-sample", action="store_true")
+ args = parser.parse_args()
+
+ model = AutoModelForImageTextToText.from_pretrained(
+ args.model,
+ device_map="cuda",
+ dtype=torch.bfloat16,
+ attn_implementation=args.attention_implementation,
+ ).eval()
+ processor = AutoProcessor.from_pretrained(args.model, use_fast=True)
+
+ if args.lora_adapter:
+ from peft import PeftModel
+ model = PeftModel.from_pretrained(model, args.lora_adapter)
+
+ dataset = datasets.load_dataset(args.dataset)[args.data_subset]
+
+ if args.mode == "from_scratch": # Generate caption & translation from scratch
+ for x in dataset:
+ output = generate_and_parse(
+ model,
+ processor,
+ captioning_prompt(x["image"]),
+ args,
+ example_id=x["id"],
+ )
+ print(f"{x['id']}\t{json.dumps(output)}")
+
+ elif args.mode == "translate": # Generate German translation given English source
+ for x in dataset:
+ input_data = json.loads(x["assistant"])
+ output = generate_and_parse(
+ model,
+ processor,
+ translation_prompt(input_data["English"]),
+ args,
+ example_id=x["id"],
+ )
+ output = {"English": input_data["English"], "German": output["Translation"]}
+ print(f"{x['id']}\t{json.dumps(output)}")
+
+ elif args.mode == "with_prefix": # Generate German translation given English caption and image
+ for x in dataset:
+ assistant_output_as_input = json.loads(x["assistant"])
+ output = generate_and_parse(
+ model,
+ processor,
+ captioning_prompt_with_source(x["image"], assistant_output_as_input["English"]),
+ args,
+ example_id=x["id"],
+ )
+ print(f"{x['id']}\t{json.dumps(output)}")
+ else:
+ sys.stderr.write(f"Unkown mode '{args.mode}'")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/setup.sh b/setup.sh
new file mode 100755
index 0000000..f85003b
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+sudo apt install jq pipx
+pipx install uv
+export PATH=~/.local/bin:$PATH
+uv sync
+export PATH=$(pwd)/.venv/bin:$PATH
+huggingface-cli login