Skip to content

feat: DIA-1402: V1-Submit Prompt auto-refinement job #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import traceback
from pydantic import (
BaseModel,
Field,
Expand All @@ -7,17 +8,19 @@
SerializeAsAny,
)
from abc import ABC
from typing import Optional, Dict, Union, Tuple
from typing import Optional, Dict, Union, Tuple, List
from rich import print
import yaml

from adala.environments.base import Environment, AsyncEnvironment, EnvironmentFeedback
from adala.environments.static_env import StaticEnvironment
from adala.runtimes.base import Runtime, AsyncRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.skills._base import Skill
from adala.skills._base import Skill, TransformSkill
from adala.memories.base import Memory
from adala.skills.skillset import SkillSet, LinearSkillSet
from adala.skills.collection.prompt_improvement import ImprovedPromptResponse

from adala.utils.logs import (
print_dataframe,
print_text,
Expand All @@ -26,7 +29,7 @@
is_running_in_jupyter,
)
from adala.utils.internal_data import InternalDataFrame

from adala.utils.types import BatchData
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -61,7 +64,7 @@ class Agent(BaseModel, ABC):
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
)
default_runtime: str = "default"
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
teacher_runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {"default": None}
)
default_teacher_runtime: str = "default"
Expand Down Expand Up @@ -118,7 +121,7 @@ def skills_validator(cls, v) -> SkillSet:
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
)

@field_validator("runtimes", mode="before")
@field_validator("runtimes", "teacher_runtimes", mode="before")
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
Validates and creates runtimes
Expand Down Expand Up @@ -393,6 +396,48 @@ def learn(

print_text("Train is done!")

async def arefine_skill(
self,
skill_name: str,
input_variables: List[str],
batch_data: Optional[BatchData] = None,
) -> ImprovedPromptResponse:
"""
beta v2 of Agent.learn() that is:
- compatible with the newer LiteLLM runtimes
- compatible with the newer response_model output formats for skills
- returns chain of thought reasoning in a legible format

Limitations so far:
- single skill at a time
- only returns the improved input_template, doesn't modify the skill in place
- doesn't use examples/feedback
- no iterations/variable cost
"""

skill = self.skills[skill_name]
if not isinstance(skill, TransformSkill):
raise ValueError(f"Skill {skill_name} is not a TransformSkill")

# get default runtimes
runtime = self.get_runtime()
teacher_runtime = self.get_teacher_runtime()

# get inputs
# TODO: replace it with async environment.get_data_batch()
if batch_data is None:
predictions = None
else:
inputs = InternalDataFrame.from_records(batch_data or [])
predictions = await self.skills.aapply(inputs, runtime=runtime)

response = await skill.aimprove(
predictions=predictions,
teacher_runtime=teacher_runtime,
target_input_variables=input_variables,
)
return response


def create_agent_from_dict(json_dict: Dict):
"""
Expand Down
41 changes: 40 additions & 1 deletion adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,21 @@ def record_to_record(
usage = completion.usage
dct = to_jsonable_python(response)
except IncompleteOutputException as e:
logger.error(f"Incomplete output error: {str(e)}")
logger.error(f"Traceback:\n{traceback.format_exc()}")
usage = e.total_usage
dct = _log_llm_exception(e)
except InstructorRetryException as e:
logger.error(f"Instructor retry error: {str(e)}")
logger.error(f"Traceback:\n{traceback.format_exc()}")
usage = e.total_usage
# get root cause error from retries
n_attempts = e.n_attempts
e = e.__cause__.last_attempt.exception()
dct = _log_llm_exception(e)
except Exception as e:
logger.error(f"Other error: {str(e)}")
logger.error(f"Traceback:\n{traceback.format_exc()}")
# usage = e.total_usage
# not available here, so have to approximate by hand, assuming the same error occurred each time
n_attempts = retries.stop.max_attempt_number
Expand Down Expand Up @@ -485,8 +491,41 @@ async def record_to_record(
extra_fields: Optional[Dict[str, Any]] = None,
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
response_model: Optional[Type[BaseModel]] = None,
) -> Dict[str, str]:
raise NotImplementedError("record_to_record is not implemented")
"""
Execute LiteLLM request given record and templates for input,
instructions and output.

Args:
record: Record to be used for input, instructions and output templates.
input_template: Template for input message.
instructions_template: Template for instructions message.
output_template: Template for output message.
extra_fields: Extra fields to be used in templates.
field_schema: Field jsonschema to be used for parsing templates.
instructions_first: If True, instructions will be sent before input.

Returns:
Dict[str, str]: The processed record.
"""
# Create a single-row DataFrame from the input record
input_df = InternalDataFrame([record])

# Use the batch_to_batch method to process the single-row DataFrame
output_df = await self.batch_to_batch(
input_df,
input_template=input_template,
instructions_template=instructions_template,
output_template=output_template,
extra_fields=extra_fields,
field_schema=field_schema,
instructions_first=instructions_first,
response_model=response_model,
)

# Extract the single row from the output DataFrame and convert it to a dictionary
return output_df.iloc[0].to_dict()


class LiteLLMVisionRuntime(LiteLLMChatRuntime):
Expand Down
123 changes: 100 additions & 23 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import string
import traceback
from pydantic import (
BaseModel,
Field,
Expand Down Expand Up @@ -479,6 +480,50 @@ def improve(
self.instructions = new_prompt


async def aimprove(self, teacher_runtime: AsyncRuntime, target_input_variables: List[str], predictions: Optional[InternalDataFrame] = None):
"""
Improves the skill.
"""

from adala.skills.collection.prompt_improvement import PromptImprovementSkill, ImprovedPromptResponse, ErrorResponseModel, PromptImprovementSkillResponseModel
response_dct = {}
try:
prompt_improvement_skill = PromptImprovementSkill(
skill_to_improve=self,
input_variables=target_input_variables,
)
if predictions is None:
input_df = InternalDataFrame()
else:
input_df = predictions
response_df = await prompt_improvement_skill.aapply(
input=input_df,
runtime=teacher_runtime,
)

# awkward to go from response model -> dict -> df -> dict -> response model
response_dct = response_df.iloc[0].to_dict()

# unflatten the response
if response_dct.pop("_adala_error", False):
output = ErrorResponseModel(**response_dct)
else:
output = PromptImprovementSkillResponseModel(**response_dct)

except Exception as e:
logger.error(f"Error improving skill: {e}. Traceback: {traceback.format_exc()}")
output = ErrorResponseModel(
_adala_message=str(e),
_adala_details=traceback.format_exc(),
)

# get tokens and token cost
resp = ImprovedPromptResponse(output=output, **response_dct)
logger.debug(f"resp: {resp}")

return resp


class SampleTransformSkill(TransformSkill):
sample_size: int

Expand Down Expand Up @@ -548,30 +593,22 @@ class AnalysisSkill(Skill):
Analysis skill that analyzes a dataframe and returns a record (e.g. for data analysis purposes).
See base class Skill for more information about the attributes.
"""

input_prefix: str = ""
input_separator: str = "\n"
chunk_size: Optional[int] = None

def apply(
self,
input: Union[InternalDataFrame, InternalSeries, Dict],
runtime: Runtime,
) -> InternalDataFrame:
"""
Applies the skill to a dataframe and returns a record.

Args:
input (InternalDataFrame): The input data to be processed.
runtime (Runtime): The runtime instance to be used for processing.
def _iter_over_chunks(self, input: InternalDataFrame, chunk_size: Optional[int] = None):

Returns:
InternalSeries: The record containing the analysis results.
"""
if input.empty:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is a wip, the chunk iteration flow will be improved within the current PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yield ""
return

if isinstance(input, InternalSeries):
input = input.to_frame()
elif isinstance(input, dict):
input = InternalDataFrame([input])


extra_fields = self._get_extra_fields()

# if chunk_size is specified, split the input into chunks and process each chunk separately
Expand All @@ -582,25 +619,65 @@ def apply(
)
else:
chunks = [input]
outputs = []

total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
agg_chunk = (
chunk.reset_index()
agg_chunk = chunk\
.reset_index()\
.apply(
lambda row: self.input_template.format(
**row, **extra_fields, i=int(row.name) + 1
),
axis=1,
)
.str.cat(sep=self.input_separator)
)
).str.cat(sep=self.input_separator)

yield agg_chunk

def apply(
self,
input: Union[InternalDataFrame, InternalSeries, Dict],
runtime: Runtime,
) -> InternalDataFrame:
"""
Applies the skill to a dataframe and returns a record.

Args:
input (InternalDataFrame): The input data to be processed.
runtime (Runtime): The runtime instance to be used for processing.

Returns:
InternalSeries: The record containing the analysis results.
"""
outputs = []
for agg_chunk in self._iter_over_chunks(input):
output = runtime.record_to_record(
{"input": agg_chunk},
{"input": f"{self.input_prefix}{agg_chunk}"},
input_template="{input}",
output_template=self.output_template,
instructions_template=self.instructions,
instructions_first=self.instructions_first,
response_model=self.response_model,
)
outputs.append(InternalSeries(output))
output = InternalDataFrame(outputs)

return output

async def aapply(
self,
input: Union[InternalDataFrame, InternalSeries, Dict],
runtime: AsyncRuntime,
) -> InternalDataFrame:
"""
Applies the skill to a dataframe and returns a record.
"""
outputs = []
for agg_chunk in self._iter_over_chunks(input):
output = await runtime.record_to_record(
{"input": f"{self.input_prefix}{agg_chunk}"},
input_template="{input}",
output_template=self.output_template,
instructions_template=self.instructions,
extra_fields=extra_fields,
instructions_first=self.instructions_first,
response_model=self.response_model,
)
Expand Down
Loading
Loading