diff options
| author | pks <pks@pks.rocks> | 2025-11-21 23:02:16 +0100 |
|---|---|---|
| committer | pks <pks@pks.rocks> | 2025-11-21 23:02:16 +0100 |
| commit | 1cc209768ad0299c396af56231c43fded7dc3515 (patch) | |
| tree | bc7fc3dcfd094078e052807f4e336dc2a5e0a43b /inference.py | |
| parent | ea7a98c4f6fc0d0df785ef601e9fc30b78e51335 (diff) | |
WIP
Diffstat (limited to 'inference.py')
| -rw-r--r-- | inference.py | 46 |
1 files changed, 43 insertions, 3 deletions
diff --git a/inference.py b/inference.py index fc8c1cc..b08a3f6 100644 --- a/inference.py +++ b/inference.py @@ -26,6 +26,15 @@ def captioning_prompt(image): } ] + +def captioning_prompt_with_source(image, source): + caption = captioning_prompt(image) + prefix = json.dumps({"English": source}).removesuffix("}") + ', "German": "' + caption.append({"role": "assistant", "content": [{"type": "text", "text": prefix}]}) + + return caption + + def translation_prompt(source): return [ { @@ -66,7 +75,7 @@ def main(): from peft import PeftModel model = PeftModel.from_pretrained(model, sys.argv[1]) - if True: + if True: # Generate German translation given English source for filename in glob("*.jsonl"): sys.stderr.write(f"Processing {filename=}\n") with open(filename, "r+") as f: @@ -97,13 +106,13 @@ def main(): f.truncate() sys.stderr.write(f"{decoded=}\n") - else: + elif 2 == 3: # Generate caption & translation from scratch for filename in glob("../Images/*.jpg"): + image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" sys.stderr.write(f"Processing {filename=}\n") inputs = make_inputs(processor, captioning_prompt(Image.open(filename)), model.device) - input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): @@ -123,6 +132,37 @@ def main(): sys.stderr.write(f"{decoded=}\n") with open(f"{os.path.basename(filename).removesuffix('.jpg')}.jsonl", "w") as f: f.write(f"{decoded}\n") + elif False: # Generate German translation given English caption and image + for filename in glob("./baseline/files_test/*.jsonl"): + image = "../Images/" + os.path.basename(filename).removesuffix(".jsonl") + ".jpg" + sys.stderr.write(f"Processing {filename=}\n") + with open(filename, "r+") as f: + data = json.loads(f.read()) + prompt = captioning_prompt_with_source(Image.open(image), data["English"]) + inputs = make_inputs(processor, + prompt, + model.device) + + input_len = inputs["input_ids"].shape[-1] # Will not cut off assistant prefix + + with torch.inference_mode(): + generation = model.generate(**inputs, + max_new_tokens=300, + do_sample=True, + top_p=1.0, + top_k=50) + generation = generation[0] # batch size 1 + truncated_generation = generation[input_len:] + + decoded = processor.decode(truncated_generation, skip_special_tokens=True).removeprefix("```json").removesuffix("```").replace("\n", "").strip() + try: + _ = json.loads(decoded) + except: + sys.stderr.write(f"Error loading JSON from string '{decoded}' for {filename=}\n") + + sys.stderr.write(f"{decoded=}\n") + with open(f"{os.path.basename(filename)}", "w") as f: + f.write(f"{decoded}\n") if __name__ == "__main__": |
