Skip to content

Commit

Permalink
Add devices/properties badges (pytorch#2321)
Browse files Browse the repository at this point in the history
Summary:
Add badges of supported properties and devices to functionals and transforms.

This commit adds `.. devices::` and `.. properties::` directives to sphinx.

APIs with these directives will have badges (based off of shields.io) which link to the
page with description of these features.

Continuation of pytorch#2316
Excluded dtypes for further improvement, and actually added badges to most of functional/transforms.

Pull Request resolved: pytorch#2321

Reviewed By: hwangjeff

Differential Revision: D35489063

Pulled By: mthrok

fbshipit-source-id: f68a70ebb22df29d5e9bd171273bd19007a81762
  • Loading branch information
mthrok authored and facebook-github-bot committed Apr 8, 2022
1 parent eb23a24 commit 72ae755
Show file tree
Hide file tree
Showing 14 changed files with 628 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ instance/
docs/_build/
docs/src/
docs/source/tutorials
docs/source/gen_images
docs/source/gen_modules

# PyBuilder
Expand Down
6 changes: 6 additions & 0 deletions docs/source/_static/css/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ dt > em.sig-param:last-of-type::after {
content: "\a";
white-space: pre;
}
/* For shields */
article.pytorch-article img.shield-badge {
width: unset;
margin-top: -18px;
margin-bottom: 9px;
}
16 changes: 13 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
import os
import sys

sys.path.insert(0, os.path.abspath("."))
import re
import warnings
from datetime import datetime
Expand Down Expand Up @@ -342,3 +342,13 @@ def inject_minigalleries(app, what, name, obj, options, lines):

def setup(app):
app.connect("autodoc-process-docstring", inject_minigalleries)


from custom_directives import SupportedDevices, SupportedProperties

# Register custom directives

from docutils.parsers import rst

rst.directives.register_directive("devices", SupportedDevices)
rst.directives.register_directive("properties", SupportedProperties)
116 changes: 116 additions & 0 deletions docs/source/custom_directives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import hashlib
from pathlib import Path
from typing import List
from urllib.parse import quote, urlencode

import requests
from docutils import nodes
from docutils.parsers.rst.directives.images import Image


_THIS_DIR = Path(__file__).parent

# Color palette from PyTorch Developer Day 2021 Presentation Template
YELLOW = "F9DB78"
GREEN = "70AD47"
BLUE = "00B0F0"
PINK = "FF71DA"
ORANGE = "FF8300"
TEAL = "00E5D1"
GRAY = "7F7F7F"


def _get_cache_path(key, ext):
filename = f"{hashlib.sha256(key).hexdigest()}{ext}"
cache_dir = _THIS_DIR / "gen_images"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / filename


def _download(url, path):
response = requests.get(url)
response.raise_for_status()
with open(path, "wb") as file:
file.write(response.content)


def _fetch_image(url):
path = _get_cache_path(url.encode("utf-8"), ext=".svg")
if not path.exists():
_download(url, path)
return str(path.relative_to(_THIS_DIR))


class BaseShield(Image):
def run(self, params, alt, section) -> List[nodes.Node]:
url = f"https://img.shields.io/static/v1?{urlencode(params, quote_via=quote)}"
path = _fetch_image(url)
self.arguments = [path]
self.options["alt"] = alt
if "class" not in self.options:
self.options["class"] = []
self.options["class"].append("shield-badge")
self.options["target"] = f"supported_features.html#{section}"
return super().run()


def _parse_devices(arg: str):
devices = sorted(arg.strip().split())

valid_values = {"CPU", "CUDA"}
if any(val not in valid_values for val in devices):
raise ValueError(
f"One or more device values are not valid. The valid values are {valid_values}. Given value: '{arg}'"
)
return ", ".join(sorted(devices))


def _parse_properties(arg: str):
properties = sorted(arg.strip().split())

valid_values = {"Autograd", "TorchScript"}
if any(val not in valid_values for val in properties):
raise ValueError(
"One or more property values are not valid. "
f"The valid values are {valid_values}. "
f"Given value: '{arg}'"
)
return ", ".join(sorted(properties))


class SupportedDevices(BaseShield):
"""List the supported devices"""

required_arguments = 1
final_argument_whitespace = True

def run(self) -> List[nodes.Node]:
devices = _parse_devices(self.arguments[0])
alt = f"This feature supports the following devices: {devices}"
params = {
"label": "Devices",
"message": devices,
"labelColor": GRAY,
"color": BLUE,
"style": "flat-square",
}
return super().run(params, alt, "devices")


class SupportedProperties(BaseShield):
"""List the supported properties"""

required_arguments = 1
final_argument_whitespace = True

def run(self) -> List[nodes.Node]:
properties = _parse_properties(self.arguments[0])
alt = f"This API supports the following properties: {properties}"
params = {
"label": "Properties",
"message": properties,
"labelColor": GRAY,
"color": GREEN,
"style": "flat-square",
}
return super().run(params, alt, "properties")
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Features described in this documentation are classified by release status:
:caption: Torchaudio Documentation

Index <self>
supported_features

API References
--------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
@misc{RESAMPLE,
author = {Julius O. Smith},
title = {Digital Audio Resampling Home Page "Theory of Ideal Bandlimited Interpolation" section},
url = {https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html},
month = {September},
year = {2020}
}
@article{voxpopuli,
author = {Changhan Wang and
Morgane Rivi{\`{e}}re and
Expand Down
98 changes: 98 additions & 0 deletions docs/source/supported_features.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
Supported Features
==================

Each TorchAudio API supports a subset of PyTorch features, such as
devices and data types.
Supported features are indicated in API references like the following:

.. devices:: CPU CUDA

.. properties:: Autograd TorchScript

These icons mean that they are verified through automated testing.

.. note::

Missing feature icons mean that they are not tested, and this can mean
different things, depending on the API.

1. The API is compatible with the feature but not tested.
2. The API is not compatible with the feature.

In case of 2, the API might explicitly raise an error, but that is not guaranteed.
For example, APIs without an Autograd badge might throw an error during backpropagation,
or silently return a wrong gradient.

If you use an API that hasn't been labeled as supporting a feature, you might want to first verify that the
feature works fine.

Devices
-------

CPU
^^^

.. devices:: CPU

TorchAudio APIs that support CPU can perform their computation on CPU tensors.


CUDA
^^^^

.. devices:: CUDA

TorchAudio APIs that support CUDA can perform their computation on CUDA devices.

In case of functions, move the tensor arguments to CUDA device before passing them to a function.

For example:

.. code:: python
cuda = torch.device("cuda")
waveform = waveform.to(cuda)
spectrogram = torchaudio.functional.spectrogram(waveform)
Classes with CUDA support are implemented with :py:func:`torch.nn.Module`.
It is also necessary to move the instance to CUDA device, before passing CUDA tensors.

For example:

.. code:: python
cuda = torch.device("cuda")
resampler = torchaudio.transforms.Resample(8000, 16000)
resampler.to(cuda)
waveform.to(cuda)
resampled = resampler(waveform)
Properties
----------

Autograd
^^^^^^^^

.. properties:: Autograd

TorchAudio APIs with autograd support can correctly backpropagate gradients.

For the basics of autograd, please refer to this `tutorial <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_.

.. note::

APIs without this mark may or may not raise an error during backpropagation.
The absence of an error raised during backpropagation does not necessarily mean the gradient is correct.

TorchScript
^^^^^^^^^^^

.. properties:: TorchScript

TorchAudio APIs with TorchScript support can be serialized and executed in non-Python environments.

For details on TorchScript, please refer to the `documentation <https://pytorch.org/docs/stable/jit.html>`_.
27 changes: 27 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,33 @@ def test_bandreject_biquad(self, central_freq, Q):
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))

def test_deemph_biquad(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
self.assert_grad(F.deemph_biquad, (x, 44100))

def test_flanger(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.flanger, (x, 44100))

def test_gain(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.gain, (x, 1.1))

def test_overdrive(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
self.assert_grad(F.gain, (x,))

@parameterized.expand([(True,), (False,)])
def test_phaser(self, sinusoidal):
torch.random.manual_seed(2434)
sr = 8000
x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
self.assert_grad(F.phaser, (x, sr, sinusoidal))

@parameterized.expand(
[
(True,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def test_MuLawDecoding(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.MuLawDecoding(), tensor)

def test_ComputeDelta(self):
tensor = torch.rand((1, 10))
self._assert_consistency(T.ComputeDeltas(), tensor)

def test_Fade(self):
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
Expand Down
Loading

0 comments on commit 72ae755

Please sign in to comment.