From 07d49c8dcea10ea5c364ee45274809a9440d0b48 Mon Sep 17 00:00:00 2001 From: Florencio Cano Gabarda Date: Fri, 21 Feb 2025 15:26:19 +0100 Subject: [PATCH] Add resume functionality --- garak/_config.py | 2 +- garak/cli.py | 17 +++++++++++++++++ garak/command.py | 22 ++++++++++++++-------- garak/probes/base.py | 5 ++++- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/garak/_config.py b/garak/_config.py index 1997c5bcc..ad61b9ac0 100644 --- a/garak/_config.py +++ b/garak/_config.py @@ -28,7 +28,7 @@ from garak import __version__ as version system_params = ( - "verbose narrow_output parallel_requests parallel_attempts skip_unknown".split() + "verbose narrow_output parallel_requests parallel_attempts skip_unknown resume".split() ) run_params = "seed deprefix eval_threshold generations probe_tags interactive".split() plugins_params = "model_type model_name extended_detectors".split() diff --git a/garak/cli.py b/garak/cli.py index e0e37df18..c8736c8dc 100644 --- a/garak/cli.py +++ b/garak/cli.py @@ -107,6 +107,13 @@ def main(arguments=None) -> None: action="store_true", help="allow skip of unknown probes, detectors, or buffs", ) + parser.add_argument( + "--resume", + "-R", + type=str, + default=None, + help="resume previous unfinnished scan", + ) ## RUN parser.add_argument( @@ -367,6 +374,16 @@ def main(arguments=None) -> None: if "buffs" in args: _config.plugins.buff_spec = args.buffs + # Parse existing attempts + if _config.system.resume: + import json + _config.system.previous_attempts = [] + with open(_config.system.resume, 'r') as fin: + for line in fin: + attempt_json = json.loads(line.strip()) + if attempt_json['entry_type'] == 'attempt': + _config.system.previous_attempts.append((attempt_json['seq'], attempt_json['prompt'])) + # base config complete if hasattr(_config.run, "seed") and isinstance(_config.run.seed, int): diff --git a/garak/command.py b/garak/command.py index bc9da83a0..4c96bb75f 100644 --- a/garak/command.py +++ b/garak/command.py @@ -74,15 +74,21 @@ def start_run(): f"Can't create reporting directory {report_path}, quitting" ) from e - filename = f"garak.{_config.transient.run_id}.report.jsonl" - if not _config.reporting.report_prefix: - filename = f"garak.{_config.transient.run_id}.report.jsonl" + if _config.system.resume: + _config.transient.report_filename = _config.system.resume + _config.transient.reportfile = open( + _config.transient.report_filename, "a", buffering=1, encoding="utf-8" + ) else: - filename = _config.reporting.report_prefix + ".report.jsonl" - _config.transient.report_filename = str(report_path / filename) - _config.transient.reportfile = open( - _config.transient.report_filename, "w", buffering=1, encoding="utf-8" - ) + filename = f"garak.{_config.transient.run_id}.report.jsonl" + if not _config.reporting.report_prefix: + filename = f"garak.{_config.transient.run_id}.report.jsonl" + else: + filename = _config.reporting.report_prefix + ".report.jsonl" + _config.transient.report_filename = str(report_path / filename) + _config.transient.reportfile = open( + _config.transient.report_filename, "w", buffering=1, encoding="utf-8" + ) setup_dict = {"entry_type": "start_run setup"} for k, v in _config.__dict__.items(): if k[:2] != "__" and type(v) in ( diff --git a/garak/probes/base.py b/garak/probes/base.py index b3fbdb025..8fbb7ae77 100644 --- a/garak/probes/base.py +++ b/garak/probes/base.py @@ -209,7 +209,10 @@ def probe(self, generator) -> Iterable[garak.attempt.Attempt]: attempts_todo: Iterable[garak.attempt.Attempt] = [] prompts = list(self.prompts) for seq, prompt in enumerate(prompts): - attempts_todo.append(self._mint_attempt(prompt, seq)) + if hasattr(_config.system, 'previous_attempts') and (seq, prompt) in _config.system.previous_attempts: + continue + else: + attempts_todo.append(self._mint_attempt(prompt, seq)) # buff hook if len(_config.buffmanager.buffs) > 0: