Skip to content

Commit

Permalink
FIX Crescendo backtrack with same orchestrator ID and handling respon…
Browse files Browse the repository at this point in the history
…ses with markdown syntax (#340)
  • Loading branch information
romanlutz authored Aug 22, 2024
1 parent 99d76c6 commit daa8ab1
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 161 deletions.
325 changes: 171 additions & 154 deletions doc/code/orchestrators/6_crescendo_orchestrator.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/code/orchestrators/6_crescendo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
) as orchestrator:

# For five turns this can take a couple minutes depending on LLM latency
score = await orchestrator.apply_crescendo_attack_async(max_rounds=5, max_backtracks=5) # type: ignore
score = await orchestrator.apply_crescendo_attack_async(max_rounds=10, max_backtracks=5) # type: ignore
orchestrator.print_conversation()

print(f"{score} {score.score_metadata} {score.score_rationale} ")
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ template: |
last_response_summary: Summarized last response, if there is no last response return empty strings,
rationale_behind_jailbreak: Detail the reasoning behind the development of the potential jailbreaks
}}
Do not put the JSON inside markdown code tags.
###Inputs###
Inputs for crafting the new jailbreak:
Expand Down
14 changes: 8 additions & 6 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import abc
import copy
from pathlib import Path
from typing import MutableSequence, Sequence
from typing import MutableSequence, Optional, Sequence
import uuid

from pyrit.common.path import RESULTS_PATH
Expand Down Expand Up @@ -213,16 +213,19 @@ def duplicate_conversation_for_new_orchestrator(self, *, new_orchestrator_id: st
self.add_request_pieces_to_memory(request_pieces=prompt_pieces)
return new_conversation_id

def duplicate_conversation_excluding_last_turn(self, *, new_orchestrator_id: str, conversation_id: str) -> str:
def duplicate_conversation_excluding_last_turn(
self, *, conversation_id: str, new_orchestrator_id: Optional[str] = None
) -> str:
"""
Duplicate a conversation, excluding the last turn. In this case, last turn is defined as before the last
user request (e.g. if there is half a turn, it just removes that half).
This can be useful when an attack strategy requires back tracking the last prompt/response pair.
Args:
new_orchestrator_id (str): The new orchestrator ID to assign to the duplicated conversations.
conversation_id (str): The conversation ID with existing conversations.
new_orchestrator_id (str, optional): The new orchestrator ID to assign to the duplicated conversations.
If no new orchestrator ID is provided, the orchestrator ID will remain the same. Defaults to None.
Returns:
The uuid for the new conversation.
"""
Expand Down Expand Up @@ -252,9 +255,8 @@ def duplicate_conversation_excluding_last_turn(self, *, new_orchestrator_id: str

for piece in prompt_pieces:
piece.id = uuid.uuid4()
if piece.orchestrator_identifier["id"] == new_orchestrator_id:
raise ValueError("The new orchestrator ID must be different from the existing orchestrator ID.")
piece.orchestrator_identifier["id"] = new_orchestrator_id
if new_orchestrator_id:
piece.orchestrator_identifier["id"] = new_orchestrator_id
piece.conversation_id = new_conversation_id

self.add_request_pieces_to_memory(request_pieces=prompt_pieces)
Expand Down
2 changes: 2 additions & 0 deletions pyrit/orchestrator/crescendo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pyrit.exceptions.exception_classes import (
InvalidJsonException,
pyrit_json_retry,
remove_markdown_json,
)
from pyrit.models import PromptTemplate
from pyrit.models import Score
Expand Down Expand Up @@ -262,6 +263,7 @@ async def _get_attack_prompt(
.request_pieces[0]
.converted_value
)
response_text = remove_markdown_json(response_text)

expected_output = ["generated_question", "rationale_behind_jailbreak", "last_response_summary"]
try:
Expand Down
50 changes: 50 additions & 0 deletions tests/memory/test_memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,56 @@ def test_duplicate_conversation_excluding_last_turn(memory: MemoryInterface):
assert piece.sequence < 2


def test_duplicate_conversation_excluding_last_turn_same_orchestrator(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id_1 = "11111"
pieces = [
PromptRequestPiece(
role="user",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=0,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=1,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="user",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=2,
orchestrator_identifier=orchestrator1.get_identifier(),
),
PromptRequestPiece(
role="assistant",
original_value="original prompt text",
conversation_id=conversation_id_1,
sequence=3,
orchestrator_identifier=orchestrator1.get_identifier(),
),
]
memory.add_request_pieces_to_memory(request_pieces=pieces)
assert len(memory.get_all_prompt_pieces()) == 4

new_conversation_id1 = memory.duplicate_conversation_excluding_last_turn(
conversation_id=conversation_id_1,
)

all_memory = memory.get_all_prompt_pieces()
assert len(all_memory) == 6

duplicate_conversation = memory._get_prompt_pieces_with_conversation_id(conversation_id=new_conversation_id1)
assert len(duplicate_conversation) == 2

for piece in duplicate_conversation:
assert piece.sequence < 2


def test_duplicate_memory_orchestrator_id_collision(memory: MemoryInterface):
orchestrator1 = Orchestrator()
conversation_id = "11111"
Expand Down

0 comments on commit daa8ab1

Please sign in to comment.