Skip to content

Commit

Permalink
Fix HIO
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeziere committed Jul 23, 2024
1 parent 867932a commit 6a0b4d3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
Binary file added src/.State.jl.swp
Binary file not shown.
46 changes: 22 additions & 24 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ struct ER <: Operator
end

function operate(er::ER, state::State)
BcdiCore.loss(state.state, true, false, false)
state.realSpace .-= state.state.deriv ./ 2
BcdiCore.loss(state.core, true, false, false)
state.realSpace .-= state.core.deriv ./ 2.0
enforceSupport(state)
end

Expand All @@ -39,10 +39,10 @@ struct HIO <: Operator
end

function operate(hio::HIO, state::State)
BcdiCore.loss(state.state, true, false, false)
state.realSpace[state.support] .-= view(state.state.deriv, state.support) ./ 2
state.realSpace[.!state.support] .*= (1 .- hio.beta)
state.realSpace[.!state.support] .-= hio.beta .* view(state.state.deriv, .!state.support) ./ 2
BcdiCore.loss(state.core, true, false, false)
state.realSpace .=
(state.realSpace .- state.core.deriv ./ 2.0) .* state.support .+
(state.realSpace .* (1.0 .- hio.beta) .+ hio.beta .* state.core.deriv ./ 2.0) .* .!state.support
end

"""
Expand All @@ -53,7 +53,6 @@ Create an object that applies shrinkwrap
struct Shrink{T} <: Operator
threshold::Float64
kernel::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}
space::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}
plan::T

function Shrink(threshold, sigma, state)
Expand All @@ -70,23 +69,22 @@ struct Shrink{T} <: Operator
end
end

kernelG = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(kernel)
kernelG = CUFFT.fft(kernelG)

space = CUDA.zeros(Float64, s)
plan = CUFFT.plan_fft!(space)
plan = state.core.plan
kernelG = CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}(kernel)
plan * kernelG
kernelG .= plan.recipSpace

new{typeof(plan)}(threshold, kernelG, space, plan)
new{typeof(plan)}(threshold, kernelG, plan)
end
end

function operate(shrink::Shrink, state::State)
shrink.space .= abs.(state.realSpace)
shrink.plan * shrink.space
shrink.space .*= shrink.kernel
shrink.plan \ shrink.space
threshVal = shrink.threshold * sqrt(maximum(abs2, shrink.space))
state.support .= abs.(shrink.space) .> threshVal
shrink.plan.tempSpace .= abs.(state.realSpace)
shrink.plan * shrink.plan.tempSpace
shrink.plan.tempSpace .= shrink.plan.recipSpace .* shrink.kernel
shrink.plan \ shrink.plan.tempSpace
threshVal = shrink.threshold * sqrt(maximum(abs2, shrink.plan.realSpace))
state.support .= abs.(shrink.plan.realSpace) .> threshVal
end

"""
Expand All @@ -98,7 +96,7 @@ struct Center <: Operator
xArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer}
yArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer}
zArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer}
space::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}
space::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}
support::CuArray{Bool, 3, CUDA.Mem.DeviceBuffer}

function Center(state)
Expand All @@ -116,7 +114,7 @@ struct Center <: Operator
end
end

space = CUDA.zeros(Float64, s)
space = CUDA.zeros(ComplexF64, s)
support = CUDA.zeros(Int64, s)

new(xArr, yArr, zArr, space, support)
Expand All @@ -128,9 +126,9 @@ function operate(center::Center, state::State)

s = size(state.realSpace)
n = reduce(+, center.support)
cenX = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.xArr)/n)
cenY = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.yArr)/n)
cenZ = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.zArr)/n)
cenX = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.xArr)/n)
cenY = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.yArr)/n)
cenZ = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.zArr)/n)
circshift!(center.space, state.realSpace, [s[1]//2+1-cenX, s[2]//2+1-cenY, s[3]//2+1-cenZ])
state.realSpace .= center.space
circshift!(center.support, state.support, [s[1]//2+1-cenX, s[2]//2+1-cenY, s[3]//2+1-cenZ])
Expand Down
12 changes: 6 additions & 6 deletions src/State.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct State{T}
realSpace::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}
shift::Vector{Int64}
support::CuArray{Bool, 3, CUDA.Mem.DeviceBuffer}
state::T
core::T

function State(intens, recSupport)
invInt = CUFFT.ifft(CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(intens))
Expand All @@ -28,16 +28,16 @@ struct State{T}
end
end
end
shift .= [1,1,1] .- round.(Int64, mapreduce(sqrt, +, intens))
shift .*= -1
shift .= [1,1,1] .- round.(Int64, shift ./ mapreduce(sqrt, +, intens))
intens = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(circshift(intens,shift))
recSupport = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(circshift(recSupport,shift))
shift .*= -1

realSpace = CUDA.zeros(ComplexF64, s)
state = BcdiCore.TradState("L2", false, realSpace, intens, recSupport)
realSpace .= CUFFT.ifft(sqrt.(intens))
core = BcdiCore.TradState("L2", false, realSpace, intens, recSupport)
realSpace .= CUFFT.ifft(sqrt.(intens) .* recSupport)

new{typeof(state)}(realSpace, shift, support, state)
new{typeof(core)}(realSpace, shift, support, core)
end
end

0 comments on commit 6a0b4d3

Please sign in to comment.