Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all methods async def again; add completion() for meta-reference #270

Merged
merged 5 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions docs/resources/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
},
"servers": [
{
Expand Down Expand Up @@ -2830,8 +2830,11 @@
"CompletionResponse": {
"type": "object",
"properties": {
"completion_message": {
"$ref": "#/components/schemas/CompletionMessage"
"content": {
"type": "string"
},
"stop_reason": {
"$ref": "#/components/schemas/StopReason"
},
"logprobs": {
"type": "array",
Expand All @@ -2842,7 +2845,8 @@
},
"additionalProperties": false,
"required": [
"completion_message"
"content",
"stop_reason"
],
"title": "Completion response."
},
Expand Down Expand Up @@ -6075,49 +6079,49 @@
],
"tags": [
{
"name": "Evaluations"
},
{
"name": "Inspect"
"name": "Models"
},
{
"name": "RewardScoring"
},
{
"name": "Datasets"
"name": "MemoryBanks"
},
{
"name": "Models"
"name": "Shields"
},
{
"name": "Telemetry"
"name": "SyntheticDataGeneration"
},
{
"name": "PostTraining"
"name": "Inference"
},
{
"name": "SyntheticDataGeneration"
"name": "Inspect"
},
{
"name": "BatchInference"
},
{
"name": "Inference"
"name": "Memory"
},
{
"name": "Datasets"
},
{
"name": "Agents"
},
{
"name": "Memory"
"name": "PostTraining"
},
{
"name": "Safety"
"name": "Telemetry"
},
{
"name": "Shields"
"name": "Safety"
},
{
"name": "MemoryBanks"
"name": "Evaluations"
},
{
"name": "BuiltinTool",
Expand Down
31 changes: 17 additions & 14 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,17 @@ components:
CompletionResponse:
additionalProperties: false
properties:
completion_message:
$ref: '#/components/schemas/CompletionMessage'
content:
type: string
logprobs:
items:
$ref: '#/components/schemas/TokenLogProbs'
type: array
stop_reason:
$ref: '#/components/schemas/StopReason'
required:
- completion_message
- content
- stop_reason
title: Completion response.
type: object
CompletionResponseStreamChunk:
Expand Down Expand Up @@ -2507,7 +2510,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
\ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
Expand Down Expand Up @@ -3712,21 +3715,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Evaluations
- name: Inspect
- name: RewardScoring
- name: Datasets
- name: Models
- name: Telemetry
- name: PostTraining
- name: RewardScoring
- name: MemoryBanks
- name: Shields
- name: SyntheticDataGeneration
- name: BatchInference
- name: Inference
- name: Agents
- name: Inspect
- name: BatchInference
- name: Memory
- name: Datasets
- name: Agents
- name: PostTraining
- name: Telemetry
- name: Safety
- name: Shields
- name: MemoryBanks
- name: Evaluations
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
Expand Down
4 changes: 1 addition & 3 deletions llama_stack/apis/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,8 @@ async def create_agent(
agent_config: AgentConfig,
) -> AgentCreateResponse: ...

# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
def create_agent_turn(
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/apis/agents/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ async def create_agent_session(
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())

def create_agent_turn(
async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
return await self._nonstream_agent_turn(request)

async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
Expand Down Expand Up @@ -126,7 +126,7 @@ async def _run_agent(

for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
iterator = api.create_agent_turn(
iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
Expand Down
7 changes: 4 additions & 3 deletions llama_stack/apis/inference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(self, request: CompletionRequest) -> AsyncGenerator:
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down Expand Up @@ -139,7 +139,8 @@ async def run_main(
else:
logprobs_config = None

iterator = client.chat_completion(
assert stream, "Non streaming not supported here"
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
Expand Down
13 changes: 6 additions & 7 deletions llama_stack/apis/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""

completion_message: CompletionMessage
content: str
stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None


Expand All @@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""

completion_message_batch: List[CompletionMessage]
batch: List[CompletionResponse]


@json_schema_type
Expand Down Expand Up @@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):

@json_schema_type
class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage]
batch: List[ChatCompletionResponse]


@json_schema_type
Expand All @@ -181,10 +182,8 @@ def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore

# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -196,7 +195,7 @@ def completion(
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down
12 changes: 6 additions & 6 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def shutdown(self) -> None:
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -93,11 +93,11 @@ def chat_completion(
)
provider = self.routing_table.get_provider_impl(model)
if stream:
return (chunk async for chunk in provider.chat_completion(**params))
return (chunk async for chunk in await provider.chat_completion(**params))
else:
return provider.chat_completion(**params)
return await provider.chat_completion(**params)

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -114,9 +114,9 @@ def completion(
logprobs=logprobs,
)
if stream:
return (chunk async for chunk in provider.completion(**params))
return (chunk async for chunk in await provider.completion(**params))
else:
return provider.completion(**params)
return await provider.completion(**params)

async def embeddings(
self,
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/providers/adapters/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
self.client.close()

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand Down Expand Up @@ -283,7 +283,7 @@ def _tools_to_tool_config(
)
return tool_config

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -58,7 +58,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -84,7 +84,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass

def completion(
async def completion(
self,
model: str,
content: InterleavedTextMedia,
Expand All @@ -61,7 +61,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()

def chat_completion(
async def chat_completion(
self,
model: str,
messages: List[Message],
Expand All @@ -87,7 +87,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
return self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
Expand Down
Loading