diff options
| -rwxr-xr-x | inference.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/inference.py b/inference.py index c522dd4..9f17c5f 100755 --- a/inference.py +++ b/inference.py @@ -90,12 +90,11 @@ def make_inputs_with_prefix(processor, def generate_and_parse(model, processor, messages, - mode, args, example_id): sys.stderr.write(f"Processing {example_id=}\n") - if mode == "with_prefix": + if args.mode == "with_prefix": inputs, prefix = make_inputs_with_prefix(processor, messages, model.device) else: inputs = make_inputs(processor, messages, model.device) @@ -117,8 +116,8 @@ def generate_and_parse(model, output_tokens = generation[0][input_len:] output_text = clean_str(processor.decode(output_tokens, skip_special_tokens=True)) - if mode == "with_prefix": - output_text = (prefix + output_text).removeprefix("<bos><start_of_turn>userYou are a professional English-German translator and also a renowned photography critic.<start_of_image>Write a detailed caption for this image in a single sentence. Translate the caption into German. The output needs to be JSON, the keys being 'English' and 'German' for the respective captions. Only output the JSON, nothing else.<end_of_turn><start_of_turn>model") + if args.mode == "with_prefix": + output_text = prefix[prefix.index("{"):] + output_text try: return json.loads(output_text) |
