Skip to content

Commit

Permalink
Implement InferenceClient.chat_completion + use new types for text-…
Browse files Browse the repository at this point in the history
…generation (#2094)

* ChatCompletion types

* style + minijinja

* remove pydantic + add templating + first draft

* remove print

* type update

* Use spec-ed object for text_generation

* move remaining text_generation stuff to _common

* add stream types for chat completion task

* stream in chat_completion

* to be tested but should be good

* fix docs

* Update text generation types

* another update

* update chat completion defnition

* fix tests

* add chat_completion test

* chat completino async test

* what a shitshow

* fix test using TGI

* more tests

* make style

* fix async

* docs

* make minijinja optional

* add role to chat output + adapt tests

* changes from feedback

* fix tests

* stuff

* fix token types

* typo

* update string
  • Loading branch information
Wauplin authored Mar 20, 2024
1 parent 57ee5fc commit 2b8c1c5
Show file tree
Hide file tree
Showing 27 changed files with 2,330 additions and 955 deletions.
29 changes: 0 additions & 29 deletions docs/source/en/package_reference/inference_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,6 @@ For most tasks, the return value has a built-in type (string, list, image...). H

[[autodoc]] huggingface_hub.inference._common.ModelStatus

### Text generation types

[`~InferenceClient.text_generation`] task has a greater support than other tasks in `InferenceClient`. In
particular, user inputs and server outputs are validated using [Pydantic](https://docs.pydantic.dev/latest/)
if this package is installed. Therefore, we recommend installing it (`pip install pydantic`)
for a better user experience.

You can find below the dataclasses used to validate data and in particular [`~huggingface_hub.inference._text_generation.TextGenerationParameters`] (input),
[`~huggingface_hub.inference._text_generation.TextGenerationResponse`] (output) and
[`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`] (streaming output).

[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationParameters

[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationResponse

[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationStreamResponse

[[autodoc]] huggingface_hub.inference._text_generation.InputToken

[[autodoc]] huggingface_hub.inference._text_generation.Token

[[autodoc]] huggingface_hub.inference._text_generation.FinishReason

[[autodoc]] huggingface_hub.inference._text_generation.BestOfSequence

[[autodoc]] huggingface_hub.inference._text_generation.Details

[[autodoc]] huggingface_hub.inference._text_generation.StreamDetails

## InferenceAPI

[`InferenceAPI`] is the legacy way to call the Inference API. The interface is more simplistic and requires knowing
Expand Down
32 changes: 28 additions & 4 deletions docs/source/en/package_reference/inference_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ This part of the lib is still under development and will be improved in future r



## chat_completion

[[autodoc]] huggingface_hub.ChatCompletionInput

[[autodoc]] huggingface_hub.ChatCompletionInputMessage

[[autodoc]] huggingface_hub.ChatCompletionOutput

[[autodoc]] huggingface_hub.ChatCompletionOutputChoice

[[autodoc]] huggingface_hub.ChatCompletionOutputChoiceMessage

[[autodoc]] huggingface_hub.ChatCompletionStreamOutput

[[autodoc]] huggingface_hub.ChatCompletionStreamOutputChoice

[[autodoc]] huggingface_hub.ChatCompletionStreamOutputDelta



## depth_estimation

[[autodoc]] huggingface_hub.DepthEstimationInput
Expand Down Expand Up @@ -203,19 +223,23 @@ This part of the lib is still under development and will be improved in future r

## text_generation

[[autodoc]] huggingface_hub.PrefillToken

[[autodoc]] huggingface_hub.TextGenerationInput

[[autodoc]] huggingface_hub.TextGenerationOutput

[[autodoc]] huggingface_hub.TextGenerationOutputDetails

[[autodoc]] huggingface_hub.TextGenerationOutputSequenceDetails

[[autodoc]] huggingface_hub.TextGenerationOutputToken

[[autodoc]] huggingface_hub.TextGenerationParameters

[[autodoc]] huggingface_hub.TextGenerationSequenceDetails
[[autodoc]] huggingface_hub.TextGenerationPrefillToken

[[autodoc]] huggingface_hub.TextGenerationStreamDetails

[[autodoc]] huggingface_hub.Token
[[autodoc]] huggingface_hub.TextGenerationStreamOutput



Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
[tool.mypy]
ignore_missing_imports = true
no_implicit_optional = true
plugins = [
"pydantic.mypy",
]
scripts_are_modules = true

[tool.pytest.ini_options]
Expand Down
14 changes: 4 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,22 @@ def get_version() -> str:
install_requires = [
"filelock",
"fsspec>=2023.5.0",
"packaging>=20.9",
"pyyaml>=5.1",
"requests",
"tqdm>=4.42.1",
"pyyaml>=5.1",
"typing-extensions>=3.7.4.3", # to be able to import TypeAlias
"packaging>=20.9",
]

extras = {}

extras["cli"] = [
"InquirerPy==0.3.4",
# Note: installs `prompt-toolkit` in the background
"InquirerPy==0.3.4", # Note: installs `prompt-toolkit` in the background
]

extras["inference"] = [
"aiohttp", # for AsyncInferenceClient
# On Python 3.8, Pydantic 2.x and tensorflow don't play well together
# Let's limit pydantic to 1.x for now. Since Tensorflow 2.14, Python3.8 is not supported anyway so impact should be
# limited. We still trigger some CIs on Python 3.8 so we need this workaround.
# NOTE: when relaxing constraint to support v3.x, make sure to adapt `src/huggingface_hub/inference/_text_generation.py`.
"pydantic>1.1,<3.0; python_version>'3.8'",
"pydantic>1.1,<2.0; python_version=='3.8'",
"minijinja>=1.0", # for chat-completion if not TGI-served
]

extras["torch"] = [
Expand Down
32 changes: 26 additions & 6 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@
"AutomaticSpeechRecognitionOutput",
"AutomaticSpeechRecognitionOutputChunk",
"AutomaticSpeechRecognitionParameters",
"ChatCompletionInput",
"ChatCompletionInputMessage",
"ChatCompletionOutput",
"ChatCompletionOutputChoice",
"ChatCompletionOutputChoiceMessage",
"ChatCompletionStreamOutput",
"ChatCompletionStreamOutputChoice",
"ChatCompletionStreamOutputDelta",
"DepthEstimationInput",
"DepthEstimationOutput",
"DocumentQuestionAnsweringInput",
Expand Down Expand Up @@ -298,7 +306,6 @@
"ObjectDetectionInput",
"ObjectDetectionOutputElement",
"ObjectDetectionParameters",
"PrefillToken",
"QuestionAnsweringInput",
"QuestionAnsweringInputData",
"QuestionAnsweringOutputElement",
Expand All @@ -320,8 +327,12 @@
"TextGenerationInput",
"TextGenerationOutput",
"TextGenerationOutputDetails",
"TextGenerationOutputSequenceDetails",
"TextGenerationOutputToken",
"TextGenerationParameters",
"TextGenerationSequenceDetails",
"TextGenerationPrefillToken",
"TextGenerationStreamDetails",
"TextGenerationStreamOutput",
"TextToAudioGenerationParameters",
"TextToAudioInput",
"TextToAudioOutput",
Expand All @@ -330,7 +341,6 @@
"TextToImageOutput",
"TextToImageParameters",
"TextToImageTargetSize",
"Token",
"TokenClassificationInput",
"TokenClassificationOutputElement",
"TokenClassificationParameters",
Expand Down Expand Up @@ -722,6 +732,14 @@ def __dir__():
AutomaticSpeechRecognitionOutput, # noqa: F401
AutomaticSpeechRecognitionOutputChunk, # noqa: F401
AutomaticSpeechRecognitionParameters, # noqa: F401
ChatCompletionInput, # noqa: F401
ChatCompletionInputMessage, # noqa: F401
ChatCompletionOutput, # noqa: F401
ChatCompletionOutputChoice, # noqa: F401
ChatCompletionOutputChoiceMessage, # noqa: F401
ChatCompletionStreamOutput, # noqa: F401
ChatCompletionStreamOutputChoice, # noqa: F401
ChatCompletionStreamOutputDelta, # noqa: F401
DepthEstimationInput, # noqa: F401
DepthEstimationOutput, # noqa: F401
DocumentQuestionAnsweringInput, # noqa: F401
Expand Down Expand Up @@ -750,7 +768,6 @@ def __dir__():
ObjectDetectionInput, # noqa: F401
ObjectDetectionOutputElement, # noqa: F401
ObjectDetectionParameters, # noqa: F401
PrefillToken, # noqa: F401
QuestionAnsweringInput, # noqa: F401
QuestionAnsweringInputData, # noqa: F401
QuestionAnsweringOutputElement, # noqa: F401
Expand All @@ -772,8 +789,12 @@ def __dir__():
TextGenerationInput, # noqa: F401
TextGenerationOutput, # noqa: F401
TextGenerationOutputDetails, # noqa: F401
TextGenerationOutputSequenceDetails, # noqa: F401
TextGenerationOutputToken, # noqa: F401
TextGenerationParameters, # noqa: F401
TextGenerationSequenceDetails, # noqa: F401
TextGenerationPrefillToken, # noqa: F401
TextGenerationStreamDetails, # noqa: F401
TextGenerationStreamOutput, # noqa: F401
TextToAudioGenerationParameters, # noqa: F401
TextToAudioInput, # noqa: F401
TextToAudioOutput, # noqa: F401
Expand All @@ -782,7 +803,6 @@ def __dir__():
TextToImageOutput, # noqa: F401
TextToImageParameters, # noqa: F401
TextToImageTargetSize, # noqa: F401
Token, # noqa: F401
TokenClassificationInput, # noqa: F401
TokenClassificationOutputElement, # noqa: F401
TokenClassificationParameters, # noqa: F401
Expand Down
23 changes: 12 additions & 11 deletions src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, Union

from .inference._client import InferenceClient
from .inference._generated._async_client import AsyncInferenceClient
Expand Down Expand Up @@ -71,8 +71,9 @@ class InferenceEndpoint:
The type of the Inference Endpoint (public, protected, private).
raw (`Dict`):
The raw dictionary data returned from the API.
token (`str`, *optional*):
Authentication token for the Inference Endpoint, if set when requesting the API.
token (`str` or `bool`, *optional*):
Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the
locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server.
Example:
```python
Expand Down Expand Up @@ -120,12 +121,12 @@ class InferenceEndpoint:
raw: Dict = field(repr=False)

# Internal fields
_token: Optional[str] = field(repr=False, compare=False)
_token: Union[str, bool, None] = field(repr=False, compare=False)
_api: "HfApi" = field(repr=False, compare=False)

@classmethod
def from_raw(
cls, raw: Dict, namespace: str, token: Optional[str] = None, api: Optional["HfApi"] = None
cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None
) -> "InferenceEndpoint":
"""Initialize object from raw dictionary."""
if api is None:
Expand Down Expand Up @@ -230,7 +231,7 @@ def fetch(self) -> "InferenceEndpoint":
Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
self.raw = obj.raw
self._populate_from_raw()
return self
Expand Down Expand Up @@ -295,7 +296,7 @@ def update(
framework=framework,
revision=revision,
task=task,
token=self._token,
token=self._token, # type: ignore [arg-type]
)

# Mutate current object
Expand All @@ -316,7 +317,7 @@ def pause(self) -> "InferenceEndpoint":
Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
self.raw = obj.raw
self._populate_from_raw()
return self
Expand All @@ -330,7 +331,7 @@ def resume(self) -> "InferenceEndpoint":
Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
obj = self._api.resume_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
self.raw = obj.raw
self._populate_from_raw()
return self
Expand All @@ -348,7 +349,7 @@ def scale_to_zero(self) -> "InferenceEndpoint":
Returns:
[`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data.
"""
obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]
self.raw = obj.raw
self._populate_from_raw()
return self
Expand All @@ -361,7 +362,7 @@ def delete(self) -> None:
This is an alias for [`HfApi.delete_inference_endpoint`].
"""
self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token)
self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type]

def _populate_from_raw(self) -> None:
"""Populate fields from raw dictionary.
Expand Down
10 changes: 5 additions & 5 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ class HfApi:
def __init__(
self,
endpoint: Optional[str] = None,
token: Optional[str] = None,
token: Union[str, bool, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
Expand All @@ -1256,9 +1256,9 @@ def __init__(
directly at the root of `huggingface_hub`.
Args:
token (`str`, *optional*):
Hugging Face token. Will default to the locally saved token if
not provided.
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Pass `token=False` if you don't want to send your token to the server.
library_name (`str`, *optional*):
The name of the library that is making the HTTP request. Will be added to
the user-agent header. Example: `"transformers"`.
Expand Down Expand Up @@ -2527,7 +2527,7 @@ def file_exists(
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
token: Union[str, bool, None] = None,
) -> bool:
"""
Checks if a file exists in a repository on the Hugging Face Hub.
Expand Down
Loading

0 comments on commit 2b8c1c5

Please sign in to comment.