Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1402: V1-Submit Prompt auto-refinement job #214

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@
SerializeAsAny,
)
from abc import ABC
from typing import Optional, Dict, Union, Tuple
from typing import Optional, Dict, Union, Tuple, List
from rich import print
import yaml

from adala.environments.base import Environment, AsyncEnvironment, EnvironmentFeedback
from adala.environments.static_env import StaticEnvironment
from adala.runtimes.base import Runtime, AsyncRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.skills._base import Skill
from adala.skills._base import Skill, TransformSkill
from adala.memories.base import Memory
from adala.skills.skillset import SkillSet, LinearSkillSet
from adala.skills.collection.prompt_improvement import (
PromptImprovementSkillResponseModel,
ErrorResponseModel,
get_prompt_improvement_inputs,
get_prompt_improvement_skill,
ImprovedPromptResponse,
)
from adala.utils.logs import (
print_dataframe,
print_text,
Expand Down Expand Up @@ -61,7 +68,7 @@ class Agent(BaseModel, ABC):
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
)
default_runtime: str = "default"
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
teacher_runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {"default": None}
)
default_teacher_runtime: str = "default"
Expand Down Expand Up @@ -118,7 +125,7 @@ def skills_validator(cls, v) -> SkillSet:
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
)

@field_validator("runtimes", mode="before")
@field_validator("runtimes", "teacher_runtimes", mode="before")
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
Validates and creates runtimes
Expand Down Expand Up @@ -393,6 +400,50 @@ def learn(

print_text("Train is done!")

async def arefine_skill(
self, skill_name: str, input_variables: List[str]
) -> ImprovedPromptResponse:
"""
beta v2 of Agent.learn() that is:
- compatible with the newer LiteLLM runtimes
- compatible with the newer response_model output formats for skills
- returns chain of thought reasoning in a legible format

Limitations so far:
- single skill at a time
- only returns the improved input_template, doesn't modify the skill in place
- doesn't use examples/feedback
- no iterations/variable cost
"""

skill = self.skills[skill_name]
if not isinstance(skill, TransformSkill):
raise ValueError(f"Skill {skill_name} is not a TransformSkill")

inputs = get_prompt_improvement_inputs(
skill, input_variables, self.get_runtime().model
)
# this is why this function cannot be parallelized over skills - input variables are injected into the response model so that they can be validated with LLM feedback within a single Instructor call
# TODO get around this and use batch_to_batch or a higher-level optimizer over all skills in the skillset
prompt_improvement_skill = get_prompt_improvement_skill(input_variables)
# awkward to go from response model -> dict -> df -> dict -> response model
df = InternalDataFrame.from_records([inputs])
response_df = await prompt_improvement_skill.aapply(
input=df,
runtime=self.get_teacher_runtime(),
)
response_dct = response_df.iloc[0].to_dict()

# unflatten the response
if response_dct.pop("_adala_error", False):
output = ErrorResponseModel(**response_dct)
else:
output = PromptImprovementSkillResponseModel(**response_dct)

# get tokens and token cost
resp = ImprovedPromptResponse(output=output, **response_dct)
return resp


def create_agent_from_dict(json_dict: Dict):
"""
Expand Down
149 changes: 149 additions & 0 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from pydantic import BaseModel, field_validator, Field, ConfigDict
from adala.skills import Skill
from typing import Any, Dict, List, Optional, Union
from adala.skills.collection.text_generation import TextGenerationSkill
from adala.utils.parse import parse_template


# NOTE: these response models are converging with the LSE ResultHandler, slowly pushing typing deeper into the lib with the end goal of combining them


class PromptImprovementSkillResponseModel(BaseModel):

# hidden variable, used for validation
_input_variables: List[str]
reasoning: str
improved_user_prompt: str
# NOTE: not exposed in LSE yet, so default is always used. Should improve this as well when we expose it.
# improved_system_prompt: str

model_config = ConfigDict(
# omit other fields
extra="ignore",
# guard against name collisions with other fields
populate_by_name=False,
)

@field_validator("improved_user_prompt", mode="after")
def validate_used_variables(cls, value: str) -> str:

templates = parse_template(value, include_texts=False)
if not templates:
raise ValueError("At least one input variable must be used in the prompt")

input_vars_used = [t["text"] for t in templates]
if extra_vars_used := set(input_vars_used) - set(cls._input_variables):
raise ValueError(
f"Invalid variable used in prompt: {extra_vars_used}. Valid variables are: {cls._input_variables}"
)

return value


class ErrorResponseModel(BaseModel):
message: str = Field(..., alias="_adala_message")
details: str = Field(..., alias="_adala_details")

model_config = ConfigDict(
# omit other fields
extra="ignore",
# guard against name collisions with other fields
populate_by_name=False,
)


class ImprovedPromptResponse(BaseModel):

output: Union[PromptImprovementSkillResponseModel, ErrorResponseModel]

prompt_tokens: int = Field(alias="_prompt_tokens")
completion_tokens: int = Field(alias="_completion_tokens")

# these can fail to calculate
prompt_cost_usd: Optional[float] = Field(alias="_prompt_cost_usd")
completion_cost_usd: Optional[float] = Field(alias="_completion_cost_usd")
total_cost_usd: Optional[float] = Field(alias="_total_cost_usd")


def get_prompt_improvement_inputs(
student_skill: Skill, input_variables: List[str], student_model: str
) -> Dict[str, Any]:
return {
"model": student_model,
"task_name": student_skill.name,
"task_description": student_skill.description,
"input_variables": input_variables,
"current_system_prompt": student_skill.instructions,
"current_user_prompt": student_skill.input_template,
"response_json_schema": student_skill.response_model.model_json_schema(),
}


def get_prompt_improvement_skill(input_variables: List[str]) -> TextGenerationSkill:

# setting this dynamically - used to validate the improved prompt
PromptImprovementSkillResponseModel._input_variables = input_variables

prompt_improvement_skill = TextGenerationSkill(
name="prompt_improvement",
# system prompt
# TODO add fewshot examples
instructions="""
You are a prompt improvement agent.

# Instructions

Improve the user prompt for an LLM model to complete a task using input variables, with the provided prompt improvement inputs as a starting point. Provide your reasoning for the changes you made to the prompt.


# Notes

- The inputs available to you are: Model, Task Name, Task Description, Input Variables, Current System Prompt, Current User Prompt, Response JSON Schema.
- Input Variables can be accessed in the user prompt using the format {variable_name} (only the variable values are used, not their names).
- Make sure your prompt produces output that will continue to conform to the Response JSON Schema.
- Provide your reasoning for the changes you made to the prompt. Provide the reasoning before providing the improved prompt.

""",
# user prompt
input_template="""
# Prompt Improvement Inputs

## Model

{model}


## Task Name

{task_name}


## Task Description

{task_description}


## Input Variables

{input_variables}


## Current System Prompt

{current_system_prompt}


## Current User Prompt

{current_user_prompt}


## Response JSON Schema

{response_json_schema}

""",
response_model=PromptImprovementSkillResponseModel,
)

return prompt_improvement_skill
63 changes: 62 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,30 @@
from typing import Any, Dict, Generic, List, Optional, TypeVar
import os
import json
import pandas as pd

import fastapi
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from adala.agents import Agent
from adala.skills import Skill
from adala.runtimes import AsyncRuntime
from aiokafka import AIOKafkaProducer
from aiokafka.errors import UnknownTopicOrPartitionError
from fastapi import HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic import BaseModel, SerializeAsAny, field_validator, Field, model_validator
from redis import Redis
import time
import uvicorn

from server.handlers.result_handlers import ResultHandler
from server.log_middleware import LogMiddleware
from adala.skills.collection.prompt_improvement import (
ImprovedPromptResponse,
ErrorResponseModel,
)
from server.tasks.stream_inference import streaming_parent_task
from server.utils import (
Settings,
Expand Down Expand Up @@ -307,6 +314,60 @@ async def ready(redis_conn: Redis = Depends(_get_redis_conn)):
return {"status": "ok"}


class ImprovedPromptRequest(BaseModel):
"""
Request model for improving a prompt.
"""

agent: Agent
skill_to_improve: str
input_variables: Optional[List[str]] = Field(
default=None,
description="List of variables available to use in the input template of the skill, in case any exist that are not currently used",
)

@field_validator("agent", mode="after")
def validate_teacher_runtime(cls, agent: Agent) -> Agent:
if not isinstance(agent.get_teacher_runtime(), AsyncRuntime):
raise ValueError("Default teacher runtime must be an AsyncRuntime")
return agent

@model_validator(mode="after")
def set_input_variable_list(self):
skill = self.agent.skills[self.skill_to_improve]
if self.input_variables is None:
self.input_variables = skill.get_input_fields()
return self


@app.post("/improved-prompt", response_model=Response[ImprovedPromptResponse])
async def improved_prompt(request: ImprovedPromptRequest):
"""
Improve a given prompt using the specified model and variables.

Args:
request (ImprovedPromptRequest): The request model for improving a prompt.

Returns:
Response: Response model for prompt improvement skill
"""

agent = request.agent
data = await agent.arefine_skill(request.skill_to_improve, request.input_variables)

if isinstance(data.output, ErrorResponseModel):
# insert error into Response
return Response(
success=False,
data=data,
message=data.output.details,
errors=[data.output.message],
)
else:
# return output
return Response(data=data)


if __name__ == "__main__":
# for debugging
uvicorn.run("app:app", host="0.0.0.0", port=30001)
Loading
Loading