From d905a9ad9f228c7d76e05ee0311e93d0c87e0f0e Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Mon, 23 Sep 2024 17:19:46 -0400 Subject: [PATCH 1/7] feat: DIA-1402: V1-Submit Prompt auto-refinement job --- server/app.py | 81 +++++++++++++++++++++++++++++- server/prompt_improvement_skill.py | 64 +++++++++++++++++++++++ 2 files changed, 144 insertions(+), 1 deletion(-) create mode 100644 server/prompt_improvement_skill.py diff --git a/server/app.py b/server/app.py index 485fb40..263e3fb 100644 --- a/server/app.py +++ b/server/app.py @@ -2,23 +2,27 @@ from typing import Any, Dict, Generic, List, Optional, TypeVar import os import json +import pandas as pd import fastapi from fastapi import Request, status from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from adala.agents import Agent +from adala.skills import Skill +from adala.runtimes import AsyncRuntime from aiokafka import AIOKafkaProducer from aiokafka.errors import UnknownTopicOrPartitionError from fastapi import HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, SerializeAsAny, field_validator +from pydantic import BaseModel, SerializeAsAny, field_validator, Field from redis import Redis import time import uvicorn from server.handlers.result_handlers import ResultHandler from server.log_middleware import LogMiddleware +from server.prompt_improvement_skill import PromptImprovementSkillResponseModel, get_prompt_improvement_inputs, get_prompt_improvement_skill from server.tasks.stream_inference import streaming_parent_task from server.utils import ( Settings, @@ -307,6 +311,81 @@ async def ready(redis_conn: Redis = Depends(_get_redis_conn)): return {"status": "ok"} +class ImprovedPromptRequest(BaseModel): + """ + Request model for improving a prompt. + """ + student_skill: Skill + student_model: str + teacher_runtime: AsyncRuntime + input_variables: List[str] + + # same code as for ResultHandler in SubmitStreamingRequest + @field_validator("student_skill", mode="before") + def validate_skill(cls, value: Dict) -> Skill: + if "type" not in value: + raise HTTPException( + status_code=400, detail="Missing type in student_skill" + ) + skill = Skill.create_from_registry(value.pop("type"), **value) + return skill + + # same code as for ResultHandler in SubmitStreamingRequest + @field_validator("teacher_runtime", mode="before") + def validate_teacher_runtime(cls, value: Dict) -> AsyncRuntime: + if "type" not in value: + raise HTTPException( + status_code=400, detail="Missing type in teacher_runtime" + ) + runtime = AsyncRuntime.create_from_registry(value.pop("type"), **value) + return runtime + +class ImprovedPromptResponse(BaseModel): + + output: Optional[PromptImprovementSkillResponseModel] = None + + prompt_tokens: int = Field(alias="_prompt_tokens") + completion_tokens: int = Field(alias="_completion_tokens") + + # these can fail to calculate + prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd") + completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd") + total_cost_usd: Optional[float] = Field(alias="_total_cost_usd") + +@app.post("/improved-prompt", response_model=Response[ImprovedPromptResponse]) +async def improved_prompt(request: ImprovedPromptRequest): + """ + Improve a given prompt using the specified model and variables. + + Args: + request (ImprovedPromptRequest): The request model for improving a prompt. + + Returns: + Response: Response model for prompt improvement skill + """ + + inputs = get_prompt_improvement_inputs(request.student_skill, request.input_variables, request.student_model) + prompt_improvement_skill = get_prompt_improvement_skill(request.input_variables) + # someday we can stop doing this... + df = pd.DataFrame.from_records([inputs]) + response_df = await prompt_improvement_skill.aapply( + input=df, + runtime=request.teacher_runtime, + ) + response_dct = response_df.iloc[0].to_dict() + + # get tokens and token cost + data = ImprovedPromptResponse(**response_dct) + + if response_dct.get("_adala_error", False): + # insert error into Response + return Response(success=False, data=data, message=response_dct["_adala_details"], errors=[response_dct["_adala_message"]]) + else: + # insert output into Response + data.output = PromptImprovementSkillResponseModel(**response_dct) + return Response(data=data) + + if __name__ == "__main__": # for debugging uvicorn.run("app:app", host="0.0.0.0", port=30001) diff --git a/server/prompt_improvement_skill.py b/server/prompt_improvement_skill.py new file mode 100644 index 0000000..78abbef --- /dev/null +++ b/server/prompt_improvement_skill.py @@ -0,0 +1,64 @@ +from pydantic import BaseModel, field_validator +from adala.skills import Skill +from typing import Any, Dict, List, Tuple +from adala.skills.collection.text_generation import TextGenerationSkill + + +class PromptImprovementSkillResponseModel(BaseModel): + + # hidden variable, used for validation + _input_variables: List[str] + reasoning: str + improved_user_prompt: str + # NOTE: not exposed in LSE yet, so default is always used. Should improve this as well when we expose it. + # improved_system_prompt: str + + + @field_validator("improved_user_prompt", mode="after") + def validate_used_variables(cls, value: str) -> str: + + start_variable_idx = value.find("{") + end_variable_idx = value.rfind("}") + if start_variable_idx == -1 or end_variable_idx == -1 or start_variable_idx >= end_variable_idx: + raise ValueError("At least one input variable must be used in the prompt") + + try: + value.format(**{var: "value" for var in cls._input_variables}) + except KeyError as e: + raise ValueError(f"Invalid variable used in prompt: {e}. Valid variables are: {cls._input_variables}") + + return value + + +def get_prompt_improvement_inputs(student_skill: Skill, input_variables: List[str], student_model: str) -> Dict[str, Any]: + return { + "model": student_model, + "task_name": student_skill.name, + "task_description": student_skill.description, + "input_variables": input_variables, + "current_system_prompt": student_skill.instructions, + "current_user_prompt": student_skill.input_template, + "response_json_schema": student_skill.response_model.model_json_schema(), + } + + +def get_prompt_improvement_skill(input_variables: List[str]) -> TextGenerationSkill: + + # setting this dynamically - used to validate the improved prompt + PromptImprovementSkillResponseModel._input_variables = input_variables + + prompt_improvement_skill = TextGenerationSkill( + name="prompt_improvement", + instructions="Improve the user prompt for the provided LLM model to complete the task using the provided input variables, with the provided user prompt as a starting point. Variables can be accessed in the user prompt using the format {variable_name} (only the variable values are used, not their names). Make sure your prompt produces output that will continue to conform to the provided json schema. Provide your reasoning for the changes you made to the prompt.", + input_template=''' + Model: {model} + Task Name: {task_name} + Task Description: {task_description} + Input Variables: {input_variables} + Current System Prompt: {current_system_prompt} + Current User Prompt: {current_user_prompt} + Response JSON Schema: {response_json_schema}''', + response_model=PromptImprovementSkillResponseModel + ) + + return prompt_improvement_skill From a53675eafbad37070e6d02711413dc4eec9c7ad4 Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Mon, 30 Sep 2024 17:53:11 -0400 Subject: [PATCH 2/7] - move prompt improvement to agent - make agent with a teacher runtime serializable - add test --- adala/agents/base.py | 50 +++++++++++- .../skills/collection/prompt_improvement.py | 39 +++++++++- server/app.py | 76 ++++++------------- tests/test_server.py | 12 +++ 4 files changed, 120 insertions(+), 57 deletions(-) rename server/prompt_improvement_skill.py => adala/skills/collection/prompt_improvement.py (68%) diff --git a/adala/agents/base.py b/adala/agents/base.py index 4bba637..e0f3333 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -7,7 +7,7 @@ SerializeAsAny, ) from abc import ABC -from typing import Optional, Dict, Union, Tuple +from typing import Optional, Dict, Union, Tuple, List from rich import print import yaml @@ -16,8 +16,10 @@ from adala.runtimes.base import Runtime, AsyncRuntime from adala.runtimes._openai import OpenAIChatRuntime from adala.skills._base import Skill +from adala.skills.collection.text_generation import TextGenerationSkill from adala.memories.base import Memory from adala.skills.skillset import SkillSet, LinearSkillSet +from adala.skills.collection.prompt_improvement import PromptImprovementSkillResponseModel, ErrorResponseModel, get_prompt_improvement_inputs, get_prompt_improvement_skill, ImprovedPromptResponse from adala.utils.logs import ( print_dataframe, print_text, @@ -61,7 +63,7 @@ class Agent(BaseModel, ABC): default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")} ) default_runtime: str = "default" - teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field( + teacher_runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field( default_factory=lambda: {"default": None} ) default_teacher_runtime: str = "default" @@ -118,7 +120,7 @@ def skills_validator(cls, v) -> SkillSet: f"skills must be of type SkillSet or Skill, but received type {type(v)}" ) - @field_validator("runtimes", mode="before") + @field_validator("runtimes", "teacher_runtimes", mode="before") def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]: """ Validates and creates runtimes @@ -392,6 +394,48 @@ def learn( break print_text("Train is done!") + + + async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> ImprovedPromptResponse: + """ + beta v2 of Agent.learn() that is: + - compatible with the newer LiteLLM runtimes + - compatible with the newer response_model output formats for skills + - returns chain of thought reasoning in a legible format + + Limitations so far: + - single skill at a time + - only returns the improved input_template, doesn't modify the skill in place + - doesn't use examples/feedback + - no iterations/variable cost + """ + + skill = self.skills[skill_name] + if not isinstance(skill, TextGenerationSkill): + raise ValueError(f"Skill {skill_name} is not a TextGenerationSkill") + + inputs = get_prompt_improvement_inputs(skill, input_variables, self.get_runtime().model) + # this is why this function cannot be parallelized over skills - input variables are injected into the response model so that they can be validated with LLM feedback within a single Instructor call + # TODO find a way to get around this and use batch_to_batch or a higher-level optimizer over all skills in the skillset + prompt_improvement_skill = get_prompt_improvement_skill(input_variables) + # awkward to go from response model -> dict -> df -> dict -> response model + df = InternalDataFrame.from_records([inputs]) + response_df = await prompt_improvement_skill.aapply( + input=df, + runtime=self.get_teacher_runtime(), + ) + response_dct = response_df.iloc[0].to_dict() + + # get tokens and token cost + data = ImprovedPromptResponse(**response_dct) + + if response_dct.get("_adala_error", False): + # insert error into Response + data.output = ErrorResponseModel(**response_dct) + else: + # insert output into Response + data.output = PromptImprovementSkillResponseModel(**response_dct) + return data def create_agent_from_dict(json_dict: Dict): diff --git a/server/prompt_improvement_skill.py b/adala/skills/collection/prompt_improvement.py similarity index 68% rename from server/prompt_improvement_skill.py rename to adala/skills/collection/prompt_improvement.py index 78abbef..546e3e3 100644 --- a/server/prompt_improvement_skill.py +++ b/adala/skills/collection/prompt_improvement.py @@ -1,9 +1,11 @@ -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, Field, ConfigDict from adala.skills import Skill -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Union from adala.skills.collection.text_generation import TextGenerationSkill +# NOTE: these response models are converging with the LSE ResultHandler, slowly pushing typing deeper into the lib with the end goal of combining them + class PromptImprovementSkillResponseModel(BaseModel): # hidden variable, used for validation @@ -13,7 +15,13 @@ class PromptImprovementSkillResponseModel(BaseModel): # NOTE: not exposed in LSE yet, so default is always used. Should improve this as well when we expose it. # improved_system_prompt: str - + model_config = ConfigDict( + # omit other fields + extra="ignore", + # guard against name collisions with other fields + populate_by_name=False, + ) + @field_validator("improved_user_prompt", mode="after") def validate_used_variables(cls, value: str) -> str: @@ -28,6 +36,31 @@ def validate_used_variables(cls, value: str) -> str: raise ValueError(f"Invalid variable used in prompt: {e}. Valid variables are: {cls._input_variables}") return value + +class ErrorResponseModel(BaseModel): + message: str = Field(..., alias="_adala_message") + details: str = Field(..., alias="_adala_details") + + model_config = ConfigDict( + # omit other fields + extra="ignore", + # guard against name collisions with other fields + populate_by_name=False, + ) + +class ImprovedPromptResponse(BaseModel): + + output: Union[PromptImprovementSkillResponseModel, ErrorResponseModel] + + prompt_tokens: int = Field(alias="_prompt_tokens") + completion_tokens: int = Field(alias="_completion_tokens") + + # these can fail to calculate + prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd") + completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd") + total_cost_usd: Optional[float] = Field(alias="_total_cost_usd") + + def get_prompt_improvement_inputs(student_skill: Skill, input_variables: List[str], student_model: str) -> Dict[str, Any]: diff --git a/server/app.py b/server/app.py index 263e3fb..b85e6b0 100644 --- a/server/app.py +++ b/server/app.py @@ -15,14 +15,14 @@ from aiokafka.errors import UnknownTopicOrPartitionError from fastapi import HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, SerializeAsAny, field_validator, Field +from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator from redis import Redis import time import uvicorn from server.handlers.result_handlers import ResultHandler from server.log_middleware import LogMiddleware -from server.prompt_improvement_skill import PromptImprovementSkillResponseModel, get_prompt_improvement_inputs, get_prompt_improvement_skill +from adala.skills.collection.prompt_improvement import ImprovedPromptResponse, ErrorResponseModel from server.tasks.stream_inference import streaming_parent_task from server.utils import ( Settings, @@ -315,42 +315,27 @@ class ImprovedPromptRequest(BaseModel): """ Request model for improving a prompt. """ - student_skill: Skill - student_model: str - teacher_runtime: AsyncRuntime - input_variables: List[str] - - # same code as for ResultHandler in SubmitStreamingRequest - @field_validator("student_skill", mode="before") - def validate_skill(cls, value: Dict) -> Skill: - if "type" not in value: - raise HTTPException( - status_code=400, detail="Missing type in student_skill" - ) - skill = Skill.create_from_registry(value.pop("type"), **value) - return skill - - # same code as for ResultHandler in SubmitStreamingRequest - @field_validator("teacher_runtime", mode="before") - def validate_teacher_runtime(cls, value: Dict) -> AsyncRuntime: - if "type" not in value: - raise HTTPException( - status_code=400, detail="Missing type in teacher_runtime" - ) - runtime = AsyncRuntime.create_from_registry(value.pop("type"), **value) - return runtime + agent: Agent + skill_to_improve: str + input_variables: Optional[Dict[str, List[str]]] = Field( + default=None, + description="List of variables available to use in the input template of the skill, in case any exist that are not currently used" + ) -class ImprovedPromptResponse(BaseModel): + @field_validator("agent", mode="after") + def validate_teacher_runtime(cls, agent: Agent) -> Agent: + if not isinstance(agent.get_teacher_runtime(), AsyncRuntime): + raise ValueError("Default teacher runtime must be an AsyncRuntime") + return agent - output: Optional[PromptImprovementSkillResponseModel] = None - - prompt_tokens: int = Field(alias="_prompt_tokens") - completion_tokens: int = Field(alias="_completion_tokens") + @model_validator(mode="after") + def set_input_variable_list(self): + skill = self.agent.skills[self.skill_to_improve] + if self.input_variables is None: + self.input_variables = skill.get_input_fields() + return self + - # these can fail to calculate - prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd") - completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd") - total_cost_usd: Optional[float] = Field(alias="_total_cost_usd") @app.post("/improved-prompt", response_model=Response[ImprovedPromptResponse]) async def improved_prompt(request: ImprovedPromptRequest): @@ -364,25 +349,14 @@ async def improved_prompt(request: ImprovedPromptRequest): Response: Response model for prompt improvement skill """ - inputs = get_prompt_improvement_inputs(request.student_skill, request.input_variables, request.student_model) - prompt_improvement_skill = get_prompt_improvement_skill(request.input_variables) - # someday we can stop doing this... - df = pd.DataFrame.from_records([inputs]) - response_df = await prompt_improvement_skill.aapply( - input=df, - runtime=request.teacher_runtime, - ) - response_dct = response_df.iloc[0].to_dict() - - # get tokens and token cost - data = ImprovedPromptResponse(**response_dct) + agent = request.agent + data = await agent.arefine_skill(request.skill_to_improve, request.input_variables) - if response_dct.get("_adala_error", False): + if isinstance(data.output, ErrorResponseModel): # insert error into Response - return Response(success=False, data=data, message=response_dct["_adala_details"], errors=[response_dct["_adala_message"]]) + return Response(success=False, data=data, message=data.output.details, errors=[data.output.message]) else: - # insert output into Response - data.output = PromptImprovementSkillResponseModel(**response_dct) + # return output return Response(data=data) diff --git a/tests/test_server.py b/tests/test_server.py index 68e20ed..61be49c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -474,3 +474,15 @@ def test_streaming_azure(client): assert ( actual_output == expected_output ), "adala did not return expected output" + + +def test_prompt_improvement_endpoint(client): + agent = SUBMIT_PAYLOAD["agent"] + agent['teacher_runtimes'] = agent['runtimes'] + agent['teacher_runtimes']['default']['model'] = 'gpt-4o' + resp = client.post("/improved-prompt", json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }) + resp.raise_for_status() \ No newline at end of file From 7a2b95f82fe8a26a11c619b134f961fb09e3d033 Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Tue, 1 Oct 2024 10:20:56 -0400 Subject: [PATCH 3/7] add failure test --- adala/agents/base.py | 26 ++++++++++++-------------- server/app.py | 2 +- tests/test_server.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/adala/agents/base.py b/adala/agents/base.py index e0f3333..3b80504 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -15,8 +15,7 @@ from adala.environments.static_env import StaticEnvironment from adala.runtimes.base import Runtime, AsyncRuntime from adala.runtimes._openai import OpenAIChatRuntime -from adala.skills._base import Skill -from adala.skills.collection.text_generation import TextGenerationSkill +from adala.skills._base import Skill, TransformSkill from adala.memories.base import Memory from adala.skills.skillset import SkillSet, LinearSkillSet from adala.skills.collection.prompt_improvement import PromptImprovementSkillResponseModel, ErrorResponseModel, get_prompt_improvement_inputs, get_prompt_improvement_skill, ImprovedPromptResponse @@ -411,12 +410,12 @@ async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> Im """ skill = self.skills[skill_name] - if not isinstance(skill, TextGenerationSkill): - raise ValueError(f"Skill {skill_name} is not a TextGenerationSkill") + if not isinstance(skill, TransformSkill): + raise ValueError(f"Skill {skill_name} is not a TransformSkill") inputs = get_prompt_improvement_inputs(skill, input_variables, self.get_runtime().model) # this is why this function cannot be parallelized over skills - input variables are injected into the response model so that they can be validated with LLM feedback within a single Instructor call - # TODO find a way to get around this and use batch_to_batch or a higher-level optimizer over all skills in the skillset + # TODO get around this and use batch_to_batch or a higher-level optimizer over all skills in the skillset prompt_improvement_skill = get_prompt_improvement_skill(input_variables) # awkward to go from response model -> dict -> df -> dict -> response model df = InternalDataFrame.from_records([inputs]) @@ -426,16 +425,15 @@ async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> Im ) response_dct = response_df.iloc[0].to_dict() - # get tokens and token cost - data = ImprovedPromptResponse(**response_dct) - - if response_dct.get("_adala_error", False): - # insert error into Response - data.output = ErrorResponseModel(**response_dct) + # unflatten the response + if response_dct.pop("_adala_error", False): + output = ErrorResponseModel(**response_dct) else: - # insert output into Response - data.output = PromptImprovementSkillResponseModel(**response_dct) - return data + output = PromptImprovementSkillResponseModel(**response_dct) + + # get tokens and token cost + resp = ImprovedPromptResponse(output=output, **response_dct) + return resp def create_agent_from_dict(json_dict: Dict): diff --git a/server/app.py b/server/app.py index b85e6b0..e7a155c 100644 --- a/server/app.py +++ b/server/app.py @@ -317,7 +317,7 @@ class ImprovedPromptRequest(BaseModel): """ agent: Agent skill_to_improve: str - input_variables: Optional[Dict[str, List[str]]] = Field( + input_variables: Optional[List[str]] = Field( default=None, description="List of variables available to use in the input template of the skill, in case any exist that are not currently used" ) diff --git a/tests/test_server.py b/tests/test_server.py index 61be49c..d01845a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile import pandas as pd from copy import deepcopy +from unittest.mock import patch # TODO manage which keys correspond to which models/deployments, probably using a litellm Router OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") @@ -477,6 +478,8 @@ def test_streaming_azure(client): def test_prompt_improvement_endpoint(client): + + # test success agent = SUBMIT_PAYLOAD["agent"] agent['teacher_runtimes'] = agent['runtimes'] agent['teacher_runtimes']['default']['model'] = 'gpt-4o' @@ -485,4 +488,34 @@ def test_prompt_improvement_endpoint(client): "skill_to_improve": "text_classifier", "input_variables": ["text"], }) - resp.raise_for_status() \ No newline at end of file + resp.raise_for_status() + assert resp.json()['success'] + assert resp.json()["data"]["output"]["improved_user_prompt"] is not None + + # test failure in payload + agent['teacher_runtimes']['default']['model'] = 'nonexistent' + resp = client.post("/improved-prompt", json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }) + assert resp.status_code == 422 + + # test runtime failure + agent['teacher_runtimes']['default']['model'] = 'gpt-4o' + with patch('instructor.AsyncInstructor.create_with_completion') as mock_create: + def side_effect(*args, **kwargs): + if 'text_classifier' in str(kwargs): + raise Exception("Simulated OpenAI API failure for text_classifier") + return mock_create.return_value + + mock_create.side_effect = side_effect + + resp = client.post("/improved-prompt", json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }) + resp.raise_for_status() + assert not resp.json()['success'] + assert "Simulated OpenAI API failure for text_classifier" == resp.json()["message"] \ No newline at end of file From dd6ad8ffd8edb03c13e45e523869fe5689eac2a1 Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Tue, 1 Oct 2024 10:21:50 -0400 Subject: [PATCH 4/7] black --- adala/agents/base.py | 23 +++++-- adala/skills/collection/prompt_improvement.py | 37 +++++++---- server/app.py | 20 ++++-- tests/test_server.py | 64 +++++++++++-------- 4 files changed, 90 insertions(+), 54 deletions(-) diff --git a/adala/agents/base.py b/adala/agents/base.py index 3b80504..06e624b 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -18,7 +18,13 @@ from adala.skills._base import Skill, TransformSkill from adala.memories.base import Memory from adala.skills.skillset import SkillSet, LinearSkillSet -from adala.skills.collection.prompt_improvement import PromptImprovementSkillResponseModel, ErrorResponseModel, get_prompt_improvement_inputs, get_prompt_improvement_skill, ImprovedPromptResponse +from adala.skills.collection.prompt_improvement import ( + PromptImprovementSkillResponseModel, + ErrorResponseModel, + get_prompt_improvement_inputs, + get_prompt_improvement_skill, + ImprovedPromptResponse, +) from adala.utils.logs import ( print_dataframe, print_text, @@ -393,9 +399,10 @@ def learn( break print_text("Train is done!") - - async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> ImprovedPromptResponse: + async def arefine_skill( + self, skill_name: str, input_variables: List[str] + ) -> ImprovedPromptResponse: """ beta v2 of Agent.learn() that is: - compatible with the newer LiteLLM runtimes @@ -408,12 +415,14 @@ async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> Im - doesn't use examples/feedback - no iterations/variable cost """ - + skill = self.skills[skill_name] if not isinstance(skill, TransformSkill): raise ValueError(f"Skill {skill_name} is not a TransformSkill") - - inputs = get_prompt_improvement_inputs(skill, input_variables, self.get_runtime().model) + + inputs = get_prompt_improvement_inputs( + skill, input_variables, self.get_runtime().model + ) # this is why this function cannot be parallelized over skills - input variables are injected into the response model so that they can be validated with LLM feedback within a single Instructor call # TODO get around this and use batch_to_batch or a higher-level optimizer over all skills in the skillset prompt_improvement_skill = get_prompt_improvement_skill(input_variables) @@ -424,7 +433,7 @@ async def arefine_skill(self, skill_name: str, input_variables: List[str]) -> Im runtime=self.get_teacher_runtime(), ) response_dct = response_df.iloc[0].to_dict() - + # unflatten the response if response_dct.pop("_adala_error", False): output = ErrorResponseModel(**response_dct) diff --git a/adala/skills/collection/prompt_improvement.py b/adala/skills/collection/prompt_improvement.py index 546e3e3..7cd9ad2 100644 --- a/adala/skills/collection/prompt_improvement.py +++ b/adala/skills/collection/prompt_improvement.py @@ -6,8 +6,9 @@ # NOTE: these response models are converging with the LSE ResultHandler, slowly pushing typing deeper into the lib with the end goal of combining them + class PromptImprovementSkillResponseModel(BaseModel): - + # hidden variable, used for validation _input_variables: List[str] reasoning: str @@ -21,22 +22,29 @@ class PromptImprovementSkillResponseModel(BaseModel): # guard against name collisions with other fields populate_by_name=False, ) - + @field_validator("improved_user_prompt", mode="after") def validate_used_variables(cls, value: str) -> str: start_variable_idx = value.find("{") end_variable_idx = value.rfind("}") - if start_variable_idx == -1 or end_variable_idx == -1 or start_variable_idx >= end_variable_idx: + if ( + start_variable_idx == -1 + or end_variable_idx == -1 + or start_variable_idx >= end_variable_idx + ): raise ValueError("At least one input variable must be used in the prompt") try: value.format(**{var: "value" for var in cls._input_variables}) except KeyError as e: - raise ValueError(f"Invalid variable used in prompt: {e}. Valid variables are: {cls._input_variables}") + raise ValueError( + f"Invalid variable used in prompt: {e}. Valid variables are: {cls._input_variables}" + ) return value + class ErrorResponseModel(BaseModel): message: str = Field(..., alias="_adala_message") details: str = Field(..., alias="_adala_details") @@ -48,10 +56,11 @@ class ErrorResponseModel(BaseModel): populate_by_name=False, ) + class ImprovedPromptResponse(BaseModel): - + output: Union[PromptImprovementSkillResponseModel, ErrorResponseModel] - + prompt_tokens: int = Field(alias="_prompt_tokens") completion_tokens: int = Field(alias="_completion_tokens") @@ -61,9 +70,9 @@ class ImprovedPromptResponse(BaseModel): total_cost_usd: Optional[float] = Field(alias="_total_cost_usd") - - -def get_prompt_improvement_inputs(student_skill: Skill, input_variables: List[str], student_model: str) -> Dict[str, Any]: +def get_prompt_improvement_inputs( + student_skill: Skill, input_variables: List[str], student_model: str +) -> Dict[str, Any]: return { "model": student_model, "task_name": student_skill.name, @@ -79,19 +88,19 @@ def get_prompt_improvement_skill(input_variables: List[str]) -> TextGenerationSk # setting this dynamically - used to validate the improved prompt PromptImprovementSkillResponseModel._input_variables = input_variables - + prompt_improvement_skill = TextGenerationSkill( name="prompt_improvement", instructions="Improve the user prompt for the provided LLM model to complete the task using the provided input variables, with the provided user prompt as a starting point. Variables can be accessed in the user prompt using the format {variable_name} (only the variable values are used, not their names). Make sure your prompt produces output that will continue to conform to the provided json schema. Provide your reasoning for the changes you made to the prompt.", - input_template=''' + input_template=""" Model: {model} Task Name: {task_name} Task Description: {task_description} Input Variables: {input_variables} Current System Prompt: {current_system_prompt} Current User Prompt: {current_user_prompt} - Response JSON Schema: {response_json_schema}''', - response_model=PromptImprovementSkillResponseModel + Response JSON Schema: {response_json_schema}""", + response_model=PromptImprovementSkillResponseModel, ) - + return prompt_improvement_skill diff --git a/server/app.py b/server/app.py index e7a155c..304817e 100644 --- a/server/app.py +++ b/server/app.py @@ -22,7 +22,10 @@ from server.handlers.result_handlers import ResultHandler from server.log_middleware import LogMiddleware -from adala.skills.collection.prompt_improvement import ImprovedPromptResponse, ErrorResponseModel +from adala.skills.collection.prompt_improvement import ( + ImprovedPromptResponse, + ErrorResponseModel, +) from server.tasks.stream_inference import streaming_parent_task from server.utils import ( Settings, @@ -315,11 +318,12 @@ class ImprovedPromptRequest(BaseModel): """ Request model for improving a prompt. """ + agent: Agent skill_to_improve: str input_variables: Optional[List[str]] = Field( default=None, - description="List of variables available to use in the input template of the skill, in case any exist that are not currently used" + description="List of variables available to use in the input template of the skill, in case any exist that are not currently used", ) @field_validator("agent", mode="after") @@ -327,14 +331,13 @@ def validate_teacher_runtime(cls, agent: Agent) -> Agent: if not isinstance(agent.get_teacher_runtime(), AsyncRuntime): raise ValueError("Default teacher runtime must be an AsyncRuntime") return agent - + @model_validator(mode="after") def set_input_variable_list(self): skill = self.agent.skills[self.skill_to_improve] if self.input_variables is None: self.input_variables = skill.get_input_fields() return self - @app.post("/improved-prompt", response_model=Response[ImprovedPromptResponse]) @@ -348,13 +351,18 @@ async def improved_prompt(request: ImprovedPromptRequest): Returns: Response: Response model for prompt improvement skill """ - + agent = request.agent data = await agent.arefine_skill(request.skill_to_improve, request.input_variables) if isinstance(data.output, ErrorResponseModel): # insert error into Response - return Response(success=False, data=data, message=data.output.details, errors=[data.output.message]) + return Response( + success=False, + data=data, + message=data.output.details, + errors=[data.output.message], + ) else: # return output return Response(data=data) diff --git a/tests/test_server.py b/tests/test_server.py index d01845a..cecfa0f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -478,44 +478,54 @@ def test_streaming_azure(client): def test_prompt_improvement_endpoint(client): - + # test success agent = SUBMIT_PAYLOAD["agent"] - agent['teacher_runtimes'] = agent['runtimes'] - agent['teacher_runtimes']['default']['model'] = 'gpt-4o' - resp = client.post("/improved-prompt", json={ - "agent": agent, - "skill_to_improve": "text_classifier", - "input_variables": ["text"], - }) + agent["teacher_runtimes"] = agent["runtimes"] + agent["teacher_runtimes"]["default"]["model"] = "gpt-4o" + resp = client.post( + "/improved-prompt", + json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }, + ) resp.raise_for_status() - assert resp.json()['success'] + assert resp.json()["success"] assert resp.json()["data"]["output"]["improved_user_prompt"] is not None # test failure in payload - agent['teacher_runtimes']['default']['model'] = 'nonexistent' - resp = client.post("/improved-prompt", json={ - "agent": agent, - "skill_to_improve": "text_classifier", - "input_variables": ["text"], - }) + agent["teacher_runtimes"]["default"]["model"] = "nonexistent" + resp = client.post( + "/improved-prompt", + json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }, + ) assert resp.status_code == 422 # test runtime failure - agent['teacher_runtimes']['default']['model'] = 'gpt-4o' - with patch('instructor.AsyncInstructor.create_with_completion') as mock_create: + agent["teacher_runtimes"]["default"]["model"] = "gpt-4o" + with patch("instructor.AsyncInstructor.create_with_completion") as mock_create: + def side_effect(*args, **kwargs): - if 'text_classifier' in str(kwargs): + if "text_classifier" in str(kwargs): raise Exception("Simulated OpenAI API failure for text_classifier") return mock_create.return_value - + mock_create.side_effect = side_effect - - resp = client.post("/improved-prompt", json={ - "agent": agent, - "skill_to_improve": "text_classifier", - "input_variables": ["text"], - }) + + resp = client.post( + "/improved-prompt", + json={ + "agent": agent, + "skill_to_improve": "text_classifier", + "input_variables": ["text"], + }, + ) resp.raise_for_status() - assert not resp.json()['success'] - assert "Simulated OpenAI API failure for text_classifier" == resp.json()["message"] \ No newline at end of file + assert not resp.json()["success"] + assert "Simulated OpenAI API failure for text_classifier" == resp.json()["message"] From 1f7e24782751461d7eb76bdc43e3e59b27b450b5 Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Tue, 1 Oct 2024 10:47:28 -0400 Subject: [PATCH 5/7] mark as openai --- tests/test_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_server.py b/tests/test_server.py index cecfa0f..352c9e0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -477,6 +477,8 @@ def test_streaming_azure(client): ), "adala did not return expected output" +# can't use vcr here because server makes async requests +@pytest.mark.use_openai def test_prompt_improvement_endpoint(client): # test success From e07c6d1d7f8e30668a415c0e7f7f9d269012724d Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Tue, 1 Oct 2024 10:55:07 -0400 Subject: [PATCH 6/7] use parse_template for validation --- adala/skills/collection/prompt_improvement.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/adala/skills/collection/prompt_improvement.py b/adala/skills/collection/prompt_improvement.py index 7cd9ad2..7f4380b 100644 --- a/adala/skills/collection/prompt_improvement.py +++ b/adala/skills/collection/prompt_improvement.py @@ -2,6 +2,7 @@ from adala.skills import Skill from typing import Any, Dict, List, Optional, Union from adala.skills.collection.text_generation import TextGenerationSkill +from adala.utils.parse import parse_template # NOTE: these response models are converging with the LSE ResultHandler, slowly pushing typing deeper into the lib with the end goal of combining them @@ -26,20 +27,14 @@ class PromptImprovementSkillResponseModel(BaseModel): @field_validator("improved_user_prompt", mode="after") def validate_used_variables(cls, value: str) -> str: - start_variable_idx = value.find("{") - end_variable_idx = value.rfind("}") - if ( - start_variable_idx == -1 - or end_variable_idx == -1 - or start_variable_idx >= end_variable_idx - ): + templates = parse_template(value, include_texts=False) + if not templates: raise ValueError("At least one input variable must be used in the prompt") - try: - value.format(**{var: "value" for var in cls._input_variables}) - except KeyError as e: + input_vars_used = [t["text"] for t in templates] + if extra_vars_used := set(input_vars_used) - set(cls._input_variables): raise ValueError( - f"Invalid variable used in prompt: {e}. Valid variables are: {cls._input_variables}" + f"Invalid variable used in prompt: {extra_vars_used}. Valid variables are: {cls._input_variables}" ) return value From 330d0363ad6968184a9c4496fb5fefd024c42639 Mon Sep 17 00:00:00 2001 From: Matt Bernstein Date: Tue, 1 Oct 2024 11:17:01 -0400 Subject: [PATCH 7/7] update prompt for prompt improvement skill --- adala/skills/collection/prompt_improvement.py | 64 ++++++++++++++++--- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/adala/skills/collection/prompt_improvement.py b/adala/skills/collection/prompt_improvement.py index 7f4380b..2160c0c 100644 --- a/adala/skills/collection/prompt_improvement.py +++ b/adala/skills/collection/prompt_improvement.py @@ -86,15 +86,63 @@ def get_prompt_improvement_skill(input_variables: List[str]) -> TextGenerationSk prompt_improvement_skill = TextGenerationSkill( name="prompt_improvement", - instructions="Improve the user prompt for the provided LLM model to complete the task using the provided input variables, with the provided user prompt as a starting point. Variables can be accessed in the user prompt using the format {variable_name} (only the variable values are used, not their names). Make sure your prompt produces output that will continue to conform to the provided json schema. Provide your reasoning for the changes you made to the prompt.", + # system prompt + # TODO add fewshot examples + instructions=""" + You are a prompt improvement agent. + + # Instructions + + Improve the user prompt for an LLM model to complete a task using input variables, with the provided prompt improvement inputs as a starting point. Provide your reasoning for the changes you made to the prompt. + + + # Notes + + - The inputs available to you are: Model, Task Name, Task Description, Input Variables, Current System Prompt, Current User Prompt, Response JSON Schema. + - Input Variables can be accessed in the user prompt using the format {variable_name} (only the variable values are used, not their names). + - Make sure your prompt produces output that will continue to conform to the Response JSON Schema. + - Provide your reasoning for the changes you made to the prompt. Provide the reasoning before providing the improved prompt. + + """, + # user prompt input_template=""" - Model: {model} - Task Name: {task_name} - Task Description: {task_description} - Input Variables: {input_variables} - Current System Prompt: {current_system_prompt} - Current User Prompt: {current_user_prompt} - Response JSON Schema: {response_json_schema}""", + # Prompt Improvement Inputs + + ## Model + + {model} + + + ## Task Name + + {task_name} + + + ## Task Description + + {task_description} + + + ## Input Variables + + {input_variables} + + + ## Current System Prompt + + {current_system_prompt} + + + ## Current User Prompt + + {current_user_prompt} + + + ## Response JSON Schema + + {response_json_schema} + + """, response_model=PromptImprovementSkillResponseModel, )