Skip to content

Commit 7118894

Browse files
update README
1 parent 8e5cc94 commit 7118894

File tree

3 files changed

+54
-52
lines changed

3 files changed

+54
-52
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
This is a small wrapper around common array libraries that is compatible with
44
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5-
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
5+
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
66
support for other array libraries, or if you encounter any issues, please [open
77
an issue](https://github.com/data-apis/array-api-compat/issues).
88

9-
See the documentation for more details https://data-apis.org/array-api-compat/
9+
See the documentation for more details <https://data-apis.org/array-api-compat/>

array_api_compat/torch/fft.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from typing import TYPE_CHECKING
44
if TYPE_CHECKING:
5-
import paddle
6-
array = paddle.Tensor
5+
import torch
6+
array = torch.Tensor
77
from typing import Union, Sequence, Literal
88

9-
from paddle.fft import * # noqa: F403
10-
import paddle.fft
9+
from torch.fft import * # noqa: F403
10+
import torch.fft
1111

12-
# Several paddle fft functions do not map axes to dim
12+
# Several torch fft functions do not map axes to dim
1313

1414
def fftn(
1515
x: array,
@@ -20,7 +20,7 @@ def fftn(
2020
norm: Literal["backward", "ortho", "forward"] = "backward",
2121
**kwargs,
2222
) -> array:
23-
return paddle.fft.fftn(x, s=s, axes=axes, norm=norm, **kwargs)
23+
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
2424

2525
def ifftn(
2626
x: array,
@@ -31,7 +31,7 @@ def ifftn(
3131
norm: Literal["backward", "ortho", "forward"] = "backward",
3232
**kwargs,
3333
) -> array:
34-
return paddle.fft.ifftn(x, s=s, axes=axes, norm=norm, **kwargs)
34+
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
3535

3636
def rfftn(
3737
x: array,
@@ -42,7 +42,7 @@ def rfftn(
4242
norm: Literal["backward", "ortho", "forward"] = "backward",
4343
**kwargs,
4444
) -> array:
45-
return paddle.fft.rfftn(x, s=s, axes=axes, norm=norm, **kwargs)
45+
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
4646

4747
def irfftn(
4848
x: array,
@@ -53,7 +53,7 @@ def irfftn(
5353
norm: Literal["backward", "ortho", "forward"] = "backward",
5454
**kwargs,
5555
) -> array:
56-
return paddle.fft.irfftn(x, s=s, axes=axes, norm=norm, **kwargs)
56+
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
5757

5858
def fftshift(
5959
x: array,
@@ -62,7 +62,7 @@ def fftshift(
6262
axes: Union[int, Sequence[int]] = None,
6363
**kwargs,
6464
) -> array:
65-
return paddle.fft.fftshift(x, axes=axes, **kwargs)
65+
return torch.fft.fftshift(x, dim=axes, **kwargs)
6666

6767
def ifftshift(
6868
x: array,
@@ -71,10 +71,10 @@ def ifftshift(
7171
axes: Union[int, Sequence[int]] = None,
7272
**kwargs,
7373
) -> array:
74-
return paddle.fft.ifftshift(x, axes=axes, **kwargs)
74+
return torch.fft.ifftshift(x, dim=axes, **kwargs)
7575

7676

77-
__all__ = paddle.fft.__all__ + [
77+
__all__ = torch.fft.__all__ + [
7878
"fftn",
7979
"ifftn",
8080
"rfftn",
@@ -83,4 +83,4 @@ def ifftshift(
8383
"ifftshift",
8484
]
8585

86-
_all_ignore = ['paddle']
86+
_all_ignore = ['torch']

array_api_compat/torch/linalg.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -2,84 +2,86 @@
22

33
from typing import TYPE_CHECKING
44
if TYPE_CHECKING:
5-
import paddle
6-
array = paddle.Tensor
7-
from paddle import dtype as Dtype
5+
import torch
6+
array = torch.Tensor
7+
from torch import dtype as Dtype
88
from typing import Optional, Union, Tuple, Literal
99
inf = float('inf')
1010

1111
from ._aliases import _fix_promotion, sum
1212

13-
from paddle.linalg import * # noqa: F403
13+
from torch.linalg import * # noqa: F403
1414

15-
# paddle.linalg doesn't define __all__
16-
# from paddle.linalg import __all__ as linalg_all
17-
from paddle import linalg as paddle_linalg
18-
linalg_all = [i for i in dir(paddle_linalg) if not i.startswith('_')]
15+
# torch.linalg doesn't define __all__
16+
# from torch.linalg import __all__ as linalg_all
17+
from torch import linalg as torch_linalg
18+
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
1919

20-
# outer is implemented in paddle but aren't in the linalg namespace
21-
from paddle import outer
20+
# outer is implemented in torch but aren't in the linalg namespace
21+
from torch import outer
2222
# These functions are in both the main and linalg namespaces
2323
from ._aliases import matmul, matrix_transpose, tensordot
2424

25-
# Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
25+
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
26+
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
2627

27-
# paddle.cross also does not support broadcasting when it would add new
28+
# torch.cross also does not support broadcasting when it would add new
29+
# dimensions https://github.com/pytorch/pytorch/issues/39656
2830
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
2931
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
3032
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
3133
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
3234
if not (x1.shape[axis] == x2.shape[axis] == 3):
3335
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
34-
x1, x2 = paddle.broadcast_tensors(x1, x2)
35-
return paddle_linalg.cross(x1, x2, axis=axis)
36+
x1, x2 = torch.broadcast_tensors(x1, x2)
37+
return torch_linalg.cross(x1, x2, dim=axis)
3638

3739
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
3840
from ._aliases import isdtype
3941

4042
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
4143

42-
# paddle.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
44+
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
4345
if x1.shape[axis] != x2.shape[axis]:
4446
raise ValueError("x1 and x2 must have the same size along the given axis")
4547

46-
# paddle.linalg.vecdot doesn't support integer dtypes
48+
# torch.linalg.vecdot doesn't support integer dtypes
4749
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
4850
if kwargs:
4951
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
5052

51-
x1_ = paddle.moveaxis(x1, axis, -1)
52-
x2_ = paddle.moveaxis(x2, axis, -1)
53-
x1_, x2_ = paddle.broadcast_tensors(x1_, x2_)
53+
x1_ = torch.moveaxis(x1, axis, -1)
54+
x2_ = torch.moveaxis(x2, axis, -1)
55+
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
5456

5557
res = x1_[..., None, :] @ x2_[..., None]
5658
return res[..., 0, 0]
57-
return paddle.linalg.vecdot(x1, x2, axis=axis, **kwargs)
59+
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
5860

5961
def solve(x1: array, x2: array, /, **kwargs) -> array:
6062
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
61-
# paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
63+
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
6264
# whenever
6365
# 1. x1.ndim - 1 == x2.ndim
6466
# 2. x1.shape[:-1] == x2.shape
6567
#
6668
# See linalg_solve_is_vector_rhs in
6769
# aten/src/ATen/native/LinearAlgebraUtils.h and
68-
# paddle_META_FUNC(_linalg_solve_ex) in
69-
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
70+
# TORCH_META_FUNC(_linalg_solve_ex) in
71+
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
7072
#
7173
# The easiest way to work around this is to prepend a size 1 dimension to
7274
# x2, since x2 is already one dimension less than x1.
7375
#
74-
# See https://github.com/pypaddle/pypaddle/issues/52915
76+
# See https://github.com/pytorch/pytorch/issues/52915
7577
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
7678
x2 = x2[None]
77-
return paddle.linalg.solve(x1, x2, **kwargs)
79+
return torch.linalg.solve(x1, x2, **kwargs)
7880

79-
# paddle.trace doesn't support the offset argument and doesn't support stacking
81+
# torch.trace doesn't support the offset argument and doesn't support stacking
8082
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
8183
# Use our wrapped sum to make sure it does upcasting correctly
82-
return sum(paddle.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
84+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
8385

8486
def vector_norm(
8587
x: array,
@@ -90,30 +92,30 @@ def vector_norm(
9092
ord: Union[int, float, Literal[inf, -inf]] = 2,
9193
**kwargs,
9294
) -> array:
93-
# paddle.vector_norm incorrectly treats axis=() the same as axis=None
95+
# torch.vector_norm incorrectly treats axis=() the same as axis=None
9496
if axis == ():
9597
out = kwargs.get('out')
9698
if out is None:
9799
dtype = None
98-
if x.dtype == paddle.complex64:
99-
dtype = paddle.float32
100-
elif x.dtype == paddle.complex128:
101-
dtype = paddle.float64
100+
if x.dtype == torch.complex64:
101+
dtype = torch.float32
102+
elif x.dtype == torch.complex128:
103+
dtype = torch.float64
102104

103-
out = paddle.zeros_like(x, dtype=dtype)
105+
out = torch.zeros_like(x, dtype=dtype)
104106

105107
# The norm of a single scalar works out to abs(x) in every case except
106-
# for p=0, which is x != 0.
108+
# for ord=0, which is x != 0.
107109
if ord == 0:
108110
out[:] = (x != 0)
109111
else:
110-
out[:] = paddle.abs(x)
112+
out[:] = torch.abs(x)
111113
return out
112-
return paddle.linalg.vector_norm(x, p=ord, axis=axis, keepdim=keepdims, **kwargs)
114+
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
113115

114116
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
115117
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
116118

117-
_all_ignore = ['paddle_linalg', 'sum']
119+
_all_ignore = ['torch_linalg', 'sum']
118120

119121
del linalg_all

0 commit comments

Comments
 (0)