Skip to content

Commit

Permalink
Autodocument type hints (#343)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] authored Aug 17, 2024
1 parent fb0b5df commit 95aeaaa
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 131 deletions.
11 changes: 11 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys

import pyg_sphinx_theme
from sphinx.application import Sphinx

import pyg_lib

Expand All @@ -21,6 +22,7 @@
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_copybutton',
'sphinx_autodoc_typehints',
'pyg',
]

Expand All @@ -37,3 +39,12 @@
'python': ('http://docs.python.org', None),
'torch': ('https://pytorch.org/docs/master', None),
}

typehints_use_rtype = False
typehints_defaults = 'comma'


def setup(app: Sphinx) -> None:
r"""Setup sphinx application."""
# Do not drop type hints in signatures:
del app.events.listeners['autodoc-process-signature']
4 changes: 2 additions & 2 deletions pyg_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def load_library(lib_name: str) -> None:
load_library('libpyg')

import pyg_lib.ops # noqa
import pyg_lib.sampler # noqa
import pyg_lib.partition # noqa
import pyg_lib.sampler # noqa


def cuda_version() -> int:
r"""Returns the CUDA version for which :obj:`pyg_lib` was compiled with.
Returns:
(int): The CUDA version.
The CUDA version.
"""
return torch.ops.pyg.cuda_version()

Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_home_dir() -> str:
variable :obj:`$PYG_LIB_HOME` which defaults to :obj:`"~/.cache/pyg_lib"`.
Returns:
(str): The cache directory.
The cache directory.
"""
if _home_dir is not None:
return _home_dir
Expand All @@ -29,7 +29,7 @@ def set_home_dir(path: str):
r"""Sets the cache directory used for storing all :obj:`pyg-lib` data.
Args:
path (str): The path to a local folder.
path: The path to a local folder.
"""
global _home_dir
_home_dir = path
98 changes: 39 additions & 59 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,12 @@ def grouped_matmul(
assert outs[1] == inputs[1] @ others[1]
Args:
inputs (List[torch.Tensor]): List of left operand 2D matrices of shapes
:obj:`[N_i, K_i]`.
others (List[torch.Tensor]): List of right operand 2D matrices of
shapes :obj:`[K_i, M_i]`.
biases (List[torch.Tensor], optional): Optional bias terms to apply for
each element. (default: :obj:`None`)
inputs: List of left operand 2D matrices of shapes :obj:`[N_i, K_i]`.
others: List of right operand 2D matrices of shapes :obj:`[K_i, M_i]`.
biases: Optional bias terms to apply for each element.
Returns:
List[torch.Tensor]: List of 2D output matrices of shapes
:obj:`[N_i, M_i]`.
List of 2D output matrices of shapes :obj:`[N_i, M_i]`.
"""
# Combine inputs into a single tuple for autograd:
outs = list(GroupedMatmul.apply(tuple(inputs + others)))
Expand Down Expand Up @@ -160,18 +156,14 @@ def segment_matmul(
assert out[5:8] == inputs[5:8] @ other[1]
Args:
inputs (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments. For best performance, given as a CPU
tensor.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
bias (torch.Tensor, optional): Optional bias term of shape
:obj:`[B, M]` (default: :obj:`None`)
inputs: The left operand 2D matrix of shape :obj:`[N, K]`.
ptr: Compressed vector of shape :obj:`[B + 1]`, holding the boundaries
of segments. For best performance, given as a CPU tensor.
other: The right operand 3D tensor of shape :obj:`[B, K, M]`.
bias: The bias term of shape :obj:`[B, M]`.
Returns:
torch.Tensor: The 2D output matrix of shape :obj:`[N, M]`.
The 2D output matrix of shape :obj:`[N, M]`.
"""
out = torch.ops.pyg.segment_matmul(inputs, ptr, other)
if bias is not None:
Expand All @@ -198,15 +190,13 @@ def sampled_add(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "add")
return out
Expand All @@ -230,15 +220,13 @@ def sampled_sub(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "sub")
return out
Expand All @@ -262,15 +250,13 @@ def sampled_mul(
thus being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "mul")
return out
Expand All @@ -294,15 +280,13 @@ def sampled_div(
being more runtime and memory-efficient.
Args:
left (torch.Tensor): The left tensor.
right (torch.Tensor): The right tensor.
left_index (torch.LongTensor, optional): The values to sample from the
:obj:`left` tensor. (default: :obj:`None`)
right_index (torch.LongTensor, optional): The values to sample from the
:obj:`right` tensor. (default: :obj:`None`)
left: The left tensor.
right: The right tensor.
left_index: The values to sample from the :obj:`left` tensor.
right_index: The values to sample from the :obj:`right` tensor.
Returns:
torch.Tensor: The output tensor.
The output tensor.
"""
out = torch.ops.pyg.sampled_op(left, right, left_index, right_index, "div")
return out
Expand All @@ -323,13 +307,12 @@ def index_sort(
device.
Args:
inputs (torch.Tensor): A vector with positive integer values.
max_value (int, optional): The maximum value stored inside
:obj:`inputs`. This value can be an estimation, but needs to be
greater than or equal to the real maximum. (default: :obj:`None`)
inputs: A vector with positive integer values.
max_value: The maximum value stored inside :obj:`inputs`. This value
can be an estimation, but needs to be greater than or equal to the
real maximum.
Returns:
Tuple[torch.LongTensor, torch.LongTensor]:
A tuple containing sorted values and indices of the elements in the
original :obj:`input` tensor.
"""
Expand All @@ -349,14 +332,6 @@ def softmax_csr(
:attr:`ptr`, and then proceeds to compute the softmax individually for
each group.
Args:
src (Tensor): The source tensor.
ptr (LongTensor): Groups defined by CSR representation.
dim (int, optional): The dimension in which to normalize.
(default: :obj:`0`)
:rtype: :class:`Tensor`
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
Expand All @@ -365,6 +340,11 @@ def softmax_csr(
[0.1453, 0.2591, 0.5907, 0.2410],
[0.0598, 0.2923, 0.1206, 0.0921],
[0.7792, 0.3502, 0.1638, 0.2145]])
Args:
src: The source tensor.
ptr: Groups defined by CSR representation.
dim: The dimension in which to normalize.
"""
dim = dim + src.dim() if dim < 0 else dim
return torch.ops.pyg.softmax_csr(src, ptr, dim)
Expand Down
19 changes: 8 additions & 11 deletions pyg_lib/partition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@ def metis(
<https://arxiv.org/abs/1905.07953>`_ paper.
Args:
rowptr (torch.Tensor): Compressed source node indices.
col (torch.Tensor): Target node indices.
num_partitions (int): The number of partitions.
node_weight (torch.Tensor, optional): Optional node weights.
(default: :obj:`None`)
edge_weight (torch.Tensor, optional): Optional edge weights.
(default: :obj:`None`)
recursive (bool, optional): If set to :obj:`True`, will use multilevel
recursive bisection instead of multilevel k-way partitioning.
(default: :obj:`False`)
rowptr: Compressed source node indices.
col: Target node indices.
num_partitions: The number of partitions.
node_weight: The node weights.
edge_weight: The edge weights.
recursive: If set to :obj:`True`, will use multilevel recursive
bisection instead of multilevel k-way partitioning.
Returns:
torch.Tensor: A vector that assings each node to a partition.
A vector that assings each node to a partition.
"""
return torch.ops.pyg.metis(rowptr, col, num_partitions, node_weight,
edge_weight, recursive)
Expand Down
Loading

0 comments on commit 95aeaaa

Please sign in to comment.