Skip to content

Commit

Permalink
ci: auto fixes from pre-commit.ci
Browse files Browse the repository at this point in the history
For more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 3, 2025
1 parent 596b4fa commit 34c5607
Show file tree
Hide file tree
Showing 32 changed files with 157 additions and 149 deletions.
12 changes: 6 additions & 6 deletions src/_bentoml_impl/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@ def import_service(
try:
module_name, _, attrs_str = service_identifier.partition(":")

assert (
module_name and attrs_str
), f'Invalid import target "{service_identifier}", must format as "<module>:<attribute>"'
assert module_name and attrs_str, (
f'Invalid import target "{service_identifier}", must format as "<module>:<attribute>"'
)
module = importlib.import_module(module_name)
root_service_name, _, depend_path = attrs_str.partition(".")
root_service = t.cast("Service[t.Any]", getattr(module, root_service_name))

assert isinstance(
root_service, Service
), f'import target "{module_name}:{attrs_str}" is not a bentoml.Service instance'
assert isinstance(root_service, Service), (
f'import target "{module_name}:{attrs_str}" is not a bentoml.Service instance'
)

if not depend_path:
svc = root_service
Expand Down
6 changes: 3 additions & 3 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def server_request_hook(span: Span, _scope: dict[str, t.Any]) -> None:
middlewares.append(Middleware(middleware_cls, **options))
# CORS middleware
if self.enable_access_control:
assert (
self.access_control_options.get("allow_origins") is not None
), "To enable cors, access_control_allow_origin must be set"
assert self.access_control_options.get("allow_origins") is not None, (
"To enable cors, access_control_allow_origin must be set"
)

from starlette.middleware.cors import CORSMiddleware

Expand Down
6 changes: 3 additions & 3 deletions src/_bentoml_impl/server/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ def serve_http(
if isinstance(bento_identifier, Service):
svc = bento_identifier
bento_identifier = svc.import_string
assert (
working_dir is None
), "working_dir should not be set when passing a service in process"
assert working_dir is None, (
"working_dir should not be set when passing a service in process"
)
# use cwd
bento_path = pathlib.Path(".")
else:
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/batch/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def process(
) -> t.Generator[RecordBatch, None, None]:
svc = _load_bento_spark(bento_tag)

assert (
api_name in svc.apis
), "An error occurred transferring the Bento to the Spark worker."
assert api_name in svc.apis, (
"An error occurred transferring the Bento to the Spark worker."
)
inference_api = svc.apis[api_name]
assert inference_api.func is not None, "Inference API function not defined"

Expand Down
24 changes: 12 additions & 12 deletions src/bentoml/_internal/client/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def __init__(
self._interceptors = interceptors
self._credentials = None
if ssl:
assert (
ssl_client_credentials is not None
), "'ssl=True' requires 'ssl_client_credentials'"
assert ssl_client_credentials is not None, (
"'ssl=True' requires 'ssl_client_credentials'"
)
self._credentials = grpc.ssl_channel_credentials(
**{
k: load_from_file(v) if isinstance(v, str) else v
Expand Down Expand Up @@ -124,9 +124,9 @@ def _create_channel(
compression: grpc.Compression | None = None,
) -> GrpcAsyncChannel:
if ssl:
assert (
ssl_client_credentials is not None
), "'ssl=True' requires 'ssl_client_credentials'"
assert ssl_client_credentials is not None, (
"'ssl=True' requires 'ssl_client_credentials'"
)
return aio.secure_channel(
server_url,
credentials=grpc.ssl_channel_credentials(
Expand Down Expand Up @@ -447,9 +447,9 @@ def __init__(
self._interceptors = interceptors
self._credentials = None
if ssl:
assert (
ssl_client_credentials is not None
), "'ssl=True' requires 'ssl_client_credentials'"
assert ssl_client_credentials is not None, (
"'ssl=True' requires 'ssl_client_credentials'"
)
self._credentials = grpc.ssl_channel_credentials(
**{
k: load_from_file(v) if isinstance(v, str) else v
Expand Down Expand Up @@ -483,9 +483,9 @@ def _create_channel(
compression: grpc.Compression | None = None,
) -> GrpcSyncChannel:
if ssl:
assert (
ssl_client_credentials is not None
), "'ssl=True' requires 'ssl_client_credentials'"
assert ssl_client_credentials is not None, (
"'ssl=True' requires 'ssl_client_credentials'"
)
return grpc.secure_channel(
server_url,
credentials=grpc.ssl_channel_credentials(
Expand Down
8 changes: 4 additions & 4 deletions src/bentoml/_internal/frameworks/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def forward(self, x, bias):
break
except PackageNotFoundError:
pass
assert (
_onnxruntime_pkg is not None and _onnxruntime_version is not None
), "Failed to find onnxruntime package version."
assert _onnxruntime_pkg is not None and _onnxruntime_version is not None, (
"Failed to find onnxruntime package version."
)

assert _onnxruntime_version is not None, "onnxruntime is not installed"
if not isinstance(model, onnx.ModelProto):
Expand All @@ -291,7 +291,7 @@ def forward(self, x, bias):
provided_methods = list(signatures.keys())
if provided_methods != ["run"]:
raise BentoMLException(
f"Provided method names {[m for m in provided_methods if m != 'run']} are invalid. 'bentoml.onnx' will load ONNX model into an 'onnxruntime.InferenceSession' for inference, so the only supported method name is 'run'."
f"Provided method names {[m for m in provided_methods if m != 'run']} are invalid. 'bentoml.onnx' will load ONNX model into an 'onnxruntime.InferenceSession' for inference, so the only supported method name is 'run'."
)

run_input_specs = [MessageToDict(inp) for inp in model.graph.input]
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/frameworks/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def configure_optimizers(self):

script_module = model.to_torchscript()

assert not isinstance(
script_module, dict
), "Saving a dict of pytorch_lightning Module into one BentoModel is not supported"
assert not isinstance(script_module, dict), (
"Saving a dict of pytorch_lightning Module into one BentoModel is not supported"
)

return script_save_model(
name,
Expand Down
44 changes: 24 additions & 20 deletions src/bentoml/_internal/frameworks/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,9 @@ def load_model(bento_model: str | Tag | Model, *args: t.Any, **kwargs: t.Any) ->

kwargs.setdefault("pipeline_class", pipeline.__class__ if pipeline else None)

assert (
task in get_supported_tasks()
), f"Task '{task}' is not a valid task for pipeline (available: {get_supported_tasks()})."
assert task in get_supported_tasks(), (
f"Task '{task}' is not a valid task for pipeline (available: {get_supported_tasks()})."
)

return (
transformers.pipeline(task=task, model=bento_model.path, **kwargs)
Expand All @@ -474,9 +474,9 @@ def load_model(bento_model: str | Tag | Model, *args: t.Any, **kwargs: t.Any) ->
protocol: PreTrainedProtocol = cloudpickle.load(f)
return protocol.from_pretrained(bento_model.path, *args, **kwargs)
else:
assert (
len(args) == 0
), "Positional args are not supported for pipeline. Make sure to only use kwargs instead."
assert len(args) == 0, (
"Positional args are not supported for pipeline. Make sure to only use kwargs instead."
)
with open(bento_model.path_of(PIPELINE_PICKLE_NAME), "rb") as f:
pipeline_class: type[transformers.Pipeline] = cloudpickle.load(f)

Expand Down Expand Up @@ -504,9 +504,9 @@ def load_model(bento_model: str | Tag | Model, *args: t.Any, **kwargs: t.Any) ->

kwargs.setdefault("pipeline_class", pipeline_class)

assert (
task in get_supported_tasks()
), f"Task '{task}' is not a valid task for pipeline (available: {get_supported_tasks()})."
assert task in get_supported_tasks(), (
f"Task '{task}' is not a valid task for pipeline (available: {get_supported_tasks()})."
)

kwargs.update(options.kwargs)
if len(kwargs) > 0:
Expand Down Expand Up @@ -860,7 +860,9 @@ def import_model(
pretrained = t.cast("PreTrainedProtocol", model)
assert all(
hasattr(pretrained, defn) for defn in ("save_pretrained", "from_pretrained")
), f"'pretrained={pretrained}' is not a valid Transformers object. It must have 'save_pretrained' and 'from_pretrained' methods."
), (
f"'pretrained={pretrained}' is not a valid Transformers object. It must have 'save_pretrained' and 'from_pretrained' methods."
)
if metadata is None:
metadata = {}

Expand Down Expand Up @@ -1001,9 +1003,9 @@ def save_model(
)
pretrained_or_pipeline = pipeline

assert (
pretrained_or_pipeline is not None
), "Please provide a pipeline or a pretrained object as a second argument."
assert pretrained_or_pipeline is not None, (
"Please provide a pipeline or a pretrained object as a second argument."
)

# The below API are introduced since 4.18
if pkg_version_info("transformers")[:2] >= (4, 18):
Expand Down Expand Up @@ -1087,13 +1089,13 @@ def save_model(
task_name,
)

assert (
task_name in get_supported_tasks()
), f"Task '{task_name}' failed to register into pipeline registry."
assert task_name in get_supported_tasks(), (
f"Task '{task_name}' failed to register into pipeline registry."
)
else:
assert (
task_definition is None
), "'task_definition' must be None if 'task_name' is not provided."
assert task_definition is None, (
"'task_definition' must be None if 'task_name' is not provided."
)

# if task_name is None, then we derive the task from pipeline.task
options_args = t.cast(
Expand Down Expand Up @@ -1124,7 +1126,9 @@ def save_model(
pretrained = t.cast("PreTrainedProtocol", pretrained_or_pipeline)
assert all(
hasattr(pretrained, defn) for defn in ("save_pretrained", "from_pretrained")
), f"'pretrained={pretrained}' is not a valid Transformers object. It must have 'save_pretrained' and 'from_pretrained' methods."
), (
f"'pretrained={pretrained}' is not a valid Transformers object. It must have 'save_pretrained' and 'from_pretrained' methods."
)
if metadata is None:
metadata = {}

Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/frameworks/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _pretty_format_function_call(base: str, name: str, arg_names: t.Tuple[t.Any]

def _pretty_format_positional(positional: t.Optional["tf_ext.TensorSignature"]) -> str:
if positional is not None:
return f'Positional arguments ({len(positional)} total):\n {" * ".join(str(a) for a in positional)}' # noqa
return f"Positional arguments ({len(positional)} total):\n {' * '.join(str(a) for a in positional)}" # noqa
return "No positional arguments.\n"


Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ async def from_proto(
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names, _ in pb_v1alpha1.File.FileType.items()]}",
) from None
content = field.content
if not content:
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async def from_proto(self, field: pb.File | pb_v1alpha1.File | bytes) -> ImageTy
)
except KeyError:
raise BadInput(
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb_v1alpha1.File.FileType.items()]}",
f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names, _ in pb_v1alpha1.File.FileType.items()]}",
) from None
if not field.content:
raise BadInput("Content is empty!") from None
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/io_descriptors/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def __init__(
json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder,
):
if pydantic_model is not None:
assert issubclass(
pydantic_model, pydantic.BaseModel
), "'pydantic_model' must be a subclass of 'pydantic.BaseModel'."
assert issubclass(pydantic_model, pydantic.BaseModel), (
"'pydantic_model' must be a subclass of 'pydantic.BaseModel'."
)

self._pydantic_model = pydantic_model
self._json_encoder = json_encoder
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, **inputs: IODescriptor[t.Any]):
self._inputs = inputs

def __repr__(self) -> str:
return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})"
return f"Multipart({','.join([f'{k}={v}' for k, v in zip(self._inputs, map(repr, self._inputs.values()))])})"

def _from_sample(cls, sample: dict[str, t.Any]) -> t.Any:
raise NotImplementedError("'from_sample' is not supported for Multipart.")
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/io_descriptors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def process_columns_contents(content: pb.Series) -> dict[str, t.Any]:
# large tabular data
if len(content.ListFields()) != 1:
raise BadInput(
f"Array contents can only be one of given values key. Use one of '{list(map(lambda f: f[0].name,content.ListFields()))}' instead."
f"Array contents can only be one of given values key. Use one of '{list(map(lambda f: f[0].name, content.ListFields()))}' instead."
) from None
return {str(i): c for i, c in enumerate(content.ListFields()[0][1])}

Expand Down
36 changes: 18 additions & 18 deletions src/bentoml/_internal/runner/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def batches_to_batch(
) -> tuple[ext.PdDataFrame, list[int]]:
import pandas as pd

assert (
batch_dim == 0
), "PandasDataFrameContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"PandasDataFrameContainer does not support batch_dim other than 0"
)
indices = list(
itertools.accumulate(subbatch.shape[batch_dim] for subbatch in batches)
)
Expand All @@ -348,9 +348,9 @@ def batch_to_batches(
indices: t.Sequence[int],
batch_dim: int = 0,
) -> list[ext.PdDataFrame]:
assert (
batch_dim == 0
), "PandasDataFrameContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"PandasDataFrameContainer does not support batch_dim other than 0"
)

return [
batch.iloc[indices[i] : indices[i + 1]].reset_index(drop=True)
Expand All @@ -365,9 +365,9 @@ def to_payload(
) -> Payload:
import pandas as pd

assert (
batch_dim == 0
), "PandasDataFrameContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"PandasDataFrameContainer does not support batch_dim other than 0"
)

if isinstance(batch, pd.Series):
batch = pd.DataFrame([batch])
Expand Down Expand Up @@ -402,9 +402,9 @@ def to_payload(
def get_batch_size(
cls, batch: ext.PdDataFrame | ext.PdSeries, batch_dim: int
) -> int:
assert (
batch_dim == 0
), "PandasDataFrameContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"PandasDataFrameContainer does not support batch_dim other than 0"
)
return batch.shape

@classmethod
Expand Down Expand Up @@ -545,9 +545,9 @@ class DefaultContainer(DataContainer[t.Any, t.List[t.Any]]):
def batches_to_batch(
cls, batches: t.Sequence[list[t.Any]], batch_dim: int = 0
) -> tuple[list[t.Any], list[int]]:
assert (
batch_dim == 0
), "Default Runner DataContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"Default Runner DataContainer does not support batch_dim other than 0"
)
batch: list[t.Any] = []
for subbatch in batches:
batch.extend(subbatch)
Expand All @@ -559,9 +559,9 @@ def batches_to_batch(
def batch_to_batches(
cls, batch: list[t.Any], indices: t.Sequence[int], batch_dim: int = 0
) -> list[list[t.Any]]:
assert (
batch_dim == 0
), "Default Runner DataContainer does not support batch_dim other than 0"
assert batch_dim == 0, (
"Default Runner DataContainer does not support batch_dim other than 0"
)
return [batch[indices[i] : indices[i + 1]] for i in range(len(indices) - 1)]

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/bentoml/_internal/runner/runner_handle/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ async def async_run_method(
) -> tritongrpcclient.InferResult | tritonhttpclient.InferResult:
from ..container import AutoContainer

assert (
(len(args) == 0) ^ (len(kwargs) == 0)
), f"Inputs for model '{__bentoml_method.name}' can be given either as positional (args) or keyword arguments (kwargs), but not both. See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#model-configuration"
assert (len(args) == 0) ^ (len(kwargs) == 0), (
f"Inputs for model '{__bentoml_method.name}' can be given either as positional (args) or keyword arguments (kwargs), but not both. See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#model-configuration"
)

pass_args = args if len(args) > 0 else kwargs

Expand Down
Loading

0 comments on commit 34c5607

Please sign in to comment.