diff options
| -rwxr-xr-x | inference.py | 47 |
1 files changed, 30 insertions, 17 deletions
diff --git a/inference.py b/inference.py index dd5af9a..c522dd4 100755 --- a/inference.py +++ b/inference.py @@ -58,29 +58,47 @@ def translation_prompt(source): def make_inputs(processor, messages, device): - ret = processor.apply_chat_template( + return processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" + ).to(device, dtype=torch.bfloat16) + + +def make_inputs_with_prefix(processor, + messages, + device): + prefix = processor.apply_chat_template( messages, add_generation_prompt=False, tokenize=False, return_dict=True, return_tensors="pt" - ).to(device, dtype=torch.bfloat16) + ).removesuffix("<end_of_turn>\n") - print() - print(f"{ret=}") - print() - exit() + ret = processor( + text=prefix, + images=messages[1]['content'][0]['image'], # FIXME: That's not great + return_tensors="pt" + ).to(device, dtype=torch.bfloat16) - return ret + return ret, prefix def generate_and_parse(model, processor, messages, + mode, args, - example_id=None): + example_id): sys.stderr.write(f"Processing {example_id=}\n") - inputs = make_inputs(processor, messages, model.device) + + if mode == "with_prefix": + inputs, prefix = make_inputs_with_prefix(processor, messages, model.device) + else: + inputs = make_inputs(processor, messages, model.device) input_len = inputs["input_ids"].shape[-1] stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")] @@ -99,18 +117,13 @@ def generate_and_parse(model, output_tokens = generation[0][input_len:] output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True)) + if mode == "with_prefix": + output_text = (prefix + output_text).removeprefix("<bos><start_of_turn>userYou are a professional English-German translator and also a renowned photography critic.<start_of_image>Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn><start_of_turn>model") try: return json.loads(output_text) except Exception: - if example_id is not None: - sys.stderr.write( - f"Error loading JSON from string '{output_text}' for id={example_id}\n" - ) - else: - sys.stderr.write( - f"Error loading JSON from string '{output_text}'\n" - ) + sys.stderr.write(f"Error loading JSON from string '{output_text}' for id={example_id}\n") return output_text |
