Skip to content

Commit

Permalink
Default to base_url if provided (#2805)
Browse files Browse the repository at this point in the history
* Default to base_url if provided

* Add test
  • Loading branch information
Wauplin authored Jan 30, 2025
1 parent 4bb4e7e commit 1033d0a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
18 changes: 9 additions & 9 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(
" It has the exact same behavior as `token`."
)

self.model: Optional[str] = model
self.model: Optional[str] = base_url or model
self.token: Optional[str] = token if token is not None else api_key
self.headers = headers if headers is not None else {}

Expand All @@ -199,9 +199,6 @@ def __init__(
self.timeout = timeout
self.proxies = proxies

# OpenAI compatibility
self.base_url = base_url

def __repr__(self):
return f"<InferenceClient(model='{self.model if self.model else ''}', timeout={self.timeout})>"

Expand Down Expand Up @@ -328,9 +325,12 @@ def _inner_post(
return response.iter_lines() if stream else response.content
except HTTPError as error:
if error.response.status_code == 422 and request_parameters.task != "unknown":
error.args = (
f"{error.args[0]}\nMake sure '{request_parameters.task}' task is supported by the model.",
) + error.args[1:]
msg = str(error.args[0])
print(error.response.text)
if len(error.response.text) > 0:
msg += f"\n{error.response.text}\n"
msg += f"\nMake sure '{request_parameters.task}' task is supported by the model."
error.args = (msg,) + error.args[1:]
if error.response.status_code == 503:
# If Model is unavailable, either raise a TimeoutError...
if timeout is not None and time.time() - t0 > timeout:
Expand Down Expand Up @@ -934,9 +934,9 @@ def chat_completion(
provider_helper = get_provider_helper(self.provider, task="conversational")

# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
# `self.model` takes precedence over 'model' argument for building URL.
# `model` takes precedence for payload value.
model_id_or_url = self.base_url or self.model or model
model_id_or_url = self.model or model
payload_model = model or self.model

# Prepare the payload
Expand Down
9 changes: 3 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
" It has the exact same behavior as `token`."
)

self.model: Optional[str] = model
self.model: Optional[str] = base_url or model
self.token: Optional[str] = token if token is not None else api_key
self.headers = headers if headers is not None else {}

Expand All @@ -191,9 +191,6 @@ def __init__(
self.trust_env = trust_env
self.proxies = proxies

# OpenAI compatibility
self.base_url = base_url

# Keep track of the sessions to close them properly
self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict()

Expand Down Expand Up @@ -977,9 +974,9 @@ async def chat_completion(
provider_helper = get_provider_helper(self.provider, task="conversational")

# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument for building URL.
# `self.model` takes precedence over 'model' argument for building URL.
# `model` takes precedence for payload value.
model_id_or_url = self.base_url or self.model or model
model_id_or_url = self.model or model
payload_model = model or self.model

# Prepare the payload
Expand Down
9 changes: 9 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,3 +1186,12 @@ def test_chat_completion_error_in_stream():
def test_resolve_chat_completion_url(model_url: str, expected_url: str):
url = _build_chat_completion_url(model_url)
assert url == expected_url


def test_pass_url_as_base_url():
client = InferenceClient(base_url="http://localhost:8082/v1/")
provider = get_provider_helper("hf-inference", "text-generation")
request = provider.prepare_request(
inputs="The huggingface_hub library is ", parameters={}, headers={}, model=client.model, api_key=None
)
assert request.url == "http://localhost:8082/v1/"

0 comments on commit 1033d0a

Please sign in to comment.