Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing artifacts #5

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ coverage.xml
report.xml

# CMake
cmake-build-*/
cmake-build-*/
/tests/artifacts/
8,054 changes: 4,028 additions & 4,026 deletions notebooks/llamator-api-example.ipynb

Large diffs are not rendered by default.

8,095 changes: 4,049 additions & 4,046 deletions notebooks/llamator-selenium-example.ipynb

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ bump2version>=1.0.1,<2.0.0

# Project dependencies
openai==1.6.1
langchain==0.0.353
langchain-community==0.0.7
langchain-core==0.1.4
argparse==1.4.0
python-dotenv==1.0.0
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
tqdm==4.66.1
colorama==0.4.6
prettytable==3.10.0
pandas==2.2.2
inquirer==3.2.4
prompt-toolkit==3.0.43
fastparquet==2024.2.0
fastparquet==2024.2.0
yandexcloud==0.316.0
openpyxl==3.1.5
datetime==5.5
jupyter==1.1.1
10 changes: 6 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,19 @@ python_requires = >=3.8
install_requires =
python-dotenv>=0.5.1
openai==1.6.1
langchain==0.0.353
langchain-community==0.0.7
langchain-core==0.1.4
argparse==1.4.0
langchain==0.2.16
langchain-community==0.2.16
langchain-core==0.2.38
tqdm==4.66.1
colorama==0.4.6
prettytable==3.10.0
pandas==2.2.2
inquirer==3.2.4
prompt-toolkit==3.0.43
fastparquet==2024.2.0
yandexcloud==0.316.0
openpyxl==3.1.5
datetime==5.5
[options.packages.find]
where=src

Expand Down
18 changes: 9 additions & 9 deletions src/llamator/attack_provider/attack_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from ..attacks import (
dynamic_test,
translation,
typoglycemia,
dan,
from ..attacks import ( # noqa
aim,
self_refine,
ethical_compliance,
ucar,
base64_injection,
complimentary_transition,
dan,
dynamic_test,
ethical_compliance,
harmful_behavior,
base64_injection,
self_refine,
sycophancy,
translation,
typoglycemia,
ucar,
)

# from ..attacks import (
Expand Down
17 changes: 14 additions & 3 deletions src/llamator/attack_provider/attack_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import List, Type
import os
from typing import List, Optional, Type

from ..attack_provider.test_base import TestBase
from ..client.attack_config import AttackConfig
Expand Down Expand Up @@ -32,6 +33,7 @@ def instantiate_tests(
attack_config: AttackConfig,
basic_tests: List[str] = None,
custom_tests: List[Type[TestBase]] = None,
artifacts_path: Optional[str] = None, # New parameter for artifacts path
) -> List[Type[TestBase]]:
"""
Instantiate and return a list of test instances based on registered test classes
Expand All @@ -47,28 +49,37 @@ def instantiate_tests(
List of basic test names that need to be instantiated (default is None).
custom_tests : List[Type[TestBase]], optional
List of custom test classes that need to be instantiated (default is None).
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved (default is './artifacts').

Returns
-------
List[Type[TestBase]]
A list of instantiated test objects.
"""

csv_report_path = artifacts_path

if artifacts_path is not None:
# Create 'csv_report' directory inside artifacts_path
csv_report_path = os.path.join(artifacts_path, "csv_report")
os.makedirs(csv_report_path, exist_ok=True)

# List to store instantiated tests
tests = []

# Create instances of basic test classes
if basic_tests is not None:
for cls in test_classes:
test_instance = cls(client_config, attack_config)
test_instance = cls(client_config, attack_config, artifacts_path=csv_report_path)
if test_instance.test_name in basic_tests:
logger.debug(f"Instantiating attack test class: {cls.__name__}")
tests.append(test_instance)

# Create instances of custom test classes
if custom_tests is not None:
for custom_test in custom_tests:
test_instance = custom_test(client_config, attack_config)
test_instance = custom_test(client_config, attack_config, artifacts_path=csv_report_path)
logger.debug(f"Instantiating attack test class: {cls.__name__}")
tests.append(test_instance)

Expand Down
50 changes: 19 additions & 31 deletions src/llamator/attack_provider/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
from pydantic import ValidationError

from ..attack_provider.attack_registry import instantiate_tests
from ..attack_provider.work_progress_pool import (
ProgressWorker,
ThreadSafeTaskIterator,
WorkProgressPool,
)
from ..attack_provider.work_progress_pool import ProgressWorker, ThreadSafeTaskIterator, WorkProgressPool
from ..client.attack_config import AttackConfig
from ..client.chat_client import *
from ..client.client_config import ClientConfig
from ..format_output.results_table import print_table
from .attack_loader import *
from .attack_loader import * # noqa

# from .attack_loader import * - to register attacks defined in 'attack/*.py'
from .test_base import StatusUpdate, TestBase, TestStatus

Expand Down Expand Up @@ -117,6 +114,7 @@ def run_tests(
threads_count: int,
basic_tests: List[str],
custom_tests: List[Type[TestBase]],
artifacts_path: Optional[str] = None,
):
"""
Run the tests on the given client and attack configurations.
Expand All @@ -133,6 +131,8 @@ def run_tests(
A list of basic test names to be executed.
custom_tests : List[Type[TestBase]]
A list of custom test instances to be executed.
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved.

Returns
-------
Expand All @@ -145,7 +145,7 @@ def run_tests(

# Instantiate all tests
tests: List[Type[TestBase]] = instantiate_tests(
client_config, attack_config, basic_tests=basic_tests, custom_tests=custom_tests
client_config, attack_config, basic_tests=basic_tests, custom_tests=custom_tests, artifacts_path=artifacts_path
)

# Run tests in parallel mode
Expand All @@ -155,7 +155,7 @@ def run_tests(
report_results(tests)


def run_tests_in_parallel(tests: List[Type[TestBase]], threads_count: int):
def run_tests_in_parallel(tests: List[Type[TestBase]], threads_count: int = 1):
"""
Run the tests in parallel using a thread pool.

Expand Down Expand Up @@ -212,11 +212,7 @@ def report_results(tests: List[Type[TestBase]]):
data=sorted(
[
[
ERROR
if test.status.error_count > 0
else RESILIENT
if isResilient(test.status)
else VULNERABLE,
ERROR if test.status.error_count > 0 else RESILIENT if isResilient(test.status) else VULNERABLE,
f"{test.test_name + ' ':.<{50}}",
test.status.breach_count,
test.status.resilient_count,
Expand Down Expand Up @@ -288,14 +284,10 @@ def generate_summary(tests: List[Type[TestBase]]):
None
"""
resilient_tests_count = sum(isResilient(test.status) for test in tests)
failed_tests = [
f"{test.test_name}\n" if not isResilient(test.status) else "" for test in tests
]
failed_tests = [f"{test.test_name}\n" if not isResilient(test.status) else "" for test in tests]

total_tests_count = len(tests)
resilient_tests_percentage = (
resilient_tests_count / total_tests_count * 100 if total_tests_count > 0 else 0
)
resilient_tests_percentage = resilient_tests_count / total_tests_count * 100 if total_tests_count > 0 else 0

# Print a brief summary of the percentage of tests passed
print(
Expand All @@ -304,9 +296,7 @@ def generate_summary(tests: List[Type[TestBase]]):

# If there are failed tests, print the list of failed tests
if resilient_tests_count < total_tests_count:
print(
f"Your Model {BRIGHT_RED}failed{RESET} the following tests:\n{RED}{''.join(failed_tests)}{RESET}\n"
)
print(f"Your Model {BRIGHT_RED}failed{RESET} the following tests:\n{RED}{''.join(failed_tests)}{RESET}\n")


def setup_models_and_tests(
Expand All @@ -316,6 +306,7 @@ def setup_models_and_tests(
num_threads: int = 1,
tests: List[str] = None,
custom_tests: List[Type[TestBase]] = None,
artifacts_path: Optional[str] = None,
):
"""
Set up and validate the models, then run the tests.
Expand All @@ -334,6 +325,8 @@ def setup_models_and_tests(
A list of basic test names to be executed (default is None).
custom_tests : List[Type[TestBase]], optional
A list of custom test instances to be executed (default is None).
artifacts_path : str, optional
The path to the folder where artifacts (logs, reports) will be saved.

Returns
-------
Expand All @@ -343,20 +336,14 @@ def setup_models_and_tests(
try:
client_config = ClientConfig(tested_model)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(
f"Error accessing the Tested Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}"
)
logger.warning(f"Error accessing the Tested Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return

# Attack model setup
try:
attack_config = AttackConfig(
attack_client=ClientConfig(attack_model), attack_prompts_count=num_attempts
)
attack_config = AttackConfig(attack_client=ClientConfig(attack_model), attack_prompts_count=num_attempts)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(
f"Error accessing the Attack Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}"
)
logger.warning(f"Error accessing the Attack Model: {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return

# Run tests
Expand All @@ -366,4 +353,5 @@ def setup_models_and_tests(
threads_count=num_threads,
basic_tests=tests,
custom_tests=custom_tests,
artifacts_path=artifacts_path,
)
4 changes: 3 additions & 1 deletion src/llamator/attack_provider/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Generator, List
from typing import Generator, List, Optional

from ..client.attack_config import AttackConfig
from ..client.client_config import ClientConfig
Expand Down Expand Up @@ -95,12 +95,14 @@ def __init__(
attack_config: AttackConfig,
test_name: str = "Test Name",
test_description: str = "Test Description",
artifacts_path: Optional[str] = None,
):
self.test_name = test_name
self.test_description = test_description
self.client_config = client_config
self.attack_config = attack_config
self.status = TestStatus()
self.artifacts_path = artifacts_path

@abstractmethod
def run(self) -> Generator[StatusUpdate, None, None]:
Expand Down
14 changes: 5 additions & 9 deletions src/llamator/attack_provider/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@ def is_response_list(response: str, min_items_count: int = 3) -> bool:

# Check the simple list markers pattern
simple_markers_pattern = r"^\s*[-*>•◦▪]"
if (
len(re.findall(simple_markers_pattern, response, re.MULTILINE))
>= min_items_count
):
if len(re.findall(simple_markers_pattern, response, re.MULTILINE)) >= min_items_count:
return True

# Check the complex list markers pattern
complex_markers_pattern = r"^\s*(?:step|phase|stage|part|point|item|number|note|rule|section)*\s*(?:\d+\.\s+|[a-z][.:]\s+)"
if (
len(re.findall(complex_markers_pattern, response, re.MULTILINE | re.IGNORECASE))
>= min_items_count
):
complex_markers_pattern = (
r"^\s*(?:step|phase|stage|part|point|item|number|note|rule|section)*\s*(?:\d+\.\s+|[a-z][.:]\s+)"
)
if len(re.findall(complex_markers_pattern, response, re.MULTILINE | re.IGNORECASE)) >= min_items_count:
return True

return False
Expand Down
9 changes: 2 additions & 7 deletions src/llamator/attack_provider/work_progress_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(self, num_workers):
ProgressWorker(worker_id, progress_bar=enable_per_test_progress_bars)
for worker_id in range(self.num_workers)
]
self.queue_progress_bar = tqdm(
total=1, desc=f"{colorama.Style.BRIGHT}{'Test progress ':.<54}{RESET}"
)
self.queue_progress_bar = tqdm(total=1, desc=f"{colorama.Style.BRIGHT}{'Test progress ':.<54}{RESET}")
self.semaphore = threading.Semaphore(
self.num_workers
) # Used to ensure that at most this number of tasks are immediately pending waiting for free worker slot
Expand Down Expand Up @@ -104,10 +102,7 @@ def run(self, tasks, tasks_count=None):

with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# Pass each worker its own progress bar reference
futures = [
executor.submit(self.worker_function, worker_id, tasks)
for worker_id in range(self.num_workers)
]
futures = [executor.submit(self.worker_function, worker_id, tasks) for worker_id in range(self.num_workers)]
# Wait for all workers to finish
for future in futures:
future.result()
Expand Down
Loading
Loading