forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add devices/properties badges (pytorch#2321)
Summary: Add badges of supported properties and devices to functionals and transforms. This commit adds `.. devices::` and `.. properties::` directives to sphinx. APIs with these directives will have badges (based off of shields.io) which link to the page with description of these features. Continuation of pytorch#2316 Excluded dtypes for further improvement, and actually added badges to most of functional/transforms. Pull Request resolved: pytorch#2321 Reviewed By: hwangjeff Differential Revision: D35489063 Pulled By: mthrok fbshipit-source-id: f68a70ebb22df29d5e9bd171273bd19007a81762
- Loading branch information
1 parent
eb23a24
commit 72ae755
Showing
14 changed files
with
628 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import hashlib | ||
from pathlib import Path | ||
from typing import List | ||
from urllib.parse import quote, urlencode | ||
|
||
import requests | ||
from docutils import nodes | ||
from docutils.parsers.rst.directives.images import Image | ||
|
||
|
||
_THIS_DIR = Path(__file__).parent | ||
|
||
# Color palette from PyTorch Developer Day 2021 Presentation Template | ||
YELLOW = "F9DB78" | ||
GREEN = "70AD47" | ||
BLUE = "00B0F0" | ||
PINK = "FF71DA" | ||
ORANGE = "FF8300" | ||
TEAL = "00E5D1" | ||
GRAY = "7F7F7F" | ||
|
||
|
||
def _get_cache_path(key, ext): | ||
filename = f"{hashlib.sha256(key).hexdigest()}{ext}" | ||
cache_dir = _THIS_DIR / "gen_images" | ||
cache_dir.mkdir(parents=True, exist_ok=True) | ||
return cache_dir / filename | ||
|
||
|
||
def _download(url, path): | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
with open(path, "wb") as file: | ||
file.write(response.content) | ||
|
||
|
||
def _fetch_image(url): | ||
path = _get_cache_path(url.encode("utf-8"), ext=".svg") | ||
if not path.exists(): | ||
_download(url, path) | ||
return str(path.relative_to(_THIS_DIR)) | ||
|
||
|
||
class BaseShield(Image): | ||
def run(self, params, alt, section) -> List[nodes.Node]: | ||
url = f"https://img.shields.io/static/v1?{urlencode(params, quote_via=quote)}" | ||
path = _fetch_image(url) | ||
self.arguments = [path] | ||
self.options["alt"] = alt | ||
if "class" not in self.options: | ||
self.options["class"] = [] | ||
self.options["class"].append("shield-badge") | ||
self.options["target"] = f"supported_features.html#{section}" | ||
return super().run() | ||
|
||
|
||
def _parse_devices(arg: str): | ||
devices = sorted(arg.strip().split()) | ||
|
||
valid_values = {"CPU", "CUDA"} | ||
if any(val not in valid_values for val in devices): | ||
raise ValueError( | ||
f"One or more device values are not valid. The valid values are {valid_values}. Given value: '{arg}'" | ||
) | ||
return ", ".join(sorted(devices)) | ||
|
||
|
||
def _parse_properties(arg: str): | ||
properties = sorted(arg.strip().split()) | ||
|
||
valid_values = {"Autograd", "TorchScript"} | ||
if any(val not in valid_values for val in properties): | ||
raise ValueError( | ||
"One or more property values are not valid. " | ||
f"The valid values are {valid_values}. " | ||
f"Given value: '{arg}'" | ||
) | ||
return ", ".join(sorted(properties)) | ||
|
||
|
||
class SupportedDevices(BaseShield): | ||
"""List the supported devices""" | ||
|
||
required_arguments = 1 | ||
final_argument_whitespace = True | ||
|
||
def run(self) -> List[nodes.Node]: | ||
devices = _parse_devices(self.arguments[0]) | ||
alt = f"This feature supports the following devices: {devices}" | ||
params = { | ||
"label": "Devices", | ||
"message": devices, | ||
"labelColor": GRAY, | ||
"color": BLUE, | ||
"style": "flat-square", | ||
} | ||
return super().run(params, alt, "devices") | ||
|
||
|
||
class SupportedProperties(BaseShield): | ||
"""List the supported properties""" | ||
|
||
required_arguments = 1 | ||
final_argument_whitespace = True | ||
|
||
def run(self) -> List[nodes.Node]: | ||
properties = _parse_properties(self.arguments[0]) | ||
alt = f"This API supports the following properties: {properties}" | ||
params = { | ||
"label": "Properties", | ||
"message": properties, | ||
"labelColor": GRAY, | ||
"color": GREEN, | ||
"style": "flat-square", | ||
} | ||
return super().run(params, alt, "properties") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
Supported Features | ||
================== | ||
|
||
Each TorchAudio API supports a subset of PyTorch features, such as | ||
devices and data types. | ||
Supported features are indicated in API references like the following: | ||
|
||
.. devices:: CPU CUDA | ||
|
||
.. properties:: Autograd TorchScript | ||
|
||
These icons mean that they are verified through automated testing. | ||
|
||
.. note:: | ||
|
||
Missing feature icons mean that they are not tested, and this can mean | ||
different things, depending on the API. | ||
|
||
1. The API is compatible with the feature but not tested. | ||
2. The API is not compatible with the feature. | ||
|
||
In case of 2, the API might explicitly raise an error, but that is not guaranteed. | ||
For example, APIs without an Autograd badge might throw an error during backpropagation, | ||
or silently return a wrong gradient. | ||
|
||
If you use an API that hasn't been labeled as supporting a feature, you might want to first verify that the | ||
feature works fine. | ||
|
||
Devices | ||
------- | ||
|
||
CPU | ||
^^^ | ||
|
||
.. devices:: CPU | ||
|
||
TorchAudio APIs that support CPU can perform their computation on CPU tensors. | ||
|
||
|
||
CUDA | ||
^^^^ | ||
|
||
.. devices:: CUDA | ||
|
||
TorchAudio APIs that support CUDA can perform their computation on CUDA devices. | ||
|
||
In case of functions, move the tensor arguments to CUDA device before passing them to a function. | ||
|
||
For example: | ||
|
||
.. code:: python | ||
cuda = torch.device("cuda") | ||
waveform = waveform.to(cuda) | ||
spectrogram = torchaudio.functional.spectrogram(waveform) | ||
Classes with CUDA support are implemented with :py:func:`torch.nn.Module`. | ||
It is also necessary to move the instance to CUDA device, before passing CUDA tensors. | ||
|
||
For example: | ||
|
||
.. code:: python | ||
cuda = torch.device("cuda") | ||
resampler = torchaudio.transforms.Resample(8000, 16000) | ||
resampler.to(cuda) | ||
waveform.to(cuda) | ||
resampled = resampler(waveform) | ||
Properties | ||
---------- | ||
|
||
Autograd | ||
^^^^^^^^ | ||
|
||
.. properties:: Autograd | ||
|
||
TorchAudio APIs with autograd support can correctly backpropagate gradients. | ||
|
||
For the basics of autograd, please refer to this `tutorial <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_. | ||
|
||
.. note:: | ||
|
||
APIs without this mark may or may not raise an error during backpropagation. | ||
The absence of an error raised during backpropagation does not necessarily mean the gradient is correct. | ||
|
||
TorchScript | ||
^^^^^^^^^^^ | ||
|
||
.. properties:: TorchScript | ||
|
||
TorchAudio APIs with TorchScript support can be serialized and executed in non-Python environments. | ||
|
||
For details on TorchScript, please refer to the `documentation <https://pytorch.org/docs/stable/jit.html>`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.