@@ -32,25 +32,25 @@ def tensor(data, requires_grad=False, dtype=float32, device="cpu"):
32
32
return Tensor (data , requires_grad = requires_grad , dtype = dtype , device = device )
33
33
34
34
35
- def ones (* shape , dtype = None , requires_grad = True , device = "cpu" ):
35
+ def ones (* shape , dtype = None , requires_grad = False , device = "cpu" ):
36
36
shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
37
37
38
38
return Tensor (np .ones (shape , dtype = dtype ), requires_grad = requires_grad , device = device )
39
39
40
40
41
- def zeros (* shape , dtype = None , requires_grad = True , device = "cpu" ):
41
+ def zeros (* shape , dtype = None , requires_grad = False , device = "cpu" ):
42
42
shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
43
43
44
44
return Tensor (np .zeros (shape , dtype = dtype ), requires_grad = requires_grad , device = device )
45
45
46
46
47
- def rand (* shape , dtype = None , requires_grad = True , device = "cpu" ):
47
+ def rand (* shape , dtype = None , requires_grad = False , device = "cpu" ):
48
48
shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
49
49
50
50
return Tensor (np .random .rand (* shape ).astype (dtype ), requires_grad = requires_grad , device = device )
51
51
52
52
53
- def randn (* shape , dtype = None , requires_grad = True , device = "cpu" ):
53
+ def randn (* shape , dtype = None , requires_grad = False , device = "cpu" ):
54
54
shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
55
55
56
56
return Tensor (
@@ -60,7 +60,7 @@ def randn(*shape, dtype=None, requires_grad=True, device="cpu"):
60
60
)
61
61
62
62
63
- def arange (start = 0 , end = None , step = 1 , dtype = None , requires_grad = True , device = "cpu" ):
63
+ def arange (start = 0 , end = None , step = 1 , dtype = None , requires_grad = False , device = "cpu" ):
64
64
if end is None :
65
65
start , end = 0 , start
66
66
return Tensor (
@@ -70,11 +70,11 @@ def arange(start=0, end=None, step=1, dtype=None, requires_grad=True, device="cp
70
70
)
71
71
72
72
73
- def ones_like (tensor , dtype = None , requires_grad = True , device = "cpu" ):
73
+ def ones_like (tensor , dtype = None , requires_grad = False , device = "cpu" ):
74
74
return Tensor (np .ones_like (tensor .data , dtype ), requires_grad = requires_grad , device = device )
75
75
76
76
77
- def zeros_like (tensor , dtype = None , requires_grad = True , device = "cpu" ):
77
+ def zeros_like (tensor , dtype = None , requires_grad = False , device = "cpu" ):
78
78
return Tensor (np .zeros_like (tensor .data , dtype ), requires_grad = requires_grad , device = device )
79
79
80
80
@@ -195,6 +195,7 @@ def flip(x, axis):
195
195
return x .flip (axis = axis )
196
196
197
197
def where (condition , x , y ):
198
+ x = tensor (x , device = condition .device ) if not isinstance (x , Tensor ) else x
198
199
return x .where (condition , y )
199
200
200
201
def equal (x , y ):
0 commit comments