Skip to content

Commit

Permalink
clean up and typo fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidShojanazeri committed Aug 7, 2023
1 parent 9b0eae4 commit 754b5d2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main(
length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
**kwargs
):
Expand Down Expand Up @@ -59,10 +59,11 @@ def main(
"pad_token": "<PAD>",
}
)
model.resize_token_embeddings(model.config.vocab_size + 1)

safety_checker = get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_saleforce_content_safety,
enable_salesforce_content_safety,
)

# Safety check of the user prompt
Expand All @@ -77,15 +78,14 @@ def main(
if not is_safe:
print(method)
print(report)
print("Skipping the inferece as the prompt is not safe.")
print("Skipping the inference as the prompt is not safe.")
sys.exit(1) # Exit the program with an error status

if peft_model:
model = load_peft_model(model, peft_model)

model.eval()
batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
model.resize_token_embeddings(model.config.vocab_size + 1)
batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions inference/safety_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ def __call__(self, output_text):
# Function to determine which safety checker to use based on the options selected
def get_safety_checker(enable_azure_content_safety,
enable_sensitive_topics,
enable_saleforce_content_safety,
enable_salesforce_content_safety,
):
safety_checker = []
if enable_azure_content_safety:
safety_checker.append(AzureSaftyChecker())
if enable_sensitive_topics:
safety_checker.append(AuditNLGSensitiveTopics())
if enable_saleforce_content_safety:
if enable_salesforce_content_safety:
safety_checker.append(SalesforceSafetyChecker())
return safety_checker

Expand Down

0 comments on commit 754b5d2

Please sign in to comment.