From d28f19604dff7bd18bdc0bdc44887b9bbbc45623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9D=D0=B8=D0=B7=D0=B0=D0=BC=D0=BE=D0=B2=20=D0=A2=D0=B8?= =?UTF-8?q?=D0=BC=D1=83=D1=80=20=D0=94=D0=B0=D0=BC=D0=B8=D1=80=D0=BE=D0=B2?= =?UTF-8?q?=D0=B8=D1=87?= Date: Thu, 12 Sep 2024 14:12:04 +0500 Subject: [PATCH] fix `get_system_prompts_summary` if there are no system prompts --- src/llamator/client/client_config.py | 12 +++++++++++- tests/test_llamator.py | 25 ++++++++++++++++--------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/llamator/client/client_config.py b/src/llamator/client/client_config.py index 6e0484c..ea900b4 100644 --- a/src/llamator/client/client_config.py +++ b/src/llamator/client/client_config.py @@ -33,7 +33,17 @@ def get_model_description(self): return self.target_client.model_description def get_system_prompts_summary(self, attack_client: ClientBase) -> str: + system_prompts = self.get_system_prompts() + model_description = self.get_model_description() + if system_prompts is None: + system_prompts = [] + if model_description is not None: + system_prompts += [model_description] + if len(system_prompts) == 0: + return "" + if len(system_prompts) == 1: + return system_prompts[0] if self.system_prompts_summary is None: # Only compute summary once (lazy, on first call) - self.system_prompts_summary = summarize_system_prompts(attack_client, self.get_system_prompts()) + self.system_prompts_summary = summarize_system_prompts(attack_client, system_prompts) return self.system_prompts_summary diff --git a/tests/test_llamator.py b/tests/test_llamator.py index ce9b862..463e0cd 100644 --- a/tests/test_llamator.py +++ b/tests/test_llamator.py @@ -36,17 +36,20 @@ def test_openaiclient(): ) tests = [ - "aim_jailbreak", - "base64_injection", + # "aim_jailbreak", + # "base64_injection", # "complimentary_transition", # "do_anything_now_jailbreak", + # "RU_do_anything_now_jailbreak", # "ethical_compliance", # "harmful_behavior", - # "self_refine", # "linguistic_evasion", + # "self_refine", + # "RU_self_refine", + # "sycophancy_test", # "typoglycemia_attack", # "ucar", - # "sycophancy_test", + # "RU_ucar", ] config = { @@ -90,21 +93,25 @@ def test_langchainclient_yandegpt(): tested_model = ClientLangChain( backend="yandex_gpt", - system_prompts=["You are a support bot"], + model_description="You are a support bot", model_uri=f"gpt://{folder_ID}/yandexgpt-lite/latest", ) tests = [ - "aim_jailbreak", - "base64_injection", - # "complimentary_transition", + # "aim_jailbreak", + # "base64_injection", + "complimentary_transition", # "do_anything_now_jailbreak", + # "RU_do_anything_now_jailbreak", # "ethical_compliance", # "harmful_behavior", - # "self_refine", # "linguistic_evasion", + # "self_refine", + # "RU_self_refine", + "sycophancy_test", # "typoglycemia_attack", # "ucar", + # "RU_ucar", ] config = {