Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aidando73 committed Dec 15, 2024
1 parent 7076e66 commit d9db9a0
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions llama_stack/providers/remote/inference/groq/groq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,30 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import warnings
from typing import AsyncGenerator, Generator, Literal
import json

from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
ChatCompletionAssistantMessageParam,
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.shared.function_definition import FunctionDefinition
from groq.types.shared.function_parameters import FunctionParameters

from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand All @@ -38,13 +39,14 @@
Role,
StopReason,
ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolParamDefinition,
ToolCallParseStatus,
ToolCallDelta,
ToolPromptFormat,
)


def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
Expand Down Expand Up @@ -85,6 +87,7 @@ def convert_chat_completion_request(
tool_choice=request.tool_choice.value if request.tool_choice else None,
)


def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == Role.system.value:
return ChatCompletionSystemMessageParam(role="system", content=message.content)
Expand All @@ -98,7 +101,6 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
raise ValueError(f"Invalid message role: {message.role}")



def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
Expand All @@ -114,13 +116,11 @@ def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
)
),
)


def _convert_groq_tool_parameter(
tool_parameter: ToolParamDefinition
) -> dict:
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
Expand Down Expand Up @@ -211,7 +211,9 @@ def _event_type_generator() -> (
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)

# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
Expand All @@ -233,6 +235,7 @@ def _event_type_generator() -> (
)
)


def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
call_id=tool_call.id,
Expand Down

0 comments on commit d9db9a0

Please sign in to comment.