diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index c9b0ddc3..0fd47ac2 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -230,7 +230,7 @@ function rrule!!( y, dy = viewify(y_dy) # In this rule we optimise carefully for the special case a == 1 && b == 0, which - # corresponds to simply multiplying symm(A) and B together, and writing the result to C. + # corresponds to simply multiplying symm(A) and x together, and writing the result to y. # This is an extremely common edge case, so it's important to do well for it. y_copy = copy(y) tmp_ref = Ref{Vector{T}}() @@ -269,13 +269,13 @@ function rrule!!( dA[n] -= α * dA_tmp[n] end - # gradient w.r.t. B. + # gradient w.r.t. x. BLAS.symv!(ul, α, A, dy, one(T), dx) # gradient w.r.t. beta. dβ = dot(dy, y) - # gradient w.r.t. C. + # gradient w.r.t. y. BLAS.scal!(β, dy) return NoRData(), NoRData(), dα, NoRData(), NoRData(), dβ, NoRData()