diff --git a/inference/inference.py b/inference/inference.py index b5bd73f5d..985ce68d5 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -31,7 +31,7 @@ def main( length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs - enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 + enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5 max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts. **kwargs ): @@ -59,10 +59,11 @@ def main( "pad_token": "", } ) + model.resize_token_embeddings(model.config.vocab_size + 1) safety_checker = get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, - enable_saleforce_content_safety, + enable_salesforce_content_safety, ) # Safety check of the user prompt @@ -77,7 +78,7 @@ def main( if not is_safe: print(method) print(report) - print("Skipping the inferece as the prompt is not safe.") + print("Skipping the inference as the prompt is not safe.") sys.exit(1) # Exit the program with an error status if peft_model: @@ -85,7 +86,6 @@ def main( model.eval() batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt") - model.resize_token_embeddings(model.config.vocab_size + 1) batch = {k: v.to("cuda") for k, v in batch.items()} start = time.perf_counter() with torch.no_grad(): diff --git a/inference/safety_utils.py b/inference/safety_utils.py index 9c6d0c361..bc321eb92 100644 --- a/inference/safety_utils.py +++ b/inference/safety_utils.py @@ -154,14 +154,14 @@ def __call__(self, output_text): # Function to determine which safety checker to use based on the options selected def get_safety_checker(enable_azure_content_safety, enable_sensitive_topics, - enable_saleforce_content_safety, + enable_salesforce_content_safety, ): safety_checker = [] if enable_azure_content_safety: safety_checker.append(AzureSaftyChecker()) if enable_sensitive_topics: safety_checker.append(AuditNLGSensitiveTopics()) - if enable_saleforce_content_safety: + if enable_salesforce_content_safety: safety_checker.append(SalesforceSafetyChecker()) return safety_checker