Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Jul 22, 2024
1 parent 1042b32 commit 85afc77
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 37 deletions.
25 changes: 12 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,18 @@ repos:
name: Upgrade Python syntax
args: [--py38-plus]

# TODO
# - repo: https://github.com/PyCQA/autoflake
# rev: v2.3.1
# hooks:
# - id: autoflake
# name: Remove unused imports and variables
# args: [
# --remove-all-unused-imports,
# --remove-unused-variables,
# --remove-duplicate-keys,
# --ignore-init-module-imports,
# --in-place,
# ]
- repo: https://github.com/PyCQA/autoflake
rev: v2.3.1
hooks:
- id: autoflake
name: Remove unused imports and variables
args: [
--remove-all-unused-imports,
--remove-unused-variables,
--remove-duplicate-keys,
--ignore-init-module-imports,
--in-place,
]

- repo: https://github.com/google/yapf
rev: v0.40.2
Expand Down
22 changes: 21 additions & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
include = ["pyproject.toml", "pyg_lib/**/*.py"]
extend-exclude = [
"pyg_lib/testing.py",
"test",
"tools",
"setup.py",
"benchmark",
]
src = ["pyg_lib"]
line-length = 80
target-version = "py38"

[lint]
select = ["D"]
select = [
"D",
]
ignore = [
"D100", # TODO Don't ignore "Missing docstring in public module"
"D104", # TODO Don't ignore "Missing docstring in public package"
"D205", # Ignore "blank line required between summary line and description"
]

[lint.pydocstyle]
convention = "google"

[format]
quote-style = "single"
37 changes: 21 additions & 16 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor


def pytreeify(cls):
def _pytreeify(cls):
r"""A pytree is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values.
Expand Down Expand Up @@ -56,7 +56,7 @@ def new_backward(ctx, *flat_grad_outputs):
return cls


@pytreeify
@_pytreeify
class GroupedMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, args: Tuple[Tensor]) -> Tuple[Tensor]:
Expand Down Expand Up @@ -96,8 +96,11 @@ def backward(ctx, *outs_grad: Tuple[Tensor]) -> Tuple[Tensor]:
return tuple(inputs_grad + others_grad)


def grouped_matmul(inputs: List[Tensor], others: List[Tensor],
biases: Optional[List[Tensor]] = None) -> List[Tensor]:
def grouped_matmul(
inputs: List[Tensor],
others: List[Tensor],
biases: Optional[List[Tensor]] = None,
) -> List[Tensor]:
r"""Performs dense-dense matrix multiplication according to groups,
utilizing dedicated kernels that effectively parallelize over groups.
Expand Down Expand Up @@ -135,14 +138,17 @@ def grouped_matmul(inputs: List[Tensor], others: List[Tensor],
return outs


def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor,
bias: Optional[Tensor] = None) -> Tensor:
def segment_matmul(
inputs: Tensor,
ptr: Tensor,
other: Tensor,
bias: Optional[Tensor] = None,
) -> Tensor:
r"""Performs dense-dense matrix multiplication according to segments along
the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
dedicated kernels that effectively parallelize over groups.
.. code-block:: python
Example:
inputs = torch.randn(8, 16)
ptr = torch.tensor([0, 5, 8])
other = torch.randn(2, 16, 32)
Expand All @@ -153,11 +159,11 @@ def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor,
assert out[5:8] == inputs[5:8] @ other[1]
Args:
input (torch.Tensor): The left operand 2D matrix of shape
inputs (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments.
For best performance, given as a CPU tensor.
the boundaries of segments. For best performance, given as a CPU
tensor.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
bias (torch.Tensor, optional): Optional bias term of shape
Expand All @@ -181,7 +187,7 @@ def sampled_add(
) -> Tensor:
r"""Performs a sampled **addition** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] +
Expand Down Expand Up @@ -213,7 +219,7 @@ def sampled_sub(
) -> Tensor:
r"""Performs a sampled **subtraction** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] -
Expand Down Expand Up @@ -245,7 +251,7 @@ def sampled_mul(
) -> Tensor:
r"""Performs a sampled **multiplication** of :obj:`left` and :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] *
Expand Down Expand Up @@ -277,7 +283,7 @@ def sampled_div(
) -> Tensor:
r"""Performs a sampled **division** of :obj:`left` by :obj:`right`
according to the indices specified in :obj:`left_index` and
:obj:`right_index`:
:obj:`right_index`.
.. math::
\textrm{out} = \textrm{left}[\textrm{left_index}] /
Expand Down Expand Up @@ -351,7 +357,6 @@ def softmax_csr(
:rtype: :class:`Tensor`
Examples:
>>> src = torch.randn(4, 4)
>>> ptr = torch.tensor([0, 4])
>>> softmax(src, ptr)
Expand Down
1 change: 1 addition & 0 deletions pyg_lib/ops/scatter_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def fused_scatter_reduce(
dim_size: int,
reduce_list: List[str],
) -> Tensor:
r"""Fuses multiple scatter operations into a single kernel."""
# TODO (matthias): Add support for `out`.
# TODO (matthias): Add backward functionality.
# TODO (matthias): Add support for inputs.dim() != 2.
Expand Down
18 changes: 11 additions & 7 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def neighbor_sample(
:obj:`node_time` as default for seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge-weight (torch.Tensor, optional): If given, will perform biased
edge_weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
Expand Down Expand Up @@ -117,10 +117,8 @@ def hetero_neighbor_sample(
.. note ::
Similar to :meth:`neighbor_sample`, but expects a dictionary of node
types (:obj:`str`) and edge types (:obj:`Tuple[str, str, str]`) for
each non-boolean argument.
Args:
kwargs: Arguments of :meth:`neighbor_sample`.
each non-boolean argument. See :meth:`neighbor_sample` for more
details.
"""
src_node_types = {k[0] for k in rowptr_dict.keys()}
dst_node_types = {k[-1] for k in rowptr_dict.keys()}
Expand Down Expand Up @@ -193,8 +191,14 @@ def subgraph(
return torch.ops.pyg.subgraph(rowptr, col, nodes, return_edge_id)


def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,
p: float = 1.0, q: float = 1.0) -> Tensor:
def random_walk(
rowptr: Tensor,
col: Tensor,
seed: Tensor,
walk_length: int,
p: float = 1.0,
q: float = 1.0,
) -> Tensor:
r"""Samples random walks of length :obj:`walk_length` from all node
indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`, as
described in the `"node2vec: Scalable Feature Learning for Networks"
Expand Down

0 comments on commit 85afc77

Please sign in to comment.