diff --git a/models.py b/models.py index bed1f79..51454b1 100644 --- a/models.py +++ b/models.py @@ -17,6 +17,15 @@ class Tool(BaseModel): type: str function: Function +class FunctionCall(BaseModel): + name: str + arguments: str + +class ToolCall(BaseModel): + id: str + type: str + function: FunctionCall + class ImageUrl(BaseModel): url: str @@ -29,7 +38,18 @@ class Message(BaseModel): role: str name: Optional[str] = None arguments: Optional[str] = None - content: Union[str, List[ContentItem]] + content: Optional[Union[str, List[ContentItem]]] = None + tool_calls: Optional[List[ToolCall]] = None + +class Message(BaseModel): + role: str + name: Optional[str] = None + content: Optional[Union[str, List[ContentItem]]] = None + tool_calls: Optional[List[ToolCall]] = None + tool_call_id: Optional[str] = None + + class Config: + extra = "allow" # 允许额外的字段 class RequestModel(BaseModel): model: str diff --git a/request.py b/request.py index 2be5753..a0715c1 100644 --- a/request.py +++ b/request.py @@ -126,7 +126,8 @@ async def get_gpt_payload(request, engine, provider): messages = [] for msg in request.messages: - name = None + tool_calls = None + tool_call_id = None if isinstance(msg.content, list): content = [] for item in msg.content: @@ -138,9 +139,23 @@ async def get_gpt_payload(request, engine, provider): content.append(image_message) else: content = msg.content - name = msg.name - if name: - messages.append({"role": msg.role, "name": name, "content": content}) + tool_calls = msg.tool_calls + tool_call_id = msg.tool_call_id + + if tool_calls: + tool_calls_list = [] + for tool_call in tool_calls: + tool_calls_list.append({ + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + }) + messages.append({"role": msg.role, "tool_calls": tool_calls_list}) + elif tool_call_id: + messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content}) else: messages.append({"role": msg.role, "content": content}) diff --git a/utils.py b/utils.py index edd250d..b9dd423 100644 --- a/utils.py +++ b/utils.py @@ -12,7 +12,6 @@ def update_config(config_data): model_dict[model] = model if type(model) == dict: model_dict.update({new: old for old, new in model.items()}) - # model_dict.update({old: old for old, new in model.items()}) provider['model'] = model_dict config_data['providers'][index] = provider api_keys_db = config_data['api_keys']