diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 3d1ad50d..9259b2f5 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -70,10 +70,10 @@ def predict(self, model_requests: List[Dict[str, Any]]): if isinstance(inputs, str): prompts.append(inputs) elif isinstance(inputs, Dict): - if 'prompt' not in req: + if 'prompt' not in inputs: raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') prompts.append(inputs['prompt']) - if 'negative_prompt' in req: + if 'negative_prompt' in inputs: negative_prompts.append(inputs['negative_prompt']) else: raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}')