diff --git a/inference/inference.py b/inference/inference.py index 81668e3fb..8f502b178 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -99,11 +99,7 @@ def main( print("Skipping the inference as the prompt is not safe.") sys.exit(1) # Exit the program with an error status - if peft_model: - model = load_peft_model(model, peft_model) - - model.eval() - batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt") + batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt") batch = {k: v.to("cuda") for k, v in batch.items()} start = time.perf_counter()