diff --git a/adala/agents/base.py b/adala/agents/base.py index 4bba6374..62626b6c 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -1,4 +1,5 @@ import logging +import traceback from pydantic import ( BaseModel, Field, @@ -7,7 +8,7 @@ SerializeAsAny, ) from abc import ABC -from typing import Optional, Dict, Union, Tuple +from typing import Optional, Dict, Union, Tuple, List from rich import print import yaml @@ -15,9 +16,11 @@ from adala.environments.static_env import StaticEnvironment from adala.runtimes.base import Runtime, AsyncRuntime from adala.runtimes._openai import OpenAIChatRuntime -from adala.skills._base import Skill +from adala.skills._base import Skill, TransformSkill from adala.memories.base import Memory from adala.skills.skillset import SkillSet, LinearSkillSet +from adala.skills.collection.prompt_improvement import ImprovedPromptResponse + from adala.utils.logs import ( print_dataframe, print_text, @@ -26,7 +29,7 @@ is_running_in_jupyter, ) from adala.utils.internal_data import InternalDataFrame - +from adala.utils.types import BatchData logger = logging.getLogger(__name__) @@ -61,7 +64,7 @@ class Agent(BaseModel, ABC): default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")} ) default_runtime: str = "default" - teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field( + teacher_runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field( default_factory=lambda: {"default": None} ) default_teacher_runtime: str = "default" @@ -118,7 +121,7 @@ def skills_validator(cls, v) -> SkillSet: f"skills must be of type SkillSet or Skill, but received type {type(v)}" ) - @field_validator("runtimes", mode="before") + @field_validator("runtimes", "teacher_runtimes", mode="before") def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]: """ Validates and creates runtimes @@ -393,6 +396,48 @@ def learn( print_text("Train is done!") + async def arefine_skill( + self, + skill_name: str, + input_variables: List[str], + batch_data: Optional[BatchData] = None, + ) -> ImprovedPromptResponse: + """ + beta v2 of Agent.learn() that is: + - compatible with the newer LiteLLM runtimes + - compatible with the newer response_model output formats for skills + - returns chain of thought reasoning in a legible format + + Limitations so far: + - single skill at a time + - only returns the improved input_template, doesn't modify the skill in place + - doesn't use examples/feedback + - no iterations/variable cost + """ + + skill = self.skills[skill_name] + if not isinstance(skill, TransformSkill): + raise ValueError(f"Skill {skill_name} is not a TransformSkill") + + # get default runtimes + runtime = self.get_runtime() + teacher_runtime = self.get_teacher_runtime() + + # get inputs + # TODO: replace it with async environment.get_data_batch() + if batch_data is None: + predictions = None + else: + inputs = InternalDataFrame.from_records(batch_data or []) + predictions = await self.skills.aapply(inputs, runtime=runtime) + + response = await skill.aimprove( + predictions=predictions, + teacher_runtime=teacher_runtime, + target_input_variables=input_variables, + ) + return response + def create_agent_from_dict(json_dict: Dict): """ diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 06976d27..fc69f5c9 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -277,15 +277,21 @@ def record_to_record( usage = completion.usage dct = to_jsonable_python(response) except IncompleteOutputException as e: + logger.error(f"Incomplete output error: {str(e)}") + logger.error(f"Traceback:\n{traceback.format_exc()}") usage = e.total_usage dct = _log_llm_exception(e) except InstructorRetryException as e: + logger.error(f"Instructor retry error: {str(e)}") + logger.error(f"Traceback:\n{traceback.format_exc()}") usage = e.total_usage # get root cause error from retries n_attempts = e.n_attempts e = e.__cause__.last_attempt.exception() dct = _log_llm_exception(e) except Exception as e: + logger.error(f"Other error: {str(e)}") + logger.error(f"Traceback:\n{traceback.format_exc()}") # usage = e.total_usage # not available here, so have to approximate by hand, assuming the same error occurred each time n_attempts = retries.stop.max_attempt_number @@ -485,8 +491,41 @@ async def record_to_record( extra_fields: Optional[Dict[str, Any]] = None, field_schema: Optional[Dict] = None, instructions_first: bool = True, + response_model: Optional[Type[BaseModel]] = None, ) -> Dict[str, str]: - raise NotImplementedError("record_to_record is not implemented") + """ + Execute LiteLLM request given record and templates for input, + instructions and output. + + Args: + record: Record to be used for input, instructions and output templates. + input_template: Template for input message. + instructions_template: Template for instructions message. + output_template: Template for output message. + extra_fields: Extra fields to be used in templates. + field_schema: Field jsonschema to be used for parsing templates. + instructions_first: If True, instructions will be sent before input. + + Returns: + Dict[str, str]: The processed record. + """ + # Create a single-row DataFrame from the input record + input_df = InternalDataFrame([record]) + + # Use the batch_to_batch method to process the single-row DataFrame + output_df = await self.batch_to_batch( + input_df, + input_template=input_template, + instructions_template=instructions_template, + output_template=output_template, + extra_fields=extra_fields, + field_schema=field_schema, + instructions_first=instructions_first, + response_model=response_model, + ) + + # Extract the single row from the output DataFrame and convert it to a dictionary + return output_df.iloc[0].to_dict() class LiteLLMVisionRuntime(LiteLLMChatRuntime): diff --git a/adala/skills/_base.py b/adala/skills/_base.py index b86a5a41..ea60dbf1 100644 --- a/adala/skills/_base.py +++ b/adala/skills/_base.py @@ -1,5 +1,6 @@ import logging import string +import traceback from pydantic import ( BaseModel, Field, @@ -479,6 +480,50 @@ def improve( self.instructions = new_prompt + async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables: List[str], predictions: Optional[InternalDataFrame] = None): + """ + Improves the skill. + """ + + from adala.skills.collection.prompt_improvement import PromptImprovementSkill, ImprovedPromptResponse, ErrorResponseModel, PromptImprovementSkillResponseModel + response_dct = {} + try: + prompt_improvement_skill = PromptImprovementSkill( + skill_to_improve=self, + input_variables=target_input_variables, + ) + if predictions is None: + input_df = InternalDataFrame() + else: + input_df = predictions + response_df = await prompt_improvement_skill.aapply( + input=input_df, + runtime=teacher_runtime, + ) + + # awkward to go from response model -> dict -> df -> dict -> response model + response_dct = response_df.iloc[0].to_dict() + + # unflatten the response + if response_dct.pop("_adala_error", False): + output = ErrorResponseModel(**response_dct) + else: + output = PromptImprovementSkillResponseModel(**response_dct) + + except Exception as e: + logger.error(f"Error improving skill: {e}. Traceback: {traceback.format_exc()}") + output = ErrorResponseModel( + _adala_message=str(e), + _adala_details=traceback.format_exc(), + ) + + # get tokens and token cost + resp = ImprovedPromptResponse(output=output, **response_dct) + logger.debug(f"resp: {resp}") + + return resp + + class SampleTransformSkill(TransformSkill): sample_size: int @@ -548,30 +593,22 @@ class AnalysisSkill(Skill): Analysis skill that analyzes a dataframe and returns a record (e.g. for data analysis purposes). See base class Skill for more information about the attributes. """ - + input_prefix: str = "" input_separator: str = "\n" chunk_size: Optional[int] = None - def apply( - self, - input: Union[InternalDataFrame, InternalSeries, Dict], - runtime: Runtime, - ) -> InternalDataFrame: - """ - Applies the skill to a dataframe and returns a record. - - Args: - input (InternalDataFrame): The input data to be processed. - runtime (Runtime): The runtime instance to be used for processing. + def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int] = None): - Returns: - InternalSeries: The record containing the analysis results. - """ + if input.empty: + yield "" + return + if isinstance(input, InternalSeries): input = input.to_frame() elif isinstance(input, dict): input = InternalDataFrame([input]) + extra_fields = self._get_extra_fields() # if chunk_size is specified, split the input into chunks and process each chunk separately @@ -582,25 +619,65 @@ def apply( ) else: chunks = [input] - outputs = [] + total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1 for chunk in tqdm(chunks, desc="Processing chunks", total=total): - agg_chunk = ( - chunk.reset_index() + agg_chunk = chunk\ + .reset_index()\ .apply( lambda row: self.input_template.format( **row, **extra_fields, i=int(row.name) + 1 ), axis=1, - ) - .str.cat(sep=self.input_separator) - ) + ).str.cat(sep=self.input_separator) + + yield agg_chunk + + def apply( + self, + input: Union[InternalDataFrame, InternalSeries, Dict], + runtime: Runtime, + ) -> InternalDataFrame: + """ + Applies the skill to a dataframe and returns a record. + + Args: + input (InternalDataFrame): The input data to be processed. + runtime (Runtime): The runtime instance to be used for processing. + + Returns: + InternalSeries: The record containing the analysis results. + """ + outputs = [] + for agg_chunk in self._iter_over_chunks(input): output = runtime.record_to_record( - {"input": agg_chunk}, + {"input": f"{self.input_prefix}{agg_chunk}"}, + input_template="{input}", + output_template=self.output_template, + instructions_template=self.instructions, + instructions_first=self.instructions_first, + response_model=self.response_model, + ) + outputs.append(InternalSeries(output)) + output = InternalDataFrame(outputs) + + return output + + async def aapply( + self, + input: Union[InternalDataFrame, InternalSeries, Dict], + runtime: AsyncRuntime, + ) -> InternalDataFrame: + """ + Applies the skill to a dataframe and returns a record. + """ + outputs = [] + for agg_chunk in self._iter_over_chunks(input): + output = await runtime.record_to_record( + {"input": f"{self.input_prefix}{agg_chunk}"}, input_template="{input}", output_template=self.output_template, instructions_template=self.instructions, - extra_fields=extra_fields, instructions_first=self.instructions_first, response_model=self.response_model, ) diff --git a/adala/skills/collection/prompt_improvement.py b/adala/skills/collection/prompt_improvement.py new file mode 100644 index 00000000..5f1742e0 --- /dev/null +++ b/adala/skills/collection/prompt_improvement.py @@ -0,0 +1,186 @@ +import json +import logging +from pydantic import BaseModel, field_validator, Field, ConfigDict, model_validator +from adala.skills import Skill +from typing import Any, Dict, List, Optional, Union +from adala.skills import AnalysisSkill +from adala.utils.parse import parse_template +from adala.utils.types import ErrorResponseModel + +logger = logging.getLogger(__name__) + + +class PromptImprovementSkillResponseModel(BaseModel): + + + reasoning: str = Field(..., description="The reasoning for the changes made to the prompt") + new_prompt_title: str = Field(..., description="The new short title for the prompt") + new_prompt_content: str = Field(..., description="The new content for the prompt") + + # model_config = ConfigDict( + # # omit other fields + # extra="ignore", + # # guard against name collisions with other fields + # populate_by_name=False, + # ) + + # @field_validator("new_prompt_content", mode="after") + # def validate_used_variables(cls, value: str) -> str: + + # templates = parse_template(value, include_texts=False) + # if not templates: + # raise ValueError("At least one input variable must be used in the prompt") + + # input_vars_used = [t["text"] for t in templates] + # if extra_vars_used := set(input_vars_used) - set(cls._input_variables): + # raise ValueError( + # f"Invalid variable used in prompt: {extra_vars_used}. Valid variables are: {cls._input_variables}" + # ) + + # return value + + +class ImprovedPromptResponse(BaseModel): + + output: Union[PromptImprovementSkillResponseModel, ErrorResponseModel] + + prompt_tokens: int = Field(alias="_prompt_tokens", default=None) + completion_tokens: int = Field(alias="_completion_tokens", default=None) + + # these can fail to calculate + prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd", default=None) + completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd", default=None) + total_cost_usd: Optional[float] = Field(alias="_total_cost_usd", default=None) + + +class PromptImprovementSkill(AnalysisSkill): + + skill_to_improve: Skill + # TODO: include model provider to specialize the prompt format for a specific model provider + # model_provider: str + input_variables: List[str] + + name: str = "prompt_improvement" + instructions: str = "" # Automatically generated + input_template: str = "" # Not used + input_prefix: str = "Here are a few prediction results after applying the current prompt for your analysis.\n\n" + input_separator: str = "\n\n" + + response_model = PromptImprovementSkillResponseModel + + + @model_validator(mode="after") + def validate_prompts(self): + input_variables = '\n'.join(self.input_variables) + + # rewrite the instructions with the actual values + self.instructions = f"""\ +You are a prompt engineer tasked with generating or enhancing a prompt for a Language Learning Model (LLM). Your goal is to create an effective prompt based on the given context, input data and requirements. + +First, carefully review the following context information: + +# Given context + +## Task name +{self.skill_to_improve.name} + +## Task description +{self.skill_to_improve.description} + +## Input variables to use +{input_variables} + +## Target response schema +```json +{json.dumps(self.skill_to_improve.response_model.model_json_schema(), indent=2)} +``` +Now, examine the current prompt (if provided): + +# Current prompt +{self.skill_to_improve.input_template} + +If a current prompt is provided, analyze it for potential improvements or errors. Consider how well it addresses the task description, input data and if it effectively utilizes all provided input variables. + +Before creating the new prompt, provide a detailed reasoning for your choices. Include: +1. How you addressed the context and task description +2. Any potential errors or improvements you identified in the previous prompt (if applicable) +3. How your new prompt better suits the target model provider +4. How your prompt is designed to generate responses matching the provided schema + +Next, generate a new short prompt title that accurately reflects the task and purpose of the prompt. + +Finally, create the new prompt content. Ensure that you: +1. Incorporate all provided input variables, formatted with "{{" and "}}" brackets +2. Address the specific task description provided in the context +3. Consider the target model provider's capabilities and limitations +4. Maintain or improve upon any relevant information from the current prompt (if provided) +5. Structure the prompt to elicit a response that matches the provided response schema + +Present your output in JSON format including the following fields: +- reasoning +- new_prompt_title +- new_prompt_content + + +# Example of the expected input and output: + +Input context: + +## Target model provider +OpenAI + +## Task description +Generate a summary of the input text. + +## Allowed input variables +text +document_metadata + +## Target response schema +```json +{{ + "summary": {{ + "type": "string" + }}, + "categories": {{ + "type": "string", + "enum": ["news", "science", "politics", "sports", "entertainment"] + }} +}} +``` + +Check the following example to see how the model should respond: + +Current prompt: +``` +Generate a summary of the input text: "{{text}}". +``` + +# Current prompt output + +Generate a summary of the input text: "The quick brown fox jumps over the lazy dog." --> {{"summary": "The quick brown fox jumps over the lazy dog.", "categories": "news"}} + +Generate a summary of the input text: "When was the Battle of Hastings?" --> {{"summary": "The Battle of Hastings was a decisive Norman victory in 1066, marking the end of Anglo-Saxon rule in England.", "categories": "history"}} + +Generate a summary of the input text: "What is the capital of France?" --> {{ "summary": "The capital of France is Paris.", "categories": "geography"}} + + +Your output: +```json +{{ + "reasoning": "The current prompt is too vague. It doesn't specify the format or style of the summary. Addidionally, the categories instructions are not provided. It results in low quality outputs, like "summary" asnwers the question but not summarizes the input text. "history" category is not provided in the response schema, so it is not possible to produce the output. Also, not all requested input variables are used. To ensure high quality responses, I need to make the following changes: ...", + "new_prompt_title": "Including categories instructions in the summary", + "new_prompt_content": "Generate a detailed summary of the input text:\n'''{{text}}'''.\nUse the document metadata to guide the model to produce categories.\n#Metadata:\n'''{{document_metadata}}'''.\nEnsure high quality output by asking the model to produce a detailed summary and to categorize the document." +}} +``` + +Ensure that your refined prompt is clear, concise, and effectively guides the LLM to produce high quality responses. + +""" + + # Create the output template for JSON output based on the response model fields + fields = self.skill_to_improve.response_model.model_fields + field_template = ", ".join([f'"{field}": "{{{field}}}"'for field in fields]) + self.output_template = "{{" + field_template + "}}" + self.input_template = f"{self.skill_to_improve.input_template} --> {self.output_template}" + return self diff --git a/adala/utils/types.py b/adala/utils/types.py new file mode 100644 index 00000000..38db5eb7 --- /dev/null +++ b/adala/utils/types.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field, ConfigDict, field_validator +from typing import List, Optional, Union +from adala.utils.parse import parse_template + + +class BatchData(BaseModel): + """ + Model for a batch of data submitted to a streaming job + """ + + job_id: str + data: List[dict] + + +class ErrorResponseModel(BaseModel): + message: str = Field(..., alias="_adala_message") + details: str = Field(..., alias="_adala_details") + + model_config = ConfigDict( + # omit other fields + extra="ignore", + # guard against name collisions with other fields + populate_by_name=False, + ) diff --git a/server/app.py b/server/app.py index 485fb40c..16dd9f28 100644 --- a/server/app.py +++ b/server/app.py @@ -2,23 +2,28 @@ from typing import Any, Dict, Generic, List, Optional, TypeVar import os import json +import pandas as pd import fastapi from fastapi import Request, status from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from adala.agents import Agent +from adala.skills import Skill +from adala.runtimes import AsyncRuntime from aiokafka import AIOKafkaProducer from aiokafka.errors import UnknownTopicOrPartitionError from fastapi import HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, SerializeAsAny, field_validator +from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator from redis import Redis import time import uvicorn +from adala.utils.types import BatchData, ErrorResponseModel from server.handlers.result_handlers import ResultHandler from server.log_middleware import LogMiddleware +from adala.skills.collection.prompt_improvement import ImprovedPromptResponse from server.tasks.stream_inference import streaming_parent_task from server.utils import ( Settings, @@ -126,15 +131,6 @@ def validate_result_handler(cls, value: Dict) -> ResultHandler: return result_handler -class BatchData(BaseModel): - """ - Model for a batch of data submitted to a streaming job - """ - - job_id: str - data: List[dict] - - @app.get("/") def get_index(): return {"status": "ok"} @@ -307,6 +303,59 @@ async def ready(redis_conn: Redis = Depends(_get_redis_conn)): return {"status": "ok"} +class ImprovedPromptRequest(BaseModel): + """ + Request model for improving a prompt. + """ + + agent: Agent + skill_to_improve: str + input_variables: Optional[List[str]] = Field( + default=None, + description="List of variables available to use in the input template of the skill, in case any exist that are not currently used", + ) + batch_data: Optional[BatchData] = Field( + default=None, + description="Batch of data to run the skill on", + ) + + @field_validator("agent", mode="after") + def validate_teacher_runtime(cls, agent: Agent) -> Agent: + if not isinstance(agent.get_teacher_runtime(), AsyncRuntime): + raise ValueError("Default teacher runtime must be an AsyncRuntime") + return agent + + @model_validator(mode="after") + def set_input_variable_list(self): + skill = self.agent.skills[self.skill_to_improve] + if self.input_variables is None: + self.input_variables = skill.get_input_fields() + return self + + +@app.post("/improved-prompt", response_model=Response[ImprovedPromptResponse]) +async def improved_prompt(request: ImprovedPromptRequest): + """ + Improve a given prompt using the specified model and variables. + + Args: + request (ImprovedPromptRequest): The request model for improving a prompt. + + Returns: + Response: Response model for prompt improvement skill + """ + + improved_prompt_response = await request.agent.arefine_skill( + skill_name=request.skill_to_improve, + input_variables=request.input_variables, + batch_data=request.batch_data.data if request.batch_data else None + ) + + return Response[ImprovedPromptResponse]( + success=not isinstance(improved_prompt_response.output, ErrorResponseModel), + data=improved_prompt_response + ) + if __name__ == "__main__": # for debugging uvicorn.run("app:app", host="0.0.0.0", port=30001) diff --git a/tests/test_refine_skill.py b/tests/test_refine_skill.py new file mode 100644 index 00000000..4ffce8a3 --- /dev/null +++ b/tests/test_refine_skill.py @@ -0,0 +1,155 @@ +import pytest +import os +from adala.agents.base import Agent +from adala.skills._base import TransformSkill +from adala.skills.collection.prompt_improvement import ImprovedPromptResponse +from unittest.mock import patch + + +@pytest.fixture +def agent_json(): + return { + "runtimes": { + "default": { + "type": "AsyncLiteLLMChatRuntime", + "model": "gpt-4o-mini", + "api_key": os.getenv("OPENAI_API_KEY"), + "max_tokens": 200, + "temperature": 0, + "batch_size": 100, + "timeout": 10, + "verbose": False, + } + }, + "teacher_runtimes": { + "default": { + "type": "AsyncLiteLLMChatRuntime", + "model": "gpt-4o-mini", + "api_key": os.getenv("OPENAI_API_KEY"), + "max_tokens": 1000, + "temperature": 0, + "batch_size": 100, + "timeout": 10, + "verbose": False, + } + }, + "skills": [ + { + "type": "ClassificationSkill", + "name": "my_classification_skill", + "instructions": "", + "input_template": "{text} {id}", + "field_schema": { + "output": { + "type": "string", + "enum": ["positive", "negative", "neutral"], + } + }, + } + ], + } + + +@pytest.mark.use_openai +@pytest.mark.asyncio +async def test_arefine_skill_no_input_data(client, agent_json): + skill_name = "my_classification_skill" + + payload = { + "agent": agent_json, + "skill_to_improve": skill_name, + "input_variables": ["text", "id"], + } + + response = client.post("/improved-prompt", json=payload) + + assert response.status_code == 200 + result = response.json() + + assert "data" in result + assert "output" in result["data"] + output = result["data"]["output"] + + assert "reasoning" in output + assert "new_prompt_title" in output + assert "new_prompt_content" in output + assert '{text}' in output["new_prompt_content"] + + +@pytest.mark.use_openai +@pytest.mark.asyncio +async def test_arefine_skill_with_input_data(client, agent_json): + skill_name = "my_classification_skill" + + batch_data = [ + {"text": "This is a test text", "id": "1"}, + {"text": "This is another test text", "id": "2"}, + ] + + payload = { + "agent": agent_json, + "skill_to_improve": skill_name, + "input_variables": ["text", "id"], + "batch_data": { + 'job_id': '123', + 'data': batch_data, + } + } + + response = client.post("/improved-prompt", json=payload) + + assert response.status_code == 200 + result = response.json() + + assert "data" in result + assert "output" in result["data"] + output = result["data"]["output"] + + assert "reasoning" in output + assert "new_prompt_title" in output + assert "new_prompt_content" in output + assert '{text}' in output["new_prompt_content"] + assert '{id}' in output["new_prompt_content"] + + +@pytest.mark.use_openai +@pytest.mark.asyncio +async def test_arefine_skill_error_handling(client, agent_json): + skill_name = "my_classification_skill" + + batch_data = None + + agent_json["teacher_runtimes"]["default"]["model"] = "nonexistent" + + payload = { + "agent": agent_json, + "skill_to_improve": skill_name, + "input_variables": ["text", "id"], + "batch_data": batch_data, + } + response = client.post("/improved-prompt", json=payload) + assert response.status_code == 422 + + # test runtime failure + agent_json["teacher_runtimes"]["default"]["model"] = "gpt-4o" + with patch("instructor.AsyncInstructor.create_with_completion") as mock_create: + + def side_effect(*args, **kwargs): + if skill_name in str(kwargs): + raise Exception(f"Simulated OpenAI API failure for {skill_name}") + return mock_create.return_value + + mock_create.side_effect = side_effect + + resp = client.post( + "/improved-prompt", + json={ + "agent": agent_json, + "skill_to_improve": skill_name, + "input_variables": ["text", "id"], + }, + ) + assert resp.raise_for_status() + resp_json = resp.json() + assert not resp_json["success"] + assert f"Simulated OpenAI API failure for {skill_name}" == resp_json["data"]["output"]["_adala_details"] diff --git a/tests/test_server.py b/tests/test_server.py index 68e20eda..901e898d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile import pandas as pd from copy import deepcopy +from unittest.mock import patch # TODO manage which keys correspond to which models/deployments, probably using a litellm Router OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")