Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typo in ReadMe + client_kwargs for OpenAI-compatible APIs in Detector + prompt on Decoder #41

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Code:

```python
# content of tests/test_weather.py
import invariant.testing.testing.functional as F
import invariant.testing.functional as F
from invariant.testing import Trace, assert_equals

def test_weather():
Expand Down Expand Up @@ -388,31 +388,37 @@ This section provides a detailed overview of the analyzer's components, includin

**Table of Contents**

- [Use Cases](#use-cases)
- [Why Agent Debugging Matters](#why-agent-debugging-matters)
- [Why Agent Security Matters](#why-agent-security-matters)
- [Features](#features)
- [Getting Started](#getting-started)
- [Use Cases](#use-cases-1)
- [Debugging Coding Agents](#debugging-coding-agents)
- [Prevent Data Leaks In Your Productivity Agent](#prevent-data-leaks-in-your-productivity-agent)
- [Detect Vulnerabilities in Your Code Generation Agent](#detect-vulnerabilities-in-your-code-generation-agent)
- [Enforce Access Control In Your RAG-based Chat Agent](#enforce-access-control-in-your-rag-based-chat-agent)
- [Documentation](#documentation)
- [Policy Language](#policy-language)
- [Example Rule](#example-rule)
- [Trace Format](#trace-format)
- [Trace Example](#trace-example)
- [Debugging and Printing Inputs](#debugging-and-printing-inputs)
- [Custom Error Types](#custom-error-types)
- [Predicates](#predicates)
- [Semantic Tool Call Matching](#semantic-tool-call-matching)
- [Integration](#integration)
- [Analyzing Agent Traces](#analyzing-agent-traces)
- [Real-Time Monitoring of an OpenAI Agent](#real-time-monitoring-of-an-openai-agent)
- [Real-Time Monitoring of a `langchain` Agent](#real-time-monitoring-of-a-langchain-agent)
- [Automatic Issue Resolution (Handlers)](#automatic-issue-resolution-handlers)
- [Roadmap](#roadmap)
- [Quickstart](#quickstart)
- [Table Of Contents](#table-of-contents)
- [Testing](#testing)
- [A quick example](#a-quick-example)
- [Testing Features](#testing-features)
- [Explorer](#explorer)
- [Analyzer](#analyzer)
- [Use Cases](#use-cases)
- [Why Agent Debugging Matters](#why-agent-debugging-matters)
- [Why Agent Security Matters](#why-agent-security-matters)
- [Analyzer Features](#analyzer-features)
- [Getting Started](#getting-started)
- [Use Cases](#use-cases-1)
- [Debugging Coding Agents](#debugging-coding-agents)
- [Prevent Data Leaks In Your Productivity Agent](#prevent-data-leaks-in-your-productivity-agent)
- [Detect Vulnerabilities in Your Code Generation Agent](#detect-vulnerabilities-in-your-code-generation-agent)
- [Enforce Access Control In Your RAG-based Chat Agent](#enforce-access-control-in-your-rag-based-chat-agent)
- [Analyzer Documentation](#analyzer-documentation)
- [Policy Language](#policy-language)
- [Example Rule](#example-rule)
- [Trace Format](#trace-format)
- [Trace Example](#trace-example)
- [Debugging and Printing Inputs](#debugging-and-printing-inputs)
- [Custom Error Types](#custom-error-types)
- [Predicates](#predicates)
- [Semantic Tool Call Matching](#semantic-tool-call-matching)
- [Integration](#integration)
- [Analyzing Agent Traces](#analyzing-agent-traces)
- [Error Localization](#error-localization)
- [Real-Time Monitoring of an OpenAI Agent](#real-time-monitoring-of-an-openai-agent)
- [Real-Time Monitoring of a `langchain` Agent](#real-time-monitoring-of-a-langchain-agent)

### Policy Language

Expand Down
9 changes: 8 additions & 1 deletion invariant/testing/custom_types/invariant_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Literal, Union

from _pytest.python_api import ApproxBase

from invariant.testing.scorers.code import execute, is_valid_json, is_valid_python
from invariant.testing.scorers.llm.classifier import Classifier
from invariant.testing.scorers.llm.detector import Detector
Expand Down Expand Up @@ -308,6 +309,7 @@ def extract(
model: str = "gpt-4o",
client: str = "OpenAI",
use_cached_result: bool = True,
client_kwargs={},
) -> list[InvariantString]:
"""Extract values from the underlying string using an LLM.

Expand All @@ -322,7 +324,12 @@ def extract(
use_cached_result (bool): Whether to use a cached result if available.

"""
llm_detector = Detector(predicate_rule=predicate, model=model, client=client)
llm_detector = Detector(
predicate_rule=predicate,
model=model,
client=client,
client_kwargs=client_kwargs,
)
detections = llm_detector.detect(self.value, use_cached_result)
ret = []
for substr, r in detections:
Expand Down
4 changes: 2 additions & 2 deletions invariant/testing/scorers/llm/clients/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ class ClientFactory:
"""Factory for creating LLM clients."""

@staticmethod
def get(client_name: str) -> LLMClient:
def get(client_name: str, client_kwargs: dict) -> LLMClient:
"""Get an LLM client by name."""
if client_name == SupportedClients.OPENAI:
return OpenAIClient()
return OpenAIClient(client_kwargs)
if client_name == SupportedClients.ANTHROPIC:
return AnthropicClient()
raise ValueError(f"Invalid client name: {client_name}")
4 changes: 2 additions & 2 deletions invariant/testing/scorers/llm/clients/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class OpenAIClient(LLMClient):
"""Client for interacting with OpenAI."""

def __init__(self):
def __init__(self, client_kwargs: dict):
# Add OPENAI_API_KEY to your environment variables.
self.client = openai.OpenAI()
self.client = openai.OpenAI(**client_kwargs)

def get_name(self) -> str:
return "OpenAI"
Expand Down
18 changes: 15 additions & 3 deletions invariant/testing/scorers/llm/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import logging
from typing import Any, Tuple

from invariant.testing.cache import CacheManager
from invariant.testing.custom_types.addresses import Range
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
from pydantic import BaseModel

from invariant.testing.cache import CacheManager
from invariant.testing.custom_types.addresses import Range

from .clients.anthropic_client import AnthropicClient
from .clients.client import SupportedClients
from .clients.client_factory import ClientFactory
Expand All @@ -28,6 +29,16 @@
Detections:
[("1", "Zurich"), ("2", "Geneva"), ("2", "Bern"), ("3", "Bern")]

You response must be in the following format:
{{
"detections": [
{{"line": 1, "substring": "Zurich"}},
{{"line": 2, "substring": "Geneva"}},
{{"line": 2, "substring": "Bern"}},
{{"line": 3, "substring": "Bern"}}
]
}}

Use the following predicate rule to find the detections in the next user message:
{predicate_rule}
"""
Expand Down Expand Up @@ -99,6 +110,7 @@ def __init__(
predicate_rule: str,
model: str = "gpt-4o",
client: str = "OpenAI",
client_kwargs: dict = {},
):
"""Instantiate Detector object.

Expand All @@ -114,7 +126,7 @@ def __init__(
"""
self.model = model
self.prompt = self._get_prompt(predicate_rule, client)
self.client = ClientFactory.get(client)
self.client = ClientFactory.get(client, client_kwargs)
self.cache_manager = CacheManager(
CACHE_DIRECTORY_LLM_DETECTOR, expiry=CACHE_TIMEOUT
)
Expand Down
Loading