diff --git a/docs/source/configurable.rst b/docs/source/configurable.rst index f58fda94..90d2d5de 100644 --- a/docs/source/configurable.rst +++ b/docs/source/configurable.rst @@ -46,6 +46,8 @@ Let's take a look at the core config. parallel_attempts: false lite: true show_z: false + enable_experimental: false + max_workers: 500 run: seed: @@ -93,6 +95,7 @@ such as ``show_100_pass_modules``. * ``narrow_output`` - Support output on narrower CLIs * ``show_z`` - Display Z-scores and visual indicators on CLI. It's good, but may be too much info until one has seen garak run a couple of times * ``enable_experimental`` - Enable experimental function CLI flags. Disabled by default. Experimental functions may disrupt your installation and provide unusual/unstable results. Can only be set by editing core config, so a git checkout of garak is recommended for this. +* ``max_workers`` - Cap on how many parallel workers can be requested. When raising this in order to use higher parallelisation, keep an eye on system resources (e.g. `ulimit -n 4026` on Linux) ``run`` config items """""""""""""""""""" diff --git a/garak/cli.py b/garak/cli.py index e0e37df1..45311999 100644 --- a/garak/cli.py +++ b/garak/cli.py @@ -64,6 +64,16 @@ def main(arguments=None) -> None: import argparse + def worker_count_validation(workers): + iworkers = int(workers) + if iworkers <= 0: + raise argparse.ArgumentTypeError("Need >0 workers (int)" % workers) + if iworkers > _config.system.max_workers: + raise argparse.ArgumentTypeError( + "Parallel worker count capped at %s (config.system.max_workers)" % _config.system.max_workers + ) + return iworkers + parser = argparse.ArgumentParser( prog="python -m garak", description="LLM safety & security scanning tool", @@ -92,15 +102,15 @@ def main(arguments=None) -> None: ) parser.add_argument( "--parallel_requests", - type=int, + type=worker_count_validation, default=_config.system.parallel_requests, help="How many generator requests to launch in parallel for a given prompt. Ignored for models that support multiple generations per call.", ) parser.add_argument( "--parallel_attempts", - type=int, + type=worker_count_validation, default=_config.system.parallel_attempts, - help="How many probe attempts to launch in parallel.", + help="How many probe attempts to launch in parallel. Raise this for faster runs when using non-local models.", ) parser.add_argument( "--skip_unknown", @@ -484,7 +494,9 @@ def main(arguments=None) -> None: if has_changes: exit(1) # exit with error code to denote changes else: - print("No revisions applied. Please verify options provided for `--fix`") + print( + "No revisions applied. Please verify options provided for `--fix`" + ) elif args.report: from garak.report import Report diff --git a/garak/generators/base.py b/garak/generators/base.py index e09d4f30..fd2c269c 100644 --- a/garak/generators/base.py +++ b/garak/generators/base.py @@ -12,6 +12,7 @@ from garak import _config from garak.configurable import Configurable +from garak.exception import GarakException import garak.resources.theme @@ -162,13 +163,27 @@ def generate( ) multi_generator_bar.set_description(self.fullname[:55]) - with Pool(_config.system.parallel_requests) as pool: - for result in pool.imap_unordered( - self._call_model, [prompt] * generations_this_call - ): - self._verify_model_result(result) - outputs.append(result[0]) - multi_generator_bar.update(1) + pool_size = min( + generations_this_call, + _config.system.parallel_requests, + _config.system.max_workers, + ) + + try: + with Pool(pool_size) as pool: + for result in pool.imap_unordered( + self._call_model, [prompt] * generations_this_call + ): + self._verify_model_result(result) + outputs.append(result[0]) + multi_generator_bar.update(1) + except OSError as o: + if o.errno == 24: + msg = "Parallelisation limit hit. Try reducing parallel_requests or raising limit (e.g. ulimit -n 4096)" + logging.critical(msg) + raise GarakException(msg) from o + else: + raise (o) else: generation_iterator = tqdm.tqdm( diff --git a/garak/probes/base.py b/garak/probes/base.py index b3fbdb02..dbf880f0 100644 --- a/garak/probes/base.py +++ b/garak/probes/base.py @@ -17,7 +17,7 @@ from garak import _config from garak.configurable import Configurable -from garak.exception import PluginConfigurationError +from garak.exception import GarakException import garak.attempt import garak.resources.theme @@ -178,17 +178,31 @@ def _execute_all(self, attempts) -> Iterable[garak.attempt.Attempt]: attempt_bar = tqdm.tqdm(total=len(attempts), leave=False) attempt_bar.set_description(self.probename.replace("garak.", "")) - with Pool(_config.system.parallel_attempts) as attempt_pool: - for result in attempt_pool.imap_unordered( - self._execute_attempt, attempts - ): - _config.transient.reportfile.write( - json.dumps(result.as_dict()) + "\n" - ) - attempts_completed.append( - result - ) # these will be out of original order - attempt_bar.update(1) + pool_size = min( + len(attempts), + _config.system.parallel_attempts, + _config.system.max_workers, + ) + + try: + with Pool(pool_size) as attempt_pool: + for result in attempt_pool.imap_unordered( + self._execute_attempt, attempts + ): + _config.transient.reportfile.write( + json.dumps(result.as_dict()) + "\n" + ) + attempts_completed.append( + result + ) # these will be out of original order + attempt_bar.update(1) + except OSError as o: + if o.errno == 24: + msg = "Parallelisation limit hit. Try reducing parallel_attempts or raising limit (e.g. ulimit -n 4096)" + logging.critical(msg) + raise GarakException(msg) from o + else: + raise (o) else: attempt_iterator = tqdm.tqdm(attempts, leave=False) diff --git a/garak/resources/garak.core.yaml b/garak/resources/garak.core.yaml index 72f7caa8..00e7e9c3 100644 --- a/garak/resources/garak.core.yaml +++ b/garak/resources/garak.core.yaml @@ -7,6 +7,7 @@ system: lite: true show_z: false enable_experimental: false + max_workers: 500 run: seed: