Skip to content

Commit f8a617d

Browse files
committed
overload 3-arg and 5-arg mul!()
1 parent 5cfdc03 commit f8a617d

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/dual.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -853,12 +853,15 @@ for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
853853
@eval Base.:\(m::$MT, x::StridedMatrix{<:Dual}) =
854854
_map_dual_components!((y, x) -> ldiv!(y, m, x), (y, x, _) -> ldiv!(y, m, x), similar(x), x)
855855

856-
@eval function Base.:*(m::$MT, x::StridedVector{<:Dual})
857-
T = valtype(eltype(x))
858-
res = similar(x, (size(m, 1),))
859-
mul!(reinterpret(reshape, T, res), reinterpret(reshape, T, x), m')
860-
return res
861-
end
856+
@eval LinearAlgebra.mul!(C::StridedVector{T}, A::$MT, B::StridedVector{T}) where T <: Dual =
857+
mul!(reinterpret(reshape, valtype(T), C), reinterpret(reshape, valtype(T), B), A')
858+
859+
@eval LinearAlgebra.mul!(C::StridedVector{T}, A::$MT, B::StridedVector{T},
860+
α::Union{LinearAlgebra.BlasFloat, Integer},
861+
β::Union{LinearAlgebra.BlasFloat, Integer}) where T <: Dual =
862+
mul!(reinterpret(reshape, valtype(T), C), reinterpret(reshape, valtype(T), B), A', α, β)
863+
864+
@eval Base.:*(m::$MT, x::StridedVector{<:Dual}) = mul!(similar(x, (size(m, 1),)), m, x)
862865

863866
@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) =
864867
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x),

0 commit comments

Comments
 (0)