diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index ab3c61542f..dbb336a421 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -188,7 +188,7 @@ def __init__( " It has the exact same behavior as `token`." ) - self.model: Optional[str] = model + self.model: Optional[str] = base_url or model self.token: Optional[str] = token if token is not None else api_key self.headers = headers if headers is not None else {} @@ -199,9 +199,6 @@ def __init__( self.timeout = timeout self.proxies = proxies - # OpenAI compatibility - self.base_url = base_url - def __repr__(self): return f"" @@ -328,9 +325,12 @@ def _inner_post( return response.iter_lines() if stream else response.content except HTTPError as error: if error.response.status_code == 422 and request_parameters.task != "unknown": - error.args = ( - f"{error.args[0]}\nMake sure '{request_parameters.task}' task is supported by the model.", - ) + error.args[1:] + msg = str(error.args[0]) + print(error.response.text) + if len(error.response.text) > 0: + msg += f"\n{error.response.text}\n" + msg += f"\nMake sure '{request_parameters.task}' task is supported by the model." + error.args = (msg,) + error.args[1:] if error.response.status_code == 503: # If Model is unavailable, either raise a TimeoutError... if timeout is not None and time.time() - t0 > timeout: @@ -934,9 +934,9 @@ def chat_completion( provider_helper = get_provider_helper(self.provider, task="conversational") # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. - # `self.base_url` and `self.model` takes precedence over 'model' argument for building URL. + # `self.model` takes precedence over 'model' argument for building URL. # `model` takes precedence for payload value. - model_id_or_url = self.base_url or self.model or model + model_id_or_url = self.model or model payload_model = model or self.model # Prepare the payload diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index cd419b0b60..6f111e68ba 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -179,7 +179,7 @@ def __init__( " It has the exact same behavior as `token`." ) - self.model: Optional[str] = model + self.model: Optional[str] = base_url or model self.token: Optional[str] = token if token is not None else api_key self.headers = headers if headers is not None else {} @@ -191,9 +191,6 @@ def __init__( self.trust_env = trust_env self.proxies = proxies - # OpenAI compatibility - self.base_url = base_url - # Keep track of the sessions to close them properly self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() @@ -977,9 +974,9 @@ async def chat_completion( provider_helper = get_provider_helper(self.provider, task="conversational") # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. - # `self.base_url` and `self.model` takes precedence over 'model' argument for building URL. + # `self.model` takes precedence over 'model' argument for building URL. # `model` takes precedence for payload value. - model_id_or_url = self.base_url or self.model or model + model_id_or_url = self.model or model payload_model = model or self.model # Prepare the payload diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 5a41baa094..c5437e46bd 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -1186,3 +1186,12 @@ def test_chat_completion_error_in_stream(): def test_resolve_chat_completion_url(model_url: str, expected_url: str): url = _build_chat_completion_url(model_url) assert url == expected_url + + +def test_pass_url_as_base_url(): + client = InferenceClient(base_url="http://localhost:8082/v1/") + provider = get_provider_helper("hf-inference", "text-generation") + request = provider.prepare_request( + inputs="The huggingface_hub library is ", parameters={}, headers={}, model=client.model, api_key=None + ) + assert request.url == "http://localhost:8082/v1/"