summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xinference.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/inference.py b/inference.py
index c522dd4..9f17c5f 100755
--- a/inference.py
+++ b/inference.py
@@ -90,12 +90,11 @@ def make_inputs_with_prefix(processor,
def generate_and_parse(model,
processor,
messages,
- mode,
args,
example_id):
sys.stderr.write(f"Processing {example_id=}\n")
- if mode == "with_prefix":
+ if args.mode == "with_prefix":
inputs, prefix = make_inputs_with_prefix(processor, messages, model.device)
else:
inputs = make_inputs(processor, messages, model.device)
@@ -117,8 +116,8 @@ def generate_and_parse(model,
output_tokens = generation[0][input_len:]
output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True))
- if mode == "with_prefix":
- output_text = (prefix + output_text).removeprefix("<bos><start_of_turn>userYou are a professional English-German translator and also a renowned photography critic.<start_of_image>Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn><start_of_turn>model")
+ if args.mode == "with_prefix":
+ output_text = prefix[prefix.index("{"):] + output_text
try:
return json.loads(output_text)