summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpks <pks@pks.rocks>2025-12-06 22:30:31 +0100
committerpks <pks@pks.rocks>2025-12-06 22:30:31 +0100
commit976f53f314608c63d22128dd607ba3eb6161b3a1 (patch)
tree39aef4150a49d439b6e49841214a02a03c3e1306
parent1769ed698990d805cceba78a4fe887aa2374ff99 (diff)
WIP
-rwxr-xr-xinference.py47
1 files changed, 30 insertions, 17 deletions
diff --git a/inference.py b/inference.py
index dd5af9a..c522dd4 100755
--- a/inference.py
+++ b/inference.py
@@ -58,29 +58,47 @@ def translation_prompt(source):
def make_inputs(processor,
messages,
device):
- ret = processor.apply_chat_template(
+ return processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt"
+ ).to(device, dtype=torch.bfloat16)
+
+
+def make_inputs_with_prefix(processor,
+ messages,
+ device):
+ prefix = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=False,
return_dict=True,
return_tensors="pt"
- ).to(device, dtype=torch.bfloat16)
+ ).removesuffix("<end_of_turn>\n")
- print()
- print(f"{ret=}")
- print()
- exit()
+ ret = processor(
+ text=prefix,
+ images=messages[1]['content'][0]['image'], # FIXME: That's not great
+ return_tensors="pt"
+ ).to(device, dtype=torch.bfloat16)
- return ret
+ return ret, prefix
def generate_and_parse(model,
processor,
messages,
+ mode,
args,
- example_id=None):
+ example_id):
sys.stderr.write(f"Processing {example_id=}\n")
- inputs = make_inputs(processor, messages, model.device)
+
+ if mode == "with_prefix":
+ inputs, prefix = make_inputs_with_prefix(processor, messages, model.device)
+ else:
+ inputs = make_inputs(processor, messages, model.device)
input_len = inputs["input_ids"].shape[-1]
stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
@@ -99,18 +117,13 @@ def generate_and_parse(model,
output_tokens = generation[0][input_len:]
output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True))
+ if mode == "with_prefix":
+ output_text = (prefix + output_text).removeprefix("<bos><start_of_turn>userYou are a professional English-German translator and also a renowned photography critic.<start_of_image>Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn><start_of_turn>model")
try:
return json.loads(output_text)
except Exception:
- if example_id is not None:
- sys.stderr.write(
- f"Error loading JSON from string '{output_text}' for id={example_id}\n"
- )
- else:
- sys.stderr.write(
- f"Error loading JSON from string '{output_text}'\n"
- )
+ sys.stderr.write(f"Error loading JSON from string '{output_text}' for id={example_id}\n")
return output_text