diff options
| author | pks <pks@pks.rocks> | 2025-11-29 23:56:12 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-29 23:56:12 +0100 |
| commit | fc6e8636014dbaf882a2fa5685497dc225ee8e29 (patch) | |
| tree | 0ecd23bd0e8b2ff8aba7b7085657919b91f3bc7c /inference.py | |
| parent | abb0df2560b53ec624d8ecaf8444a679fdadb6b2 (diff) | |
WIP
Diffstat (limited to 'inference.py')
| -rw-r--r-- | inference.py | 103 |
1 files changed, 57 insertions, 46 deletions
diff --git a/inference.py b/inference.py index b08a3f6..d20dc55 100644 --- a/inference.py +++ b/inference.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import argparse import json import os import requests @@ -63,58 +64,66 @@ def make_inputs(processor, def main(): - model_id = "google/gemma-3-4b-it" + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="google/gemma-3-4b-it") + parser.add_argument("--lora-adapter", default=None) + parser.add_argument("--mode", choices=["from_scratch", "with_prefix", "translate"]) + args = parser.parse_args() + model = Gemma3ForConditionalGeneration.from_pretrained( - model_id, + args.model_id, device_map="cuda", dtype=torch.bfloat16, ).eval() - processor = AutoProcessor.from_pretrained(model_id, use_fast=True) + processor = AutoProcessor.from_pretrained(args.model_id, use_fast=True) - if len(sys.argv) == 2: + if args.lora_adapter: from peft import PeftModel - model = PeftModel.from_pretrained(model, sys.argv[1]) - - if True: # Generate German translation given English source + model = PeftModel.from_pretrained(model, args.lora_adapter) + + if args.mode == "translate": # Generate German translation given English source for filename in glob("*.jsonl"): sys.stderr.write(f"Processing {filename=}\n") - with open(filename, "r+") as f: - data = json.loads(f.read()) - f.seek(0) - - inputs = make_inputs(processor, - translation_prompt(data["English"]), - model.device) - input_len = inputs["input_ids"].shape[-1] - - with torch.inference_mode(): - generation = model.generate(**inputs, - max_new_tokens=300, - do_sample=True, - top_p=1.0, - top_k=50) - generation = generation[0][input_len:] - - decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() - try: - new_data = json.loads(decoded) - except: - sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") - - data.update(new_data) - f.write(json.dumps(data)) - f.truncate() - - sys.stderr.write(f"{decoded=}\n") - elif 2 == 3: # Generate caption & translation from scratch - for filename in glob("../Images/*.jpg"): - image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" + try: + with open(filename, "r+") as f: + data = json.loads(f.read()) + f.seek(0) + + inputs = make_inputs(processor, + translation_prompt(data["English"]), + model.device) + input_len = inputs["input_ids"].shape[-1] + + with torch.inference_mode(): + generation = model.generate(**inputs, + max_new_tokens=300, + do_sample=True, + top_p=1.0, + top_k=50) + generation = generation[0][input_len:] + + decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + try: + new_data = json.loads(decoded) + except: + sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + + data.update(new_data) + f.write(json.dumps(data)) + f.truncate() + + sys.stderr.write(f"{decoded=}\n") + except: + pass + elif args.mode == "from_scratch": # Generate caption & translation from scratch + for filename in glob("../d/Images/*.jpg"): + image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") inputs = make_inputs(processor, captioning_prompt(Image.open(filename)), model.device) input_len = inputs["input_ids"].shape[-1] - + with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, @@ -122,19 +131,19 @@ def main(): top_p=1.0, top_k=50) generation = generation[0][input_len:] - + decoded = processor.decode(generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") - + sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") - elif False: # Generate German translation given English caption and image + elif args.mode == "with_prefix": # Generate German translation given English caption and image for filename in glob("./baseline/files_test/*.jsonl"): - image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" + image = "../d/Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: data = json.loads(f.read()) @@ -144,7 +153,7 @@ def main(): model.device) input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix - + with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=300, @@ -153,16 +162,18 @@ def main(): top_k=50) generation = generation[0] # batch size 1 truncated_generation = generation[input_len:] - + decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() try: _ = json.loads(decoded) except: sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") - + sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename)}", "w") as f: f.write(f"{decoded}\n") + else: + sys.stderr.write(f"Unkown mode '{args.mode}'") if __name__ == "__main__": |
