diff --git a/src/.State.jl.swp b/src/.State.jl.swp new file mode 100644 index 0000000..beae2cd Binary files /dev/null and b/src/.State.jl.swp differ diff --git a/src/Operators.jl b/src/Operators.jl index 3ae56ca..f1d3c6e 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -24,8 +24,8 @@ struct ER <: Operator end function operate(er::ER, state::State) - BcdiCore.loss(state.state, true, false, false) - state.realSpace .-= state.state.deriv ./ 2 + BcdiCore.loss(state.core, true, false, false) + state.realSpace .-= state.core.deriv ./ 2.0 enforceSupport(state) end @@ -39,10 +39,10 @@ struct HIO <: Operator end function operate(hio::HIO, state::State) - BcdiCore.loss(state.state, true, false, false) - state.realSpace[state.support] .-= view(state.state.deriv, state.support) ./ 2 - state.realSpace[.!state.support] .*= (1 .- hio.beta) - state.realSpace[.!state.support] .-= hio.beta .* view(state.state.deriv, .!state.support) ./ 2 + BcdiCore.loss(state.core, true, false, false) + state.realSpace .= + (state.realSpace .- state.core.deriv ./ 2.0) .* state.support .+ + (state.realSpace .* (1.0 .- hio.beta) .+ hio.beta .* state.core.deriv ./ 2.0) .* .!state.support end """ @@ -53,7 +53,6 @@ Create an object that applies shrinkwrap struct Shrink{T} <: Operator threshold::Float64 kernel::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer} - space::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer} plan::T function Shrink(threshold, sigma, state) @@ -70,23 +69,22 @@ struct Shrink{T} <: Operator end end - kernelG = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(kernel) - kernelG = CUFFT.fft(kernelG) - - space = CUDA.zeros(Float64, s) - plan = CUFFT.plan_fft!(space) + plan = state.core.plan + kernelG = CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer}(kernel) + plan * kernelG + kernelG .= plan.recipSpace - new{typeof(plan)}(threshold, kernelG, space, plan) + new{typeof(plan)}(threshold, kernelG, plan) end end function operate(shrink::Shrink, state::State) - shrink.space .= abs.(state.realSpace) - shrink.plan * shrink.space - shrink.space .*= shrink.kernel - shrink.plan \ shrink.space - threshVal = shrink.threshold * sqrt(maximum(abs2, shrink.space)) - state.support .= abs.(shrink.space) .> threshVal + shrink.plan.tempSpace .= abs.(state.realSpace) + shrink.plan * shrink.plan.tempSpace + shrink.plan.tempSpace .= shrink.plan.recipSpace .* shrink.kernel + shrink.plan \ shrink.plan.tempSpace + threshVal = shrink.threshold * sqrt(maximum(abs2, shrink.plan.realSpace)) + state.support .= abs.(shrink.plan.realSpace) .> threshVal end """ @@ -98,7 +96,7 @@ struct Center <: Operator xArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer} yArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer} zArr::CuArray{Int64, 3, CUDA.Mem.DeviceBuffer} - space::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer} + space::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer} support::CuArray{Bool, 3, CUDA.Mem.DeviceBuffer} function Center(state) @@ -116,7 +114,7 @@ struct Center <: Operator end end - space = CUDA.zeros(Float64, s) + space = CUDA.zeros(ComplexF64, s) support = CUDA.zeros(Int64, s) new(xArr, yArr, zArr, space, support) @@ -128,9 +126,9 @@ function operate(center::Center, state::State) s = size(state.realSpace) n = reduce(+, center.support) - cenX = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.xArr)/n) - cenY = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.yArr)/n) - cenZ = round(Int32, mapreduce((r,x)->r*x, +, state.notSupport, center.zArr)/n) + cenX = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.xArr)/n) + cenY = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.yArr)/n) + cenZ = round(Int32, mapreduce((r,x)->r*x, +, center.support, center.zArr)/n) circshift!(center.space, state.realSpace, [s[1]//2+1-cenX, s[2]//2+1-cenY, s[3]//2+1-cenZ]) state.realSpace .= center.space circshift!(center.support, state.support, [s[1]//2+1-cenX, s[2]//2+1-cenY, s[3]//2+1-cenZ]) diff --git a/src/State.jl b/src/State.jl index 7788621..7db5281 100644 --- a/src/State.jl +++ b/src/State.jl @@ -8,7 +8,7 @@ struct State{T} realSpace::CuArray{ComplexF64, 3, CUDA.Mem.DeviceBuffer} shift::Vector{Int64} support::CuArray{Bool, 3, CUDA.Mem.DeviceBuffer} - state::T + core::T function State(intens, recSupport) invInt = CUFFT.ifft(CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(intens)) @@ -28,16 +28,16 @@ struct State{T} end end end - shift .= [1,1,1] .- round.(Int64, mapreduce(sqrt, +, intens)) - shift .*= -1 + shift .= [1,1,1] .- round.(Int64, shift ./ mapreduce(sqrt, +, intens)) intens = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(circshift(intens,shift)) recSupport = CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}(circshift(recSupport,shift)) + shift .*= -1 realSpace = CUDA.zeros(ComplexF64, s) - state = BcdiCore.TradState("L2", false, realSpace, intens, recSupport) - realSpace .= CUFFT.ifft(sqrt.(intens)) + core = BcdiCore.TradState("L2", false, realSpace, intens, recSupport) + realSpace .= CUFFT.ifft(sqrt.(intens) .* recSupport) - new{typeof(state)}(realSpace, shift, support, state) + new{typeof(core)}(realSpace, shift, support, core) end end