Skip to content

Commit

Permalink
Merge pull request qutip#73 from albertomercurio/dev/patch-3
Browse files Browse the repository at this point in the history
Improve WignerClenshaw allocations
  • Loading branch information
albertomercurio authored Apr 14, 2024
2 parents e76cd54 + 71a58cd commit 459d47d
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions src/wigner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ function _wigner(ρ::AbstractArray, xvec::AbstractVector{T}, yvec::AbstractVecto
return _wigner_laguerre(ρ, A, W, g, solver)
end

function _wigner::AbstractArray, xvec::AbstractVector{T}, yvec::AbstractVector{T},
g::Real, solver::WignerClenshaw) where {T <: BlasFloat}
function _wigner::AbstractArray{T1}, xvec::AbstractVector{T}, yvec::AbstractVector{T},
g::Real, solver::WignerClenshaw) where {T1 <: BlasFloat, T <: BlasFloat}

g = convert(T, g)
M = size(ρ, 1)
Expand All @@ -62,9 +62,14 @@ function _wigner(ρ::AbstractArray, xvec::AbstractVector{T}, yvec::AbstractVecto
W .= 2 * ρ[1, end]
L = M - 1

y0 = similar(B, T1)
y1 = similar(B, T1)
y0_old = copy(y0)
res = similar(y0)

while L > 0
L -= 1
ρdiag = _wig_laguerre_clenshaw(L, B, (1 + Int(L!=0))*diag(ρ, L))
ρdiag = _wig_laguerre_clenshaw!(res, L, B, lmul!(1 + Int(L!=0), diag(ρ, L)), y0, y1, y0_old)
@. W = ρdiag + W * A / (L + 1)
end

Expand Down Expand Up @@ -160,23 +165,22 @@ function check_inf(x::T) where T
end


function _wig_laguerre_clenshaw(L::Int, x::AbstractArray{T1}, c::AbstractVector{T2}) where {T1<:Real, T2<:BlasFloat}
if length(c) == 1
y0 = c[1]
y1 = 0
elseif length(c) == 2
y0 = c[1]
y1 = c[2]
else
k = length(c)
y0 = similar(x, T2); y0 .= c[end-1]
y1 = similar(x, T2); y1 .= c[end]
for i in range(3, length(c), step=1)
k -= 1
y0_old = copy(y0)
@. y0 = c[end+1-i] - y1 * sqrt((k - 1) * (L + k - 1) / ((L + k) * k))
@. y1 = y0_old - y1 * ((L + 2 * k - 1) - x) / sqrt((L + k) * k)
end
function _wig_laguerre_clenshaw!(res, L::Int, x, c, y0, y1, y0_old)
length(c) == 1 && return c[1]
length(c) == 2 && return @. c[1] - c[2] * (L + 1 - x) / sqrt(L + 1)

y0 .= c[end-1]
y1 .= c[end]

k = length(c)
for i in range(3, length(c), step=1)
k -= 1
copyto!(y0_old, y0)
@. y0 = c[end+1-i] - y1 * sqrt((k - 1) * (L + k - 1) / ((L + k) * k))
@. y1 = y0_old - y1 * ((L + 2 * k - 1) - x) / sqrt((L + k) * k)
end
return @. y0 - y1 * (L + 1 - x) / sqrt(L + 1)

@. res = y0 - y1 * (L + 1 - x) / sqrt(L + 1)

return res
end

0 comments on commit 459d47d

Please sign in to comment.