Skip to content

Commit

Permalink
Support for new tools and agents API (#73)
Browse files Browse the repository at this point in the history
This PR updates the client SDK with latest tool and agent API changes.
Adds utility to define an agent with memory, which abstracts the
creation of memory bank and memory tool from the user.
Server changes: meta-llama/llama-stack#673
Test plan:

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py
  • Loading branch information
dineshyv authored Jan 9, 2025
1 parent f5cffac commit 40da0d0
Show file tree
Hide file tree
Showing 49 changed files with 2,781 additions and 278 deletions.
6 changes: 6 additions & 0 deletions src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,9 @@ def __init__(self, **kwargs: Any) -> None:

class SyncHttpxClientWrapper(DefaultHttpxClient):
def __del__(self) -> None:
if self.is_closed:
return

try:
self.close()
except Exception:
Expand Down Expand Up @@ -1334,6 +1337,9 @@ def __init__(self, **kwargs: Any) -> None:

class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient):
def __del__(self) -> None:
if self.is_closed:
return

try:
# TODO(someday): support non asyncio runtimes here
asyncio.get_running_loop().create_task(self.aclose())
Expand Down
316 changes: 233 additions & 83 deletions src/llama_stack_client/_client.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion src/llama_stack_client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,11 @@ def construct_type(*, value: object, type_: object) -> object:
_, items_type = get_args(type_) # Dict[_, items_type]
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}

if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
if (
not is_literal_type(type_)
and inspect.isclass(origin)
and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel))
):
if is_list(value):
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]

Expand Down
22 changes: 13 additions & 9 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,27 @@
from typing import List, Optional, Tuple, Union

from llama_stack_client import LlamaStackClient
from llama_stack_client.types import (Attachment, ToolResponseMessage,
UserMessage)
from llama_stack_client.types import ToolResponseMessage, UserMessage
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup

from .custom_tool import CustomTool
from .client_tool import ClientTool


class Agent:
def __init__(
self,
client: LlamaStackClient,
agent_config: AgentConfig,
custom_tools: Tuple[CustomTool] = (),
client_tools: Tuple[ClientTool] = (),
memory_bank_id: Optional[str] = None,
):
self.client = client
self.agent_config = agent_config
self.agent_id = self._create_agent(agent_config)
self.custom_tools = {t.get_name(): t for t in custom_tools}
self.client_tools = {t.get_name(): t for t in client_tools}
self.sessions = []
self.memory_bank_id = memory_bank_id

def _create_agent(self, agent_config: AgentConfig) -> int:
agentic_system_create_response = self.client.agents.create(
Expand Down Expand Up @@ -53,31 +55,33 @@ def _has_tool_call(self, chunk):
def _run_tool(self, chunk):
message = chunk.event.payload.turn.output_message
tool_call = message.tool_calls[0]
if tool_call.tool_name not in self.custom_tools:
if tool_call.tool_name not in self.client_tools:
return ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=f"Unknown tool `{tool_call.tool_name}` was called.",
role="ipython",
)
tool = self.custom_tools[tool_call.tool_name]
tool = self.client_tools[tool_call.tool_name]
result_messages = tool.run([message])
next_message = result_messages[0]
return next_message

def create_turn(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
attachments: Optional[List[Attachment]] = None,
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
):
response = self.client.agents.turn.create(
agent_id=self.agent_id,
# use specified session_id or last session created
session_id=session_id or self.session_id[-1],
messages=messages,
attachments=attachments,
stream=True,
documents=documents,
toolgroups=toolgroups,
)
for chunk in response:
if hasattr(chunk, "error"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from abc import abstractmethod
from typing import Dict, List, Union

from llama_stack_client.types import (FunctionCallToolDefinition,
ToolResponseMessage, UserMessage)
from llama_stack_client.types.tool_param_definition_param import \
ToolParamDefinitionParam
from llama_stack_client.types import ToolResponseMessage, UserMessage
from llama_stack_client.types.tool_def_param import (
Parameter,
ToolDefParam,
)


class CustomTool:
class ClientTool:
"""
Developers can define their custom tools that models can use
by extending this class.
Expand All @@ -37,7 +38,7 @@ def get_description(self) -> str:
raise NotImplementedError

@abstractmethod
def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
def get_params_definition(self) -> Dict[str, Parameter]:
raise NotImplementedError

def get_instruction_string(self) -> str:
Expand All @@ -48,16 +49,20 @@ def parameters_for_system_prompt(self) -> str:
{
"name": self.get_name(),
"description": self.get_description(),
"parameters": {name: definition.__dict__ for name, definition in self.get_params_definition().items()},
"parameters": {
name: definition.__dict__
for name, definition in self.get_params_definition().items()
},
}
)

def get_tool_definition(self) -> FunctionCallToolDefinition:
return FunctionCallToolDefinition(
type="function_call",
function_name=self.get_name(),
def get_tool_definition(self) -> ToolDefParam:
return ToolDefParam(
name=self.get_name(),
description=self.get_description(),
parameters=self.get_params_definition(),
parameters=list(self.get_params_definition().values()),
metadata={},
tool_prompt_format="python_list",
)

@abstractmethod
Expand Down
32 changes: 14 additions & 18 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,21 @@ def _get_log_event(self, chunk, previous_event_type=None, previous_step_type=Non
)

for r in details.tool_responses:
yield LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)

# memory retrieval
if step_type == "memory_retrieval" and event_type == "step_complete":
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = (
f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
)
if r.tool_name == "query_memory":
inserted_context = interleaved_content_as_str(r.content)
content = f"fetched {len(inserted_context)} bytes from memory"

yield LogEvent(
role=step_type,
content=content,
color="cyan",
)
yield LogEvent(
role=step_type,
content=content,
color="cyan",
)
else:
yield LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
)

def _get_event_type_step_type(self, chunk):
if hasattr(chunk, "event"):
Expand Down
42 changes: 42 additions & 0 deletions src/llama_stack_client/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
EvalResourceWithStreamingResponse,
AsyncEvalResourceWithStreamingResponse,
)
from .tools import (
ToolsResource,
AsyncToolsResource,
ToolsResourceWithRawResponse,
AsyncToolsResourceWithRawResponse,
ToolsResourceWithStreamingResponse,
AsyncToolsResourceWithStreamingResponse,
)
from .agents import (
AgentsResource,
AsyncAgentsResource,
Expand Down Expand Up @@ -120,6 +128,14 @@
EvalTasksResourceWithStreamingResponse,
AsyncEvalTasksResourceWithStreamingResponse,
)
from .toolgroups import (
ToolgroupsResource,
AsyncToolgroupsResource,
ToolgroupsResourceWithRawResponse,
AsyncToolgroupsResourceWithRawResponse,
ToolgroupsResourceWithStreamingResponse,
AsyncToolgroupsResourceWithStreamingResponse,
)
from .memory_banks import (
MemoryBanksResource,
AsyncMemoryBanksResource,
Expand All @@ -128,6 +144,14 @@
MemoryBanksResourceWithStreamingResponse,
AsyncMemoryBanksResourceWithStreamingResponse,
)
from .tool_runtime import (
ToolRuntimeResource,
AsyncToolRuntimeResource,
ToolRuntimeResourceWithRawResponse,
AsyncToolRuntimeResourceWithRawResponse,
ToolRuntimeResourceWithStreamingResponse,
AsyncToolRuntimeResourceWithStreamingResponse,
)
from .post_training import (
PostTrainingResource,
AsyncPostTrainingResource,
Expand Down Expand Up @@ -162,6 +186,24 @@
)

__all__ = [
"ToolgroupsResource",
"AsyncToolgroupsResource",
"ToolgroupsResourceWithRawResponse",
"AsyncToolgroupsResourceWithRawResponse",
"ToolgroupsResourceWithStreamingResponse",
"AsyncToolgroupsResourceWithStreamingResponse",
"ToolsResource",
"AsyncToolsResource",
"ToolsResourceWithRawResponse",
"AsyncToolsResourceWithRawResponse",
"ToolsResourceWithStreamingResponse",
"AsyncToolsResourceWithStreamingResponse",
"ToolRuntimeResource",
"AsyncToolRuntimeResource",
"ToolRuntimeResourceWithRawResponse",
"AsyncToolRuntimeResourceWithRawResponse",
"ToolRuntimeResourceWithStreamingResponse",
"AsyncToolRuntimeResourceWithStreamingResponse",
"AgentsResource",
"AsyncAgentsResource",
"AgentsResourceWithRawResponse",
Expand Down
33 changes: 21 additions & 12 deletions src/llama_stack_client/resources/agents/turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Iterable, cast
from typing import Any, List, Iterable, cast
from typing_extensions import Literal, overload

import httpx
Expand All @@ -26,7 +26,6 @@
from ..._base_client import make_request_options
from ...types.agents import turn_create_params, turn_retrieve_params
from ...types.agents.turn import Turn
from ...types.shared_params.attachment import Attachment
from ...types.agents.turn_create_response import TurnCreateResponse

__all__ = ["TurnResource", "AsyncTurnResource"]
Expand Down Expand Up @@ -59,8 +58,9 @@ def create(
agent_id: str,
messages: Iterable[turn_create_params.Message],
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
stream: Literal[False] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -89,7 +89,8 @@ def create(
messages: Iterable[turn_create_params.Message],
session_id: str,
stream: Literal[True],
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -118,7 +119,8 @@ def create(
messages: Iterable[turn_create_params.Message],
session_id: str,
stream: bool,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -146,8 +148,9 @@ def create(
agent_id: str,
messages: Iterable[turn_create_params.Message],
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand All @@ -170,8 +173,9 @@ def create(
"agent_id": agent_id,
"messages": messages,
"session_id": session_id,
"attachments": attachments,
"documents": documents,
"stream": stream,
"toolgroups": toolgroups,
},
turn_create_params.TurnCreateParams,
),
Expand Down Expand Up @@ -261,8 +265,9 @@ async def create(
agent_id: str,
messages: Iterable[turn_create_params.Message],
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
stream: Literal[False] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -291,7 +296,8 @@ async def create(
messages: Iterable[turn_create_params.Message],
session_id: str,
stream: Literal[True],
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -320,7 +326,8 @@ async def create(
messages: Iterable[turn_create_params.Message],
session_id: str,
stream: bool,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -348,8 +355,9 @@ async def create(
agent_id: str,
messages: Iterable[turn_create_params.Message],
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
documents: Iterable[turn_create_params.Document] | NotGiven = NOT_GIVEN,
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
toolgroups: List[turn_create_params.Toolgroup] | NotGiven = NOT_GIVEN,
x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand All @@ -372,8 +380,9 @@ async def create(
"agent_id": agent_id,
"messages": messages,
"session_id": session_id,
"attachments": attachments,
"documents": documents,
"stream": stream,
"toolgroups": toolgroups,
},
turn_create_params.TurnCreateParams,
),
Expand Down
Loading

0 comments on commit 40da0d0

Please sign in to comment.