2
2
3
3
from typing import TYPE_CHECKING
4
4
if TYPE_CHECKING :
5
- import paddle
6
- array = paddle .Tensor
7
- from paddle import dtype as Dtype
5
+ import torch
6
+ array = torch .Tensor
7
+ from torch import dtype as Dtype
8
8
from typing import Optional , Union , Tuple , Literal
9
9
inf = float ('inf' )
10
10
11
11
from ._aliases import _fix_promotion , sum
12
12
13
- from paddle .linalg import * # noqa: F403
13
+ from torch .linalg import * # noqa: F403
14
14
15
- # paddle .linalg doesn't define __all__
16
- # from paddle .linalg import __all__ as linalg_all
17
- from paddle import linalg as paddle_linalg
18
- linalg_all = [i for i in dir (paddle_linalg ) if not i .startswith ('_' )]
15
+ # torch .linalg doesn't define __all__
16
+ # from torch .linalg import __all__ as linalg_all
17
+ from torch import linalg as torch_linalg
18
+ linalg_all = [i for i in dir (torch_linalg ) if not i .startswith ('_' )]
19
19
20
- # outer is implemented in paddle but aren't in the linalg namespace
21
- from paddle import outer
20
+ # outer is implemented in torch but aren't in the linalg namespace
21
+ from torch import outer
22
22
# These functions are in both the main and linalg namespaces
23
23
from ._aliases import matmul , matrix_transpose , tensordot
24
24
25
- # Note: paddle.linalg.cross does not default to axis=-1 (it defaults to the
25
+ # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
26
+ # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
26
27
27
- # paddle.cross also does not support broadcasting when it would add new
28
+ # torch.cross also does not support broadcasting when it would add new
29
+ # dimensions https://github.com/pytorch/pytorch/issues/39656
28
30
def cross (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
29
31
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
30
32
if not (- min (x1 .ndim , x2 .ndim ) <= axis < max (x1 .ndim , x2 .ndim )):
31
33
raise ValueError (f"axis { axis } out of bounds for cross product of arrays with shapes { x1 .shape } and { x2 .shape } " )
32
34
if not (x1 .shape [axis ] == x2 .shape [axis ] == 3 ):
33
35
raise ValueError (f"cross product axis must have size 3, got { x1 .shape [axis ]} and { x2 .shape [axis ]} " )
34
- x1 , x2 = paddle .broadcast_tensors (x1 , x2 )
35
- return paddle_linalg .cross (x1 , x2 , axis = axis )
36
+ x1 , x2 = torch .broadcast_tensors (x1 , x2 )
37
+ return torch_linalg .cross (x1 , x2 , dim = axis )
36
38
37
39
def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 , ** kwargs ) -> array :
38
40
from ._aliases import isdtype
39
41
40
42
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
41
43
42
- # paddle .linalg.vecdot incorrectly allows broadcasting along the contracted dimension
44
+ # torch .linalg.vecdot incorrectly allows broadcasting along the contracted dimension
43
45
if x1 .shape [axis ] != x2 .shape [axis ]:
44
46
raise ValueError ("x1 and x2 must have the same size along the given axis" )
45
47
46
- # paddle .linalg.vecdot doesn't support integer dtypes
48
+ # torch .linalg.vecdot doesn't support integer dtypes
47
49
if isdtype (x1 .dtype , 'integral' ) or isdtype (x2 .dtype , 'integral' ):
48
50
if kwargs :
49
51
raise RuntimeError ("vecdot kwargs not supported for integral dtypes" )
50
52
51
- x1_ = paddle .moveaxis (x1 , axis , - 1 )
52
- x2_ = paddle .moveaxis (x2 , axis , - 1 )
53
- x1_ , x2_ = paddle .broadcast_tensors (x1_ , x2_ )
53
+ x1_ = torch .moveaxis (x1 , axis , - 1 )
54
+ x2_ = torch .moveaxis (x2 , axis , - 1 )
55
+ x1_ , x2_ = torch .broadcast_tensors (x1_ , x2_ )
54
56
55
57
res = x1_ [..., None , :] @ x2_ [..., None ]
56
58
return res [..., 0 , 0 ]
57
- return paddle .linalg .vecdot (x1 , x2 , axis = axis , ** kwargs )
59
+ return torch .linalg .vecdot (x1 , x2 , dim = axis , ** kwargs )
58
60
59
61
def solve (x1 : array , x2 : array , / , ** kwargs ) -> array :
60
62
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
61
- # paddle tries to emulate NumPy 1 solve behavior by using batched 1-D solve
63
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
62
64
# whenever
63
65
# 1. x1.ndim - 1 == x2.ndim
64
66
# 2. x1.shape[:-1] == x2.shape
65
67
#
66
68
# See linalg_solve_is_vector_rhs in
67
69
# aten/src/ATen/native/LinearAlgebraUtils.h and
68
- # paddle_META_FUNC (_linalg_solve_ex) in
69
- # aten/src/ATen/native/BatchLinearAlgebra.cpp in the Pypaddle source code.
70
+ # TORCH_META_FUNC (_linalg_solve_ex) in
71
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
70
72
#
71
73
# The easiest way to work around this is to prepend a size 1 dimension to
72
74
# x2, since x2 is already one dimension less than x1.
73
75
#
74
- # See https://github.com/pypaddle/pypaddle /issues/52915
76
+ # See https://github.com/pytorch/pytorch /issues/52915
75
77
if x2 .ndim != 1 and x1 .ndim - 1 == x2 .ndim and x1 .shape [:- 1 ] == x2 .shape :
76
78
x2 = x2 [None ]
77
- return paddle .linalg .solve (x1 , x2 , ** kwargs )
79
+ return torch .linalg .solve (x1 , x2 , ** kwargs )
78
80
79
- # paddle .trace doesn't support the offset argument and doesn't support stacking
81
+ # torch .trace doesn't support the offset argument and doesn't support stacking
80
82
def trace (x : array , / , * , offset : int = 0 , dtype : Optional [Dtype ] = None ) -> array :
81
83
# Use our wrapped sum to make sure it does upcasting correctly
82
- return sum (paddle .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
84
+ return sum (torch .diagonal (x , offset = offset , dim1 = - 2 , dim2 = - 1 ), axis = - 1 , dtype = dtype )
83
85
84
86
def vector_norm (
85
87
x : array ,
@@ -90,30 +92,30 @@ def vector_norm(
90
92
ord : Union [int , float , Literal [inf , - inf ]] = 2 ,
91
93
** kwargs ,
92
94
) -> array :
93
- # paddle .vector_norm incorrectly treats axis=() the same as axis=None
95
+ # torch .vector_norm incorrectly treats axis=() the same as axis=None
94
96
if axis == ():
95
97
out = kwargs .get ('out' )
96
98
if out is None :
97
99
dtype = None
98
- if x .dtype == paddle .complex64 :
99
- dtype = paddle .float32
100
- elif x .dtype == paddle .complex128 :
101
- dtype = paddle .float64
100
+ if x .dtype == torch .complex64 :
101
+ dtype = torch .float32
102
+ elif x .dtype == torch .complex128 :
103
+ dtype = torch .float64
102
104
103
- out = paddle .zeros_like (x , dtype = dtype )
105
+ out = torch .zeros_like (x , dtype = dtype )
104
106
105
107
# The norm of a single scalar works out to abs(x) in every case except
106
- # for p =0, which is x != 0.
108
+ # for ord =0, which is x != 0.
107
109
if ord == 0 :
108
110
out [:] = (x != 0 )
109
111
else :
110
- out [:] = paddle .abs (x )
112
+ out [:] = torch .abs (x )
111
113
return out
112
- return paddle .linalg .vector_norm (x , p = ord , axis = axis , keepdim = keepdims , ** kwargs )
114
+ return torch .linalg .vector_norm (x , ord = ord , axis = axis , keepdim = keepdims , ** kwargs )
113
115
114
116
__all__ = linalg_all + ['outer' , 'matmul' , 'matrix_transpose' , 'tensordot' ,
115
117
'cross' , 'vecdot' , 'solve' , 'trace' , 'vector_norm' ]
116
118
117
- _all_ignore = ['paddle_linalg ' , 'sum' ]
119
+ _all_ignore = ['torch_linalg ' , 'sum' ]
118
120
119
121
del linalg_all
0 commit comments