diff --git a/invariant/analyzer/runtime/utils/image.py b/invariant/analyzer/runtime/utils/image.py new file mode 100644 index 0000000..63a3d3e --- /dev/null +++ b/invariant/analyzer/runtime/utils/image.py @@ -0,0 +1,212 @@ +from anthropic import Anthropic +from typing import Tuple, Literal, List +import base64 +import io +from PIL import Image, ImageDraw + + +def add_rectangles_to_image( + image: str, + coordinates: list[tuple[int, int]], + color: str = "red", + border_width: int = 3, + box_width: int = 10, + image_format: Literal["jpeg", "png"] = "png", +) -> str: + """Add multiple rectangles to a base64 encoded image. + + Args: + image: The base64 encoded image. + coordinates: The coordinates of the rectangles. Between 0 and 1. + + Returns: + The base64 encoded image with the rectangles added. + """ + image = Image.open(io.BytesIO(base64.b64decode(image))) + draw = ImageDraw.Draw(image) + + for coordinate in coordinates: + draw.rectangle( + ( + coordinate[0] * image.width - box_width // 2, + coordinate[1] * image.height - box_width // 2, + coordinate[0] * image.width + box_width // 2, + coordinate[1] * image.height + box_width // 2, + ), + outline=color, + width=border_width, + ) + + # Convert the image back to a base64 string + buffered = io.BytesIO() + image.save(buffered, format=image_format) + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return base64_image + + +CLICK_AVAILABLE_SYSTEM_PROMPT = """ +You are an advanced image analysis model tasked with evaluating whether a click on a user interface element violates a given policy. + +### Task: +You will be provided with an image showing a user interface and a policy. In the image, there will be a red rectangle that marks the location of a user click. You must analyze the image and the policy to determine if the clicked element (indicated by the red rectangle) violates the policy. + +### Key Points: +- Focus only on the red rectangle, which represents the location of the click. This rectangle is not necessarily surrounding an element but indicates where the click occurred. +- Use the policy provided and the context from the image to decide whether the click violates the policy. + +### Policy Format: +The policy will be provided as a string enclosed in and tags. + +### Output Format: +Your response should always be structured in JSON with the following keys: +- `allowed`: # True if the click is allowed, False if the click violates the policy. +- `reason`: # Explanation for the decision (allowed or not allowed). + +### Example Input/Output: +INPUT: +{ + "image": IMAGE, + "policy": " Never click `Buy` buttons. " +} +RESPONSE: +{ + "allowed": true, + "reason": "This click is allowed as it does not click a `Buy` button." +} + +INPUT: +{ + "image": IMAGE, + "policy": " Never click `Operations` buttons. " +} +RESPONSE: +{ + "allowed": false, + "reason": "This click is not allowed as it clicks the `Operations` button." +} +""" + +def format_image( + base64_image: str, + prompt: str | None = None, + image_format: Literal["jpeg", "png"] = "png", +): + """Given a base64 encoded image, return a message that can be sent to the LLM. + + Args: + base64_image: The base64 encoded image. + prompt: The prompt to add to the image. + image_format: The format of the image. + + Returns: + A message that can be sent to the LLM. + """ + content = [] + + # Add the image to the content + content.append( + { + "type": "image", + "source": { + "type": "base64", + "data": base64_image, + "media_type": f"image/{image_format}", + }, + } + ) + + # If the prompt is set, add it to the content + if prompt is not None: + content.append({"type": "text", "text": prompt}) + + return {"role": "user", "content": content} + + +def format_policy(policy: str) -> dict: + """Given a policy, return a message that can be sent to the LLM. + + Args: + policy: The policy to format. + + Returns: + A message that can be sent to the LLM. + """ + return {"role": "user", "content": f' {policy} '} + + +class ClaudeModel: + def __init__( + self, + api_key: str, + model_name: str = "claude-3-5-sonnet-20240620", + max_tokens: int = 1000, + system_prompt: str = CLICK_AVAILABLE_SYSTEM_PROMPT, + image_format: Literal["jpeg", "png"] = "png", + bbox_config: dict | None = None, + ): + """ + Args: + api_key: The API key for the Anthropic API. + model_name: The name of the model to use. + max_tokens: The maximum number of tokens to generate. + system_prompt: The system prompt to use. + """ + self.api_key = api_key + self.model_name = model_name + self.max_tokens = max_tokens + self.system_prompt = system_prompt + self.image_format = image_format + + + self.client = Anthropic(api_key=self.api_key) + self.bbox_config = bbox_config or { + "color": "red", + "border_width": 5, + "box_width": 75, + } + + def _format_request(self, image: Image.Image, policy: str) -> List[dict]: + image_message = format_image(image, image_format=self.image_format) + policy_message = format_policy(policy) + return [image_message, policy_message] + + def _get_response(self, messages: List[dict]) -> str: + """Get a response from the model using the stable API. + + Args: + messages: The messages to send to the model. + + Returns: + The response from the model. + """ + response = self.client.messages.create( + model=self.model_name, + max_tokens=self.max_tokens, + system=self.system_prompt, + messages=messages, + ) + + return response.content[0].text + + def evaluate(self, image: str, policy: str, click_coordinates: Tuple[float, float]) -> str: + """Evaluate whether a click is allowed. + + Args: + image: Base64 encoded image. + policy: The policy to evaluate. + click_coordinates: The coordinates of the click. Values are between 0 and 1. + + Returns: + json object with fields according to prompt + """ + assert 0 <= click_coordinates[0] <= 1 and 0 <= click_coordinates[1] <= 1, "Click coordinates must be between 0 and 1" + print(image[:10], len(image), click_coordinates) + image = add_rectangles_to_image( + image, + [click_coordinates], + image_format=self.image_format, + **self.bbox_config, + ) + messages = self._format_request(image, policy) + return self._get_response(messages) diff --git a/invariant/analyzer/stdlib/invariant/detectors/__init__.py b/invariant/analyzer/stdlib/invariant/detectors/__init__.py index 0e8c6e1..d2db455 100644 --- a/invariant/analyzer/stdlib/invariant/detectors/__init__.py +++ b/invariant/analyzer/stdlib/invariant/detectors/__init__.py @@ -4,4 +4,5 @@ from invariant.analyzer.stdlib.invariant.detectors.code import * from invariant.analyzer.stdlib.invariant.detectors.pii import * from invariant.analyzer.stdlib.invariant.detectors.copyright import * +from invariant.analyzer.stdlib.invariant.detectors.image import image_policy_violations diff --git a/invariant/analyzer/stdlib/invariant/detectors/image.py b/invariant/analyzer/stdlib/invariant/detectors/image.py new file mode 100644 index 0000000..ae1dcaa --- /dev/null +++ b/invariant/analyzer/stdlib/invariant/detectors/image.py @@ -0,0 +1,79 @@ +from invariant.analyzer.runtime.functions import cache +from invariant.analyzer.runtime.utils.image import ClaudeModel +from typing import Dict, Tuple, Optional +import json +import os + + +# Global instance of the ClaudeModel +IMAGE_POLICY_MODEL = None + +@cache +def image_policy_violations( + data: str, + policy: str, + click_coordinates: Tuple[float, float], + api_key: Optional[str] = None, + model_name: str = "claude-3-7-sonnet-20250219", + **config: Dict +) -> bool: + """ + Detects policy violations in images using Claude's vision capabilities. + + Args: + data: The image data as a base64 string. It may have a prefix like data:image/jpeg;base64, in which case it will be stripped. + policy: The policy to evaluate against the image + click_coordinates: coordinates for click evaluation (x, y) between 0 and 1 + api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY environment variable) + model_name: The Claude model to use + **config: Additional configuration options + + Returns: + bool: True if the policy is violated, False otherwise + """ + global IMAGE_POLICY_MODEL + + # Get API key from environment if not provided + if api_key is None: + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + return True + + # Initialize the model if not already done + if IMAGE_POLICY_MODEL is None: + IMAGE_POLICY_MODEL = ClaudeModel( + api_key=api_key, + model_name=model_name, + **config + ) + + # Strip base64 prefix if present + if data.startswith('" + # Sample click coordinates + self.sample_coordinates = (0.5, 0.5) + self.api_key = "sk-ant-api03-1234567890" + + def _is_unsafe_click(self, screenshot, policy, click_coordinates): + # This is a hack for mocking the calls from inside DSL + result = image_policy_violations(screenshot, policy, click_coordinates, self.api_key) + return result + + @unittest.mock.patch('invariant.analyzer.stdlib.invariant.detectors.image._evaluate_single_image') + def test_image_policy_violations(self, mock_evaluate): + # Configure the mock to return different values based on inputs + def mock_evaluate_side_effect(image, policy, click_coordinates): + return True + + mock_evaluate.side_effect = mock_evaluate_side_effect + + # Test with a policy that should trigger a violation + policy_str_violation = """ + from invariant.detectors.image import image_policy_violations + + raise PolicyViolation("image policy violation", msg) if: + (msg: Message) + input.is_unsafe_click(msg.content, "Harmful content policy", (0.5, 0.5)) + """ + + trace = [user(self.sample_image_data)] + + # Test with policy that should trigger violation + policy = Policy.from_string(policy_str_violation) + result = policy.analyze(trace, is_unsafe_click=self._is_unsafe_click) + self.assertEqual(len(result.errors), 1) + + # Verify the mock was called with expected arguments + mock_evaluate.assert_called() + + @unittest.mock.patch('invariant.analyzer.stdlib.invariant.detectors.image._evaluate_single_image') + def test_image_policy_with_data_uri_prefix(self, mock_evaluate): + # Configure the mock to always return False (no violation) + mock_evaluate.return_value = False + + policy_str = """ + from invariant.detectors.image import image_policy_violations + + raise PolicyViolation("image policy violation", msg) if: + (msg: Message) + input.is_unsafe_click(msg.content, "Test policy", (0.5, 0.5)) + """ + + # Create a trace with the prefixed image data + trace = [user(self.sample_image_data_with_prefix)] + + policy = Policy.from_string(policy_str) + result = policy.analyze(trace, is_unsafe_click=self._is_unsafe_click) + self.assertEqual(len(result.errors), 0) + + # Verify the mock was called + mock_evaluate.assert_called()