Skip to content

Commit

Permalink
Fix model_validator in InferenceEndpoints due to Pipeline pickl…
Browse files Browse the repository at this point in the history
…ing (#552)

* Allow nested connect calls and overload rshift method to connect steps (#490)

* Allow nested connect calls and overload rshift method to connect steps

* Update src/distilabel/steps/base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/unit/pipeline/test_base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/unit/pipeline/test_base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/unit/pipeline/test_base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/unit/pipeline/test_base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Update tests/unit/pipeline/test_base.py

Co-authored-by: Alvaro Bartolome <[email protected]>

* Add comment to simplify reading the tests

* Add reference on the Pipeline of alternative ways of connecting the steps

---------

Co-authored-by: Alvaro Bartolome <[email protected]>

* Fix `model_validator` in `InferenceEndpointsLLM` as called after `load` too

Due to the `pickle` usage within the `Pipeline`, the `model_validator` is not only called the first time `InferenceEndpointsLLM` is instantiated, but also once the `load` method has been run, so that it means that `base_url` will always have a value already, and then will always raise a `pydantic.ValidationError`; but that's been fixed already

* Set `1.0.1` version to release bug-fix

* Revert "Allow nested connect calls and overload rshift method to connect steps (#490)"

This reverts commit c3e7b0d.

---------

Co-authored-by: Agus <[email protected]>
Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
3 people authored Apr 19, 2024
1 parent ffc0f4f commit b870f39
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/distilabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

from rich import traceback as rich_traceback

__version__ = "1.0.0"
__version__ = "1.0.1"

rich_traceback.install(show_locals=True)
24 changes: 16 additions & 8 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class InferenceEndpointsLLM(AsyncLLM):
```python
from distilabel.llms.huggingface import InferenceEndpointsLLM
# Free serverless Inference API
llm = InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
Expand Down Expand Up @@ -123,21 +122,30 @@ class InferenceEndpointsLLM(AsyncLLM):
def only_one_of_model_id_endpoint_name_or_base_url_provided(
self,
) -> "InferenceEndpointsLLM":
"""Validates that only one of `model_id`, `endpoint_name`, or `base_url` is provided."""
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
favour of the dynamically calculated one.."""

if self.base_url and (self.model_id or self.endpoint_name):
self._logger.warning( # type: ignore
f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`"
" is also provided, the `base_url` will either be ignored or overwritten with the one generated"
" from either of those args, for serverless or dedicated inference endpoints, respectively."
)

if self.model_id and (not self.endpoint_name and not self.base_url):
if self.base_url and not (self.model_id or self.endpoint_name):
return self

if self.endpoint_name and (not self.model_id and not self.base_url):
if self.model_id and not self.endpoint_name:
return self

if self.base_url and (not self.model_id and not self.endpoint_name):
if self.endpoint_name and not self.model_id:
return self

raise ValidationError(
"Only one of `model_id`, `endpoint_name`, or `base_url` must be provided. Found"
f" `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name}, and"
f" `base_url`={self.base_url}."
"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too,"
" it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name},"
f" and `base_url`={self.base_url}."
)

def load(self) -> None: # noqa: C901
Expand Down

0 comments on commit b870f39

Please sign in to comment.