From d8dc1ea555e673f6af5ceeb630583889189a6157 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Fri, 24 Jan 2025 16:56:08 -0300 Subject: [PATCH 1/5] Fix the pydantic logging validator Prevent the validator from logging unused extra fields before object is validated by the pydantic validator. During the routing based on the request content where pydantic tries several classes for deserialization, this would result in spurious logs. Signed-off-by: Max de Bayser --- vllm/entrypoints/openai/protocol.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 80403f77d5375..76e7043c117af 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -45,14 +45,14 @@ class OpenAIBaseModel(BaseModel): # Cache class field names field_names: ClassVar[Optional[Set[str]]] = None - @model_validator(mode="before") + @model_validator(mode="wrap") @classmethod - def __log_extra_fields__(cls, data): - + def __log_extra_fields__(cls, data, handler): + result = handler(data) field_names = cls.field_names if field_names is None: if not isinstance(data, dict): - return data + return result # Get all class field names and their potential aliases field_names = set() for field_name, field in cls.model_fields.items(): @@ -67,7 +67,7 @@ def __log_extra_fields__(cls, data): "The following fields were present in the request " "but ignored: %s", data.keys() - field_names) - return data + return result class ErrorResponse(OpenAIBaseModel): From 961e3fd17e33a270c5aea71ab251d5680d310ef1 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 27 Jan 2025 10:24:25 -0300 Subject: [PATCH 2/5] Don't validate non-dict objects Signed-off-by: Max de Bayser --- vllm/entrypoints/openai/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 76e7043c117af..d346cd77261ec 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -49,10 +49,10 @@ class OpenAIBaseModel(BaseModel): @classmethod def __log_extra_fields__(cls, data, handler): result = handler(data) + if not isinstance(data, dict): + return result field_names = cls.field_names if field_names is None: - if not isinstance(data, dict): - return result # Get all class field names and their potential aliases field_names = set() for field_name, field in cls.model_fields.items(): From 5e053e736827f41d99e94e30b05895bbe276d2e8 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 27 Jan 2025 10:35:09 -0300 Subject: [PATCH 3/5] trigger ci Signed-off-by: Max de Bayser From 9cea2546b9703113cd1487a9a88336225dd6ffa1 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 27 Jan 2025 16:43:52 -0300 Subject: [PATCH 4/5] trigger ci Signed-off-by: Max de Bayser From daa1e27992ca863ee0cd41b3357c08e5cf3186de Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Tue, 28 Jan 2025 23:57:57 -0300 Subject: [PATCH 5/5] Fix BatchRequestInput pydantic validation When not enough input fields are given, the deserialization of the body is ambiguous. Signed-off-by: Max de Bayser --- vllm/entrypoints/openai/protocol.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index eb824e20abab0..157953d706fdb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,8 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union import torch -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, + ValidationInfo, field_validator, model_validator) from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -1285,6 +1286,20 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] + @field_validator('body', mode='plain') + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url = info.data['url'] + if url == "/v1/chat/completions": + return ChatCompletionRequest.model_validate(value) + if url == "/v1/embeddings": + return TypeAdapter(EmbeddingRequest).validate_python(value) + if url == "/v1/score": + return ScoreRequest.model_validate(value) + return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest]).validate_python(value) + class BatchResponseData(OpenAIBaseModel): # HTTP status code of the response.