57
57
58
58
# # Gradient
59
59
60
- struct EnzymeForwardGradientExtras{C ,O} <: GradientExtras
60
+ struct EnzymeForwardGradientExtras{B ,O} <: GradientExtras
61
61
shadow:: O
62
62
end
63
63
64
- function DI. prepare_gradient (f, :: AutoEnzyme{<:ForwardMode} , x)
65
- C = pick_chunksize ( length (x))
66
- shadow = chunkedonehot (x, Val (C ))
67
- return EnzymeForwardGradientExtras {C ,typeof(shadow)} (shadow)
64
+ function DI. prepare_gradient (f, backend :: AutoEnzyme{<:ForwardMode} , x)
65
+ B = pick_batchsize (backend, length (x))
66
+ shadow = chunkedonehot (x, Val (B ))
67
+ return EnzymeForwardGradientExtras {B ,typeof(shadow)} (shadow)
68
68
end
69
69
70
70
function DI. gradient (
71
- f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C }
72
- ) where {C }
73
- grad_tup = gradient (forward_mode (backend), f, x, Val {C} ( ); shadow= extras. shadow)
71
+ f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{B }
72
+ ) where {B }
73
+ grad_tup = gradient (forward_mode (backend), f, x, Val (B ); shadow= extras. shadow)
74
74
return reshape (collect (grad_tup), size (x))
75
75
end
76
76
@@ -81,38 +81,38 @@ function DI.value_and_gradient(
81
81
end
82
82
83
83
function DI. gradient! (
84
- f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C }
85
- ) where {C }
86
- grad_tup = gradient (forward_mode (backend), f, x, Val {C} ( ); shadow= extras. shadow)
84
+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{B }
85
+ ) where {B }
86
+ grad_tup = gradient (forward_mode (backend), f, x, Val (B ); shadow= extras. shadow)
87
87
return copyto! (grad, grad_tup)
88
88
end
89
89
90
90
function DI. value_and_gradient! (
91
- f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C }
92
- ) where {C }
93
- grad_tup = gradient (forward_mode (backend), f, x, Val {C} ( ); shadow= extras. shadow)
91
+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{B }
92
+ ) where {B }
93
+ grad_tup = gradient (forward_mode (backend), f, x, Val (B ); shadow= extras. shadow)
94
94
return f (x), copyto! (grad, grad_tup)
95
95
end
96
96
97
97
# # Jacobian
98
98
99
- struct EnzymeForwardOneArgJacobianExtras{C ,O} <: JacobianExtras
99
+ struct EnzymeForwardOneArgJacobianExtras{B ,O} <: JacobianExtras
100
100
shadow:: O
101
101
end
102
102
103
- function DI. prepare_jacobian (f, :: AutoEnzyme{<:Union{ForwardMode,Nothing}} , x)
104
- C = pick_chunksize ( length (x))
105
- shadow = chunkedonehot (x, Val (C ))
106
- return EnzymeForwardOneArgJacobianExtras {C ,typeof(shadow)} (shadow)
103
+ function DI. prepare_jacobian (f, backend :: AutoEnzyme{<:Union{ForwardMode,Nothing}} , x)
104
+ B = pick_batchsize (backend, length (x))
105
+ shadow = chunkedonehot (x, Val (B ))
106
+ return EnzymeForwardOneArgJacobianExtras {B ,typeof(shadow)} (shadow)
107
107
end
108
108
109
109
function DI. jacobian (
110
110
f,
111
111
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
112
112
x,
113
- extras:: EnzymeForwardOneArgJacobianExtras{C } ,
114
- ) where {C }
115
- jac_wrongshape = jacobian (forward_mode (backend), f, x, Val {C} ( ); shadow= extras. shadow)
113
+ extras:: EnzymeForwardOneArgJacobianExtras{B } ,
114
+ ) where {B }
115
+ jac_wrongshape = jacobian (forward_mode (backend), f, x, Val (B ); shadow= extras. shadow)
116
116
nx = length (x)
117
117
ny = length (jac_wrongshape) ÷ length (x)
118
118
return reshape (jac_wrongshape, ny, nx)
0 commit comments