diff --git a/docs/source/conf.py b/docs/source/conf.py index 3c235ced6..3ddc258a5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,6 +20,7 @@ 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', + 'sphinx_copybutton', 'pyg', ] diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index bdf8efd53..40ad20460 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -104,7 +104,8 @@ def grouped_matmul( r"""Performs dense-dense matrix multiplication according to groups, utilizing dedicated kernels that effectively parallelize over groups. - Example: + .. code-block:: python + inputs = [torch.randn(5, 16), torch.randn(3, 32)] others = [torch.randn(16, 32), torch.randn(32, 64)] @@ -147,7 +148,8 @@ def segment_matmul( the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing dedicated kernels that effectively parallelize over groups. - Example: + .. code-block:: python + inputs = torch.randn(8, 16) ptr = torch.tensor([0, 5, 8]) other = torch.randn(2, 16, 32)