diff options
| author | pks <pks@pks.rocks> | 2025-12-05 22:23:16 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-12-05 22:23:16 +0100 |
| commit | f243793b76a8ace9b8a690cb02afb0e91a5b0531 (patch) | |
| tree | 7ec30fdd8642a3edfe41ed09ced8271200150909 | |
| parent | c0ed7b3ada7f41faaad9a2a64697d6a0e385ed86 (diff) | |
WIP
| -rwxr-xr-x | inference.py | 110 | ||||
| -rw-r--r-- | inference2.py | 179 | ||||
| -rwxr-xr-x | setup.sh | 8 |
3 files changed, 240 insertions, 57 deletions
diff --git a/inference.py b/inference.py index 69ab33b..c8adb16 100755 --- a/inference.py +++ b/inference.py @@ -10,7 +10,7 @@ import torch from glob import glob from PIL import Image -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import AutoProcessor, AutoModelForImageTextToText def clean_str(s): @@ -71,18 +71,19 @@ def make_inputs(processor, def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", default="google/gemma-3-4b-it") - parser.add_argument("--lora-adapter", default=None) - parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"]) + parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) + parser.add_argument("--attention-implementation", default="eager", type=str) + parser.add_argument("--lora-adapter", default=None, type=str) + parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) - parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test") + parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) args = parser.parse_args() - model = Gemma3ForConditionalGeneration.from_pretrained( + model = AutoModelForImageTextToText.from_pretrained( args.model, device_map="cuda", dtype=torch.bfloat16, - attn_implementation="eager", + attn_implementation=args.attention_implementation, ).eval() processor = AutoProcessor.from_pretrained(args.model, use_fast=True) @@ -94,7 +95,7 @@ def main(): if args.mode == "translate": # Generate German translation given English source for x in dataset: - sys.stderr.write(f"Processing id={x['id']=}\n") + sys.stderr.write(f"Processing id={x['id']}\n") data = json.loads(x["assistant"]) @@ -104,21 +105,23 @@ def main(): input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): - generation = model.generate(**inputs, - max_new_tokens=300, - do_sample=True, - top_p=1.0, - top_k=50) - generation = generation[0][input_len:] - - decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) + output = model.generate(**inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + disable_compile=True) + output = generation[0][input_len:] + + output = clean_str(processor.decode(output, skip_special_tokens=True)) + try: - new_data = json.loads(decoded) + output = json.loads(output) except: - sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") - data.update(new_data) - print(json.dumps(data)) + print(json.dumps(output)) elif args.mode == "from_scratch": # Generate caption & translation from scratch for x in dataset: @@ -129,57 +132,50 @@ def main(): input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): - generation = model.generate(**inputs, - max_new_tokens=300, - do_sample=True, - temperature=0.8, - top_p=1.0, - top_k=50, - eos_token_id=stop_token_ids, - disable_compile=True) - generation = generation[0][input_len:] - - decoded = clean_str(processor.decode(generation, skip_special_tokens=True)) + output = model.generate(**inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + eos_token_id=stop_token_ids, + disable_compile=True) + output = output[0][input_len:] + + output = clean_str(processor.decode(output, skip_special_tokens=True)) try: - _ = json.loads(decoded) + output = json.loads(output) except: - sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") - sys.stderr.write(f"{decoded=}\n") - with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: - f.write(f"{decoded}\n") + print(json.dumps(output)) elif args.mode == "with_prefix": # Generate German translation given English caption and image - for filename in glob("./baseline/files_test/*.jsonl"): - image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" - sys.stderr.write(f"Processing {filename=}\n") - with open(filename, "r+") as f: - data = json.loads(f.read()) + for x in dataset: + sys.stderr.write(f"Processing id={x['id']}\n") + data = json.loads(x['assistant_reply']) prompt = captioning_prompt_with_source(Image.open(image), data["English"]) inputs = make_inputs(processor, prompt, model.device) - input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix with torch.inference_mode(): - generation = model.generate(**inputs, - max_new_tokens=300, - do_sample=True, - top_p=1.0, - top_k=50) - generation = generation[0] # batch size 1 - truncated_generation = generation[input_len:] - - decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + output = model.generate(**inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + args.temperature, + top_p=args.top_p, + top_k=args.top_k) + output = generation[0][input_len:] + + output = clean_str(processor.decode(output, skip_special_tokens=True)) try: - _ = json.loads(decoded) + output = json.loads(output) except: - sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") - - sys.stderr.write(f"{decoded=}\n") - with open(f"{os.path.basename(filename)}", "w") as f: - f.write(f"{decoded}\n") + sys.stderr.write(f"Error loading JSON from string '{output}' for id{x['id']}\n") + + print(json.dumps(output)) else: sys.stderr.write(f"Unkown mode '{args.mode}'") diff --git a/inference2.py b/inference2.py new file mode 100644 index 0000000..67e633a --- /dev/null +++ b/inference2.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 + +import argparse +import datasets +import json +import os +import requests +import sys +import torch +from glob import glob +from PIL import Image +from transformers import AutoProcessor, AutoModelForImageTextToText + + +def clean_str(s): + return s.removeprefix("```json").removesuffix("```").replace("\n", "").strip() + + +def captioning_prompt(image): + return [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a professional English-German translator and also a renowned photography critic."}] + }, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else."} + ] + } + ] + + +def captioning_prompt_with_source(image, source): + prompt = captioning_prompt(image) + prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' + prompt.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) + + return prompt + + +def translation_prompt(source): + return [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a professional English-German translator."}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": f"Translate the following caption into German. The output needs to be JSON, the only being 'Translation' for the translation. Only output the JSON, nothing else. Caption: {source}"} + ] + } + ] + + +def make_inputs(processor, + messages, + device): + return processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" + ).to(device, dtype=torch.bfloat16) + + +def generate_and_parse(model, + processor, + messages, + args, + example_id=None): + sys.stderr.write(f"Processing {example_id=}\n") + inputs = make_inputs(processor, messages, model.device) + input_len = inputs["input_ids"].shape[-1] + + stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")] + + with torch.inference_mode(): + generation = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=not args.do_not_sample, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + eos_token_id=stop_token_ids, + disable_compile=True, + ) + + output_tokens = generation[0][input_len:] + output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True)) + + try: + return json.loads(output_text) + except Exception: + if example_id is not None: + sys.stderr.write( + f"Error loading JSON from string '{output_text}' for id={example_id}\n" + ) + else: + sys.stderr.write( + f"Error loading JSON from string '{output_text}'\n" + ) + return output_text + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="google/gemma-3-4b-it", type=str) + parser.add_argument("--attention-implementation", default="eager", type=str) + parser.add_argument("--lora-adapter", default=None, type=str) + parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"], type=str, required=True) + parser.add_argument("--dataset", default="asdf2k/caption_translation", type=str) + parser.add_argument("--data-subset", choices=["train", "dev", "test"], default="test", type=str) + parser.add_argument("--max-new-tokens", default=300, type=int) + parser.add_argument("--top-p", default=1.0, type=int) + parser.add_argument("--top-k", default=50, type=int) + parser.add_argument("--temperature", default=0.8, type=int) + parser.add_argument("--do-not-sample", action="store_true") + args = parser.parse_args() + + model = AutoModelForImageTextToText.from_pretrained( + args.model, + device_map="cuda", + dtype=torch.bfloat16, + attn_implementation=args.attention_implementation, + ).eval() + processor = AutoProcessor.from_pretrained(args.model, use_fast=True) + + if args.lora_adapter: + from peft import PeftModel + model = PeftModel.from_pretrained(model, args.lora_adapter) + + dataset = datasets.load_dataset(args.dataset)[args.data_subset] + + if args.mode == "from_scratch": # Generate caption & translation from scratch + for x in dataset: + output = generate_and_parse( + model, + processor, + captioning_prompt(x["image"]), + args, + example_id=x["id"], + ) + print(f"{x['id']}\t{json.dumps(output)}") + + elif args.mode == "translate": # Generate German translation given English source + for x in dataset: + input_data = json.loads(x["assistant"]) + output = generate_and_parse( + model, + processor, + translation_prompt(input_data["English"]), + args, + example_id=x["id"], + ) + output = {"English": input_data["English"], "German": output["Translation"]} + print(f"{x['id']}\t{json.dumps(output)}") + + elif args.mode == "with_prefix": # Generate German translation given English caption and image + for x in dataset: + assistant_output_as_input = json.loads(x["assistant"]) + output = generate_and_parse( + model, + processor, + captioning_prompt_with_source(x["image"], assistant_output_as_input["English"]), + args, + example_id=x["id"], + ) + print(f"{x['id']}\t{json.dumps(output)}") + else: + sys.stderr.write(f"Unkown mode '{args.mode}'") + + +if __name__ == "__main__": + main() diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..f85003b --- /dev/null +++ b/setup.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +sudo apt install jq pipx +pipx install uv +export PATH=~/.local/bin:$PATH +uv sync +export PATH=$(pwd)/.venv/bin:$PATH +huggingface-cli login |
