From da4cebc691784d59734dd7a1e1d384ced6c53c97 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 10 Nov 2023 01:18:39 -0500 Subject: [PATCH 1/2] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 009c0e9b..5f8c8690 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ oneAPIExt = ["oneAPI"] [compat] AMDGPU = "0.4.9" Adapt = "3" -CUDA = "4.1.0" +CUDA = "4.1.0, 5" ChainRulesCore = "1" DiffEqBase = "6.122" DocStringExtensions = "0.8, 0.9" From da3b7ab66edf4e75a526b03890bbe9971c7dcb46 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Sat, 11 Nov 2023 11:54:03 -0500 Subject: [PATCH 2/2] format --- src/ensemblegpuarray/kernels.jl | 42 +-- src/ensemblegpuarray/lowerlevel_solve.jl | 6 +- src/ensemblegpuarray/problem_generation.jl | 8 +- src/ensemblegpukernel/callbacks.jl | 60 ++--- .../integrators/integrator_utils.jl | 246 +++++++++--------- .../integrators/nonstiff/interpolants.jl | 34 +-- .../integrators/nonstiff/types.jl | 60 ++--- .../integrators/stiff/interpolants.jl | 12 +- .../integrators/stiff/types.jl | 152 +++++------ src/ensemblegpukernel/kernels.jl | 10 +- src/ensemblegpukernel/linalg/linsolve.jl | 30 +-- src/ensemblegpukernel/linalg/lu.jl | 2 +- src/ensemblegpukernel/lowerlevel_solve.jl | 34 +-- src/ensemblegpukernel/nlsolve/type.jl | 14 +- .../perform_step/gpu_em_perform_step.jl | 2 +- .../perform_step/gpu_siea_perform_step.jl | 2 +- .../problems/ode_problems.jl | 22 +- src/solve.jl | 32 +-- 18 files changed, 384 insertions(+), 384 deletions(-) diff --git a/src/ensemblegpuarray/kernels.jl b/src/ensemblegpuarray/kernels.jl index a2b209dc..95dea7ea 100644 --- a/src/ensemblegpuarray/kernels.jl +++ b/src/ensemblegpuarray/kernels.jl @@ -13,8 +13,8 @@ function Adapt.adapt_structure(to, ps::ParamWrapper{P, T}) where {P, T} end @kernel function gpu_kernel(f, du, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), + @Const(t)) where {P, T} i = @index(Global, Linear) @inbounds p = params[i].params @inbounds tspan = params[i].data @@ -25,8 +25,8 @@ end end @kernel function gpu_kernel_oop(f, du, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), + @Const(t)) where {P, T} i = @index(Global, Linear) @inbounds p = params[i].params @inbounds tspan = params[i].data @@ -59,8 +59,8 @@ end end @kernel function jac_kernel(f, J, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), + @Const(t)) where {P, T} i = @index(Global, Linear) - 1 section = (1 + (i * size(u, 1))):((i + 1) * size(u, 1)) @inbounds p = params[i + 1].params @@ -73,8 +73,8 @@ end end @kernel function jac_kernel_oop(f, J, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), + @Const(t)) where {P, T} i = @index(Global, Linear) - 1 section = (1 + (i * size(u, 1))):((i + 1) * size(u, 1)) @@ -122,7 +122,7 @@ end end @kernel function continuous_condition_kernel(condition, out, @Const(u), @Const(t), - @Const(p)) + @Const(p)) i = @index(Global, Linear) @views @inbounds out[i] = condition(u[:, i], t, FakeIntegrator(u[:, i], t, p[:, i])) end @@ -141,8 +141,8 @@ function workgroupsize(backend, n) end @kernel function W_kernel(jac, W, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), @Const(gamma), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), @Const(gamma), + @Const(t)) where {P, T} i = @index(Global, Linear) len = size(u, 1) _W = @inbounds @view(W[:, :, i]) @@ -176,9 +176,9 @@ end end @kernel function W_kernel_oop(jac, W, @Const(u), - @Const(params::AbstractArray{ParamWrapper{P, T}}), - @Const(gamma), - @Const(t)) where {P, T} + @Const(params::AbstractArray{ParamWrapper{P, T}}), + @Const(gamma), + @Const(t)) where {P, T} i = @index(Global, Linear) len = size(u, 1) @@ -218,7 +218,7 @@ end end @kernel function Wt_kernel(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma), - @Const(t)) where {T} + @Const(t)) where {T} i = @index(Global, Linear) len = size(u, 1) _W = @inbounds @view(W[:, :, i]) @@ -240,7 +240,7 @@ end end @kernel function Wt_kernel_oop(f::AbstractArray{T}, W, @Const(u), @Const(p), @Const(gamma), - @Const(t)) where {T} + @Const(t)) where {T} i = @index(Global, Linear) len = size(u, 1) _W = @inbounds @view(W[:, :, i]) @@ -268,7 +268,7 @@ end end @kernel function gpu_kernel_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p), - @Const(t)) where {T} + @Const(t)) where {T} i = @index(Global, Linear) @inbounds f = f[i].tgrad if eltype(p) <: Number @@ -279,7 +279,7 @@ end end @kernel function gpu_kernel_oop_tgrad(f::AbstractArray{T}, du, @Const(u), @Const(p), - @Const(t)) where {T} + @Const(t)) where {T} i = @index(Global, Linear) @inbounds f = f[i].tgrad if eltype(p) <: Number @@ -320,13 +320,13 @@ LinSolveGPUSplitFactorize() = LinSolveGPUSplitFactorize(0, 0) LinearSolve.needs_concrete_A(::LinSolveGPUSplitFactorize) = true function LinearSolve.init_cacheval(linsol::LinSolveGPUSplitFactorize, A, b, u, Pl, Pr, - maxiters::Int, abstol, reltol, verbose::Bool, - assumptions::LinearSolve.OperatorAssumptions) + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::LinearSolve.OperatorAssumptions) LinSolveGPUSplitFactorize(linsol.len, length(u) ÷ linsol.len) end function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::LinSolveGPUSplitFactorize, - args...; kwargs...) + args...; kwargs...) p = cache.cacheval A = cache.A b = cache.b diff --git a/src/ensemblegpuarray/lowerlevel_solve.jl b/src/ensemblegpuarray/lowerlevel_solve.jl index 5e6494ce..de98ec35 100644 --- a/src/ensemblegpuarray/lowerlevel_solve.jl +++ b/src/ensemblegpuarray/lowerlevel_solve.jl @@ -24,9 +24,9 @@ Only a subset of the common solver arguments are supported. function vectorized_map_solve end function vectorized_map_solve(probs, alg, - ensemblealg::Union{EnsembleArrayAlgorithm}, I, - adaptive; - kwargs...) + ensemblealg::Union{EnsembleArrayAlgorithm}, I, + adaptive; + kwargs...) # @assert all(Base.Fix2((prob1, prob2) -> isequal(prob1.tspan, prob2.tspan),probs[1]),probs) # u0 = reduce(hcat, Array(probs[i].u0) for i in 1:length(I)) diff --git a/src/ensemblegpuarray/problem_generation.jl b/src/ensemblegpuarray/problem_generation.jl index 9075cae2..4e6fe732 100644 --- a/src/ensemblegpuarray/problem_generation.jl +++ b/src/ensemblegpuarray/problem_generation.jl @@ -1,8 +1,8 @@ function generate_problem(prob::SciMLBase.AbstractODEProblem, - u0, - p, - jac_prototype, - colorvec) + u0, + p, + jac_prototype, + colorvec) _f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop function (du, u, p, t) version = get_backend(u) diff --git a/src/ensemblegpukernel/callbacks.jl b/src/ensemblegpukernel/callbacks.jl index 1268bb9b..d95c251a 100644 --- a/src/ensemblegpukernel/callbacks.jl +++ b/src/ensemblegpukernel/callbacks.jl @@ -5,8 +5,8 @@ struct GPUDiscreteCallback{F1, F2, F3, F4, F5} <: SciMLBase.AbstractDiscreteCall finalize::F4 save_positions::F5 function GPUDiscreteCallback(condition::F1, affect!::F2, - initialize::F3, finalize::F4, - save_positions::F5) where {F1, F2, F3, F4, F5} + initialize::F3, finalize::F4, + save_positions::F5) where {F1, F2, F3, F4, F5} if save_positions != (false, false) error("Callback `save_positions` are incompatible with kernel-based GPU ODE solvers due requiring static sizing. Please ensure `save_positions = (false,false)` is set in all callback definitions used with such solvers.") end @@ -15,9 +15,9 @@ struct GPUDiscreteCallback{F1, F2, F3, F4, F5} <: SciMLBase.AbstractDiscreteCall end end function GPUDiscreteCallback(condition, affect!; - initialize = SciMLBase.INITIALIZE_DEFAULT, - finalize = SciMLBase.FINALIZE_DEFAULT, - save_positions = (false, false)) + initialize = SciMLBase.INITIALIZE_DEFAULT, + finalize = SciMLBase.FINALIZE_DEFAULT, + save_positions = (false, false)) GPUDiscreteCallback(condition, affect!, initialize, finalize, save_positions) end @@ -42,12 +42,12 @@ struct GPUContinuousCallback{F1, F2, F3, F4, F5, F6, T, T2, T3, I, R} <: reltol::T2 repeat_nudge::T3 function GPUContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3, - initialize::F4, finalize::F5, idxs::I, rootfind, - interp_points, save_positions::F6, dtrelax::R, abstol::T, - reltol::T2, - repeat_nudge::T3) where {F1, F2, F3, F4, F5, F6, T, T2, - T3, I, R, - } + initialize::F4, finalize::F5, idxs::I, rootfind, + interp_points, save_positions::F6, dtrelax::R, abstol::T, + reltol::T2, + repeat_nudge::T3) where {F1, F2, F3, F4, F5, F6, T, T2, + T3, I, R, + } if save_positions != (false, false) error("Callback `save_positions` are incompatible with kernel-based GPU ODE solvers due requiring static sizing. Please ensure `save_positions = (false,false)` is set in all callback definitions used with such solvers.") end @@ -61,15 +61,15 @@ struct GPUContinuousCallback{F1, F2, F3, F4, F5, F6, T, T2, T3, I, R} <: end function GPUContinuousCallback(condition, affect!, affect_neg!; - initialize = SciMLBase.INITIALIZE_DEFAULT, - finalize = SciMLBase.FINALIZE_DEFAULT, - idxs = nothing, - rootfind = LeftRootFind, - save_positions = (false, false), - interp_points = 10, - dtrelax = 1, - abstol = 10eps(Float32), reltol = 0, - repeat_nudge = 1 // 100) + initialize = SciMLBase.INITIALIZE_DEFAULT, + finalize = SciMLBase.FINALIZE_DEFAULT, + idxs = nothing, + rootfind = LeftRootFind, + save_positions = (false, false), + interp_points = 10, + dtrelax = 1, + abstol = 10eps(Float32), reltol = 0, + repeat_nudge = 1 // 100) GPUContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs, rootfind, interp_points, @@ -78,15 +78,15 @@ function GPUContinuousCallback(condition, affect!, affect_neg!; end function GPUContinuousCallback(condition, affect!; - initialize = SciMLBase.INITIALIZE_DEFAULT, - finalize = SciMLBase.FINALIZE_DEFAULT, - idxs = nothing, - rootfind = LeftRootFind, - save_positions = (false, false), - affect_neg! = affect!, - interp_points = 10, - dtrelax = 1, - abstol = 10eps(Float32), reltol = 0, repeat_nudge = 1 // 100) + initialize = SciMLBase.INITIALIZE_DEFAULT, + finalize = SciMLBase.FINALIZE_DEFAULT, + idxs = nothing, + rootfind = LeftRootFind, + save_positions = (false, false), + affect_neg! = affect!, + interp_points = 10, + dtrelax = 1, + abstol = 10eps(Float32), reltol = 0, repeat_nudge = 1 // 100) GPUContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs, rootfind, interp_points, save_positions, @@ -101,7 +101,7 @@ function Base.convert(::Type{GPUContinuousCallback}, x::T) where {T <: Continuou end function generate_callback(callback::DiscreteCallback, I, - ensemblealg) + ensemblealg) if ensemblealg isa EnsembleGPUArray backend = ensemblealg.backend cur = adapt(backend, [false for i in 1:I]) diff --git a/src/ensemblegpukernel/integrators/integrator_utils.jl b/src/ensemblegpukernel/integrators/integrator_utils.jl index e883d6d0..f2d4f7da 100644 --- a/src/ensemblegpukernel/integrators/integrator_utils.jl +++ b/src/ensemblegpukernel/integrators/integrator_utils.jl @@ -11,13 +11,13 @@ function build_adaptive_controller_cache(alg::A, ::Type{T}) where {A, T} end @inline function savevalues!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, - T, - }, ts, - us, - force = false) where {AlgType <: GPUODEAlgorithm, IIP, S, T} + AlgType, + IIP, + S, + T, + }, ts, + us, + force = false) where {AlgType <: GPUODEAlgorithm, IIP, S, T} saved, savedexactly = false, false saveat = integrator.saveat @@ -46,28 +46,28 @@ end end @inline function DiffEqBase.terminate!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, - IIP, S, - T}, - retcode = ReturnCode.Terminated) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, -} + IIP, S, + T}, + retcode = ReturnCode.Terminated) where { + AlgType <: + GPUODEAlgorithm, + IIP, + S, + T, + } integrator.retcode = retcode end @inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T, - }, - ts, us, - callback::GPUDiscreteCallback) where { - AlgType <: - GPUODEAlgorithm, - IIP, S, T} + AlgType, + IIP, + S, T, + }, + ts, us, + callback::GPUDiscreteCallback) where { + AlgType <: + GPUODEAlgorithm, + IIP, S, T} saved_in_cb = false if callback.condition(integrator.u, integrator.t, integrator) # handle saveat @@ -80,29 +80,29 @@ end end @inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T, - }, - ts, us, - callback::GPUDiscreteCallback, - args...) where {AlgType <: GPUODEAlgorithm, IIP, - S, T} + AlgType, + IIP, + S, T, + }, + ts, us, + callback::GPUDiscreteCallback, + args...) where {AlgType <: GPUODEAlgorithm, IIP, + S, T} apply_discrete_callback!(integrator, ts, us, apply_discrete_callback!(integrator, ts, us, callback)..., args...) end @inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T, - }, - ts, us, - discrete_modified::Bool, - saved_in_cb::Bool, callback::GPUDiscreteCallback, - args...) where {AlgType <: GPUODEAlgorithm, IIP, - S, T} + AlgType, + IIP, + S, T, + }, + ts, us, + discrete_modified::Bool, + saved_in_cb::Bool, callback::GPUDiscreteCallback, + args...) where {AlgType <: GPUODEAlgorithm, IIP, + S, T} bool, saved_in_cb2 = apply_discrete_callback!(integrator, ts, us, apply_discrete_callback!(integrator, ts, us, callback)..., @@ -111,28 +111,28 @@ end end @inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T, - }, - ts, us, - discrete_modified::Bool, - saved_in_cb::Bool, - callback::GPUDiscreteCallback) where { - AlgType <: - GPUODEAlgorithm, - IIP, S, T} + AlgType, + IIP, + S, T, + }, + ts, us, + discrete_modified::Bool, + saved_in_cb::Bool, + callback::GPUDiscreteCallback) where { + AlgType <: + GPUODEAlgorithm, + IIP, S, T} bool, saved_in_cb2 = apply_discrete_callback!(integrator, ts, us, callback) discrete_modified || bool, saved_in_cb || saved_in_cb2 end @inline function interpolate(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, - T, - }, - t) where {AlgType <: GPUODEAlgorithm, IIP, S, T} + AlgType, + IIP, + S, + T, + }, + t) where {AlgType <: GPUODEAlgorithm, IIP, S, T} θ = (t - integrator.tprev) / integrator.dt b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = SimpleDiffEq.bθs(integrator.rs, θ) return integrator.uprev + @@ -143,20 +143,20 @@ end end @inline function _change_t_via_interpolation!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, + AlgType, + IIP, + S, + T, + }, + t, + modify_save_endpoint::Type{Val{T1}}) where { + AlgType <: + GPUODEAlgorithm, IIP, S, T, - }, - t, - modify_save_endpoint::Type{Val{T1}}) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, - T1, -} + T1, + } # Can get rid of an allocation here with a function # get_tmp_arr(integrator.cache) which gives a pointer to some # cache array which can be modified. @@ -170,30 +170,30 @@ end end end @inline function DiffEqBase.change_t_via_interpolation!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, + AlgType, + IIP, + S, + T, + }, + t, + modify_save_endpoint::Type{Val{T1}} = Val{ + false, + }) where { + AlgType <: + GPUODEAlgorithm, IIP, S, T, - }, - t, - modify_save_endpoint::Type{Val{T1}} = Val{ - false, - }) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, - T1, -} + T1, + } _change_t_via_interpolation!(integrator, t, modify_save_endpoint) end @inline function apply_callback!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, IIP, - S, T}, - callback::GPUContinuousCallback, - cb_time, prev_sign, event_idx, ts, - us) where {AlgType <: GPUODEAlgorithm, IIP, S, T} + S, T}, + callback::GPUContinuousCallback, + cb_time, prev_sign, event_idx, ts, + us) where {AlgType <: GPUODEAlgorithm, IIP, S, T} DiffEqBase.change_t_via_interpolation!(integrator, integrator.tprev + cb_time) # handle saveat @@ -220,8 +220,8 @@ end end @inline function handle_callbacks!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, - IIP, S, T}, - ts, us) where {AlgType <: GPUODEAlgorithm, IIP, S, T} + IIP, S, T}, + ts, us) where {AlgType <: GPUODEAlgorithm, IIP, S, T} discrete_callbacks = integrator.callback.discrete_callbacks continuous_callbacks = integrator.callback.continuous_callbacks atleast_one_callback = false @@ -257,14 +257,14 @@ end end @inline function DiffEqBase.find_callback_time(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, - T, - }, - callback::DiffEqGPU.GPUContinuousCallback, - counter) where {AlgType <: GPUODEAlgorithm, - IIP, S, T} + AlgType, + IIP, + S, + T, + }, + callback::DiffEqGPU.GPUContinuousCallback, + counter) where {AlgType <: GPUODEAlgorithm, + IIP, S, T} event_occurred, interp_index, prev_sign, prev_sign_index, event_idx = DiffEqBase.determine_event_occurance(integrator, callback, counter) @@ -314,26 +314,26 @@ end end @inline function SciMLBase.get_tmp_cache(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T}) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, -} + AlgType, + IIP, + S, T}) where { + AlgType <: + GPUODEAlgorithm, + IIP, + S, + T, + } return nothing end @inline function DiffEqBase.get_condition(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, T, - }, - callback, - abst) where {AlgType <: GPUODEAlgorithm, IIP, S, T -} + AlgType, + IIP, + S, T, + }, + callback, + abst) where {AlgType <: GPUODEAlgorithm, IIP, S, T + } if abst == integrator.t tmp = integrator.u elseif abst == integrator.tprev @@ -346,16 +346,16 @@ end # interp_points = 0 or equivalently nothing @inline function DiffEqBase.determine_event_occurance(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, - T, - }, - callback::DiffEqGPU.GPUContinuousCallback, - counter) where { - AlgType <: - GPUODEAlgorithm, IIP, - S, T} + AlgType, + IIP, + S, + T, + }, + callback::DiffEqGPU.GPUContinuousCallback, + counter) where { + AlgType <: + GPUODEAlgorithm, IIP, + S, T} event_occurred = false interp_index = 0 diff --git a/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl b/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl index 277d5acf..134ea1aa 100644 --- a/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl +++ b/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl @@ -1,14 +1,14 @@ # Default: Hermite Interpolation @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::DiffEqBase.AbstractODEIntegrator{AlgType, - IIP, S, T, - }) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, -} + integ::DiffEqBase.AbstractODEIntegrator{AlgType, + IIP, S, T, + }) where { + AlgType <: + GPUODEAlgorithm, + IIP, + S, + T, + } y₁ = integ.u k1 = integ.k1 k2 = integ.k2 @@ -43,8 +43,8 @@ end end @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where {T <: - Union{GPUV7I, GPUAV7I}} + integ::T) where {T <: + Union{GPUV7I, GPUAV7I}} b1Θ, b4Θ, b5Θ, b6Θ, b7Θ, b8Θ, b9Θ, b11Θ, b12Θ, b13Θ, b14Θ, b15Θ, b16Θ = bΘs(integ, Θ) @unpack c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, c12, a1201, a1204, @@ -127,8 +127,8 @@ end end @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where {T <: - Union{GPUV9I, GPUAV9I}} + integ::T) where {T <: + Union{GPUV9I, GPUAV9I}} b1Θ, b8Θ, b9Θ, b10Θ, b11Θ, b12Θ, b13Θ, b14Θ, b15Θ, b17Θ, b18Θ, b19Θ, b20Θ, b21Θ, b22Θ, b23Θ, b24Θ, b25Θ, b26Θ = bΘs(integ, Θ) @@ -206,8 +206,8 @@ end end @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where {T <: - Union{GPUT5I, GPUAT5I}} + integ::T) where {T <: + Union{GPUT5I, GPUAT5I}} b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = SimpleDiffEq.bθs(integ.rs, Θ) return y₀ + dt * @@ -217,8 +217,8 @@ end end @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where {T <: - Union{GPURB23I, GPUARB23I}} + integ::T) where {T <: + Union{GPURB23I, GPUARB23I}} c1 = Θ * (1 - Θ) / (1 - 2 * integ.d) c2 = Θ * (Θ - 2 * integ.d) / (1 - 2 * integ.d) return y₀ + dt * (c1 * integ.k1 + c2 * integ.k2) diff --git a/src/ensemblegpukernel/integrators/nonstiff/types.jl b/src/ensemblegpukernel/integrators/nonstiff/types.jl index 8c5a6796..7c9e3ba1 100644 --- a/src/ensemblegpukernel/integrators/nonstiff/types.jl +++ b/src/ensemblegpukernel/integrators/nonstiff/types.jl @@ -311,11 +311,11 @@ end # Initialization of Integrators ####################################################################################### @inline function init(alg::GPUTsit5, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, S, + TS, CB, ST} cs, as, rs = SimpleDiffEq._build_tsit5_caches(T) !IIP && @assert S <: SArray @@ -340,11 +340,11 @@ end end @inline function init(alg::GPUTsit5, f::F, IIP::Bool, u0::S, t0::T, tf::T, dt::T, - p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} + p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} cs, as, btildes, rs = SimpleDiffEq._build_atsit5_caches(T) !IIP && @assert S <: SArray @@ -398,11 +398,11 @@ end end @inline function init(alg::GPUVern7, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, S, + TS, CB, ST} tab = Vern7Tableau(T, T) !IIP && @assert S <: SArray @@ -437,11 +437,11 @@ end end @inline function init(alg::GPUVern7, f::F, IIP::Bool, u0::S, t0::T, tf::T, dt::T, - p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} + p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} !IIP && @assert S <: SArray tab = Vern7Tableau(T, T) @@ -495,11 +495,11 @@ end end @inline function init(alg::GPUVern9, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, S, + TS, CB, ST} tab = Vern9Tableau(T, T) !IIP && @assert S <: SArray @@ -534,11 +534,11 @@ end end @inline function init(alg::GPUVern9, f::F, IIP::Bool, u0::S, t0::T, tf::T, dt::T, - p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} + p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, CB, ST} !IIP && @assert S <: SArray tab = Vern9Tableau(T, T) diff --git a/src/ensemblegpukernel/integrators/stiff/interpolants.jl b/src/ensemblegpukernel/integrators/stiff/interpolants.jl index 60f9fa3d..c33c56f9 100644 --- a/src/ensemblegpukernel/integrators/stiff/interpolants.jl +++ b/src/ensemblegpukernel/integrators/stiff/interpolants.jl @@ -1,16 +1,16 @@ @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where { - T <: - Union{GPURodas4I, GPUARodas4I}} + integ::T) where { + T <: + Union{GPURodas4I, GPUARodas4I}} Θ1 = 1 - Θ y₁ = integ.u return Θ1 * y₀ + Θ * (y₁ + Θ1 * (integ.k1 + Θ * integ.k2)) end @inline @muladd function _ode_interpolant(Θ, dt, y₀, - integ::T) where { - T <: - Union{GPURodas5PI, GPUARodas5PI}} + integ::T) where { + T <: + Union{GPURodas5PI, GPUARodas5PI}} Θ1 = 1 - Θ y₁ = integ.u return Θ1 * y₀ + Θ * (y₁ + Θ1 * (integ.k1 + Θ * (integ.k2 + Θ * integ.k3))) diff --git a/src/ensemblegpukernel/integrators/stiff/types.jl b/src/ensemblegpukernel/integrators/stiff/types.jl index 4f4ed556..f61e4a14 100644 --- a/src/ensemblegpukernel/integrators/stiff/types.jl +++ b/src/ensemblegpukernel/integrators/stiff/types.jl @@ -1,25 +1,25 @@ @inline function (integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, - S, - T, -})(t) where { - AlgType <: - GPUODEAlgorithm, - IIP, - S, - T, -} + AlgType, + IIP, + S, + T, + })(t) where { + AlgType <: + GPUODEAlgorithm, + IIP, + S, + T, + } Θ = (t - integrator.tprev) / integrator.dt _ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator) end @inline function DiffEqBase.u_modified!(integrator::DiffEqBase.AbstractODEIntegrator{ - AlgType, - IIP, S, - T}, - bool::Bool) where {AlgType <: GPUODEAlgorithm, IIP, - S, T} + AlgType, + IIP, S, + T}, + bool::Bool) where {AlgType <: GPUODEAlgorithm, IIP, + S, T} integrator.u_modified = bool end @@ -55,12 +55,12 @@ end const GPURB23I = GPURosenbrock23Integrator @inline function init(alg::GPURosenbrock23, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, - S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, + S, + TS, CB, ST} !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 @@ -135,12 +135,12 @@ end const GPUARB23I = GPUARosenbrock23Integrator @inline function init(alg::GPURosenbrock23, f::F, IIP::Bool, u0::S, t0::T, tf::T, - dt::T, p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, - CB, ST} + dt::T, p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, + CB, ST} !IIP && @assert S <: SArray qoldinit = T(1e-4) @@ -223,12 +223,12 @@ end const GPURodas4I = GPURodas4Integrator @inline function init(alg::GPURodas4, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, - S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, + S, + TS, CB, ST} !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 @@ -311,12 +311,12 @@ end const GPUARodas4I = GPUARodas4Integrator @inline function init(alg::GPURodas4, f::F, IIP::Bool, u0::S, t0::T, tf::T, - dt::T, p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, - CB, ST} + dt::T, p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, + CB, ST} !IIP && @assert S <: SArray qoldinit = T(1e-4) event_last_time = 1 @@ -397,12 +397,12 @@ end const GPURodas5PI = GPURodas5PIntegrator @inline function init(alg::GPURodas5P, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, - S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, + S, + TS, CB, ST} !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 @@ -476,12 +476,12 @@ end const GPUARodas5PI = GPUARodas5PIntegrator @inline function init(alg::GPURodas5P, f::F, IIP::Bool, u0::S, t0::T, tf::T, - dt::T, p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, - CB, ST} + dt::T, p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, + CB, ST} !IIP && @assert S <: SArray qoldinit = T(1e-4) event_last_time = 1 @@ -562,12 +562,12 @@ end const GPUKvaerno3I = GPUKvaerno3Integrator @inline function init(alg::GPUKvaerno3, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, - S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, + S, + TS, CB, ST} !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 @@ -641,12 +641,12 @@ end const GPUAKvaerno3I = GPUAKvaerno3Integrator @inline function init(alg::GPUKvaerno3, f::F, IIP::Bool, u0::S, t0::T, tf::T, - dt::T, p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, - CB, ST} + dt::T, p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, + CB, ST} !IIP && @assert S <: SArray qoldinit = T(1e-4) event_last_time = 1 @@ -727,12 +727,12 @@ end const GPUKvaerno5I = GPUKvaerno5Integrator @inline function init(alg::GPUKvaerno5, f::F, IIP::Bool, u0::S, t0::T, dt::T, - p::P, tstops::TS, - callback::CB, - save_everystep::Bool, - saveat::ST) where {F, P, T, - S, - TS, CB, ST} + p::P, tstops::TS, + callback::CB, + save_everystep::Bool, + saveat::ST) where {F, P, T, + S, + TS, CB, ST} !IIP && @assert S <: SArray event_last_time = 1 vector_event_last_time = 0 @@ -806,12 +806,12 @@ end const GPUAKvaerno5I = GPUAKvaerno5Integrator @inline function init(alg::GPUKvaerno5, f::F, IIP::Bool, u0::S, t0::T, tf::T, - dt::T, p::P, - abstol::TOL, reltol::TOL, - internalnorm::N, tstops::TS, - callback::CB, - saveat::ST) where {F, P, S, T, N, TOL, TS, - CB, ST} + dt::T, p::P, + abstol::TOL, reltol::TOL, + internalnorm::N, tstops::TS, + callback::CB, + saveat::ST) where {F, P, S, T, N, TOL, TS, + CB, ST} !IIP && @assert S <: SArray qoldinit = T(1e-4) event_last_time = 1 diff --git a/src/ensemblegpukernel/kernels.jl b/src/ensemblegpukernel/kernels.jl index b1b26a65..51d650ff 100644 --- a/src/ensemblegpukernel/kernels.jl +++ b/src/ensemblegpukernel/kernels.jl @@ -1,7 +1,7 @@ @kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback, - tstops, nsteps, - saveat, ::Val{save_everystep}) where {save_everystep} + tstops, nsteps, + saveat, ::Val{save_everystep}) where {save_everystep} i = @index(Global, Linear) # get the actual problem for this thread @@ -52,9 +52,9 @@ end @kernel function ode_asolve_kernel(@Const(probs), alg, _us, _ts, dt, callback, tstops, - abstol, reltol, - saveat, - ::Val{save_everystep}) where {save_everystep} + abstol, reltol, + saveat, + ::Val{save_everystep}) where {save_everystep} i = @index(Global, Linear) # get the actual problem for this thread diff --git a/src/ensemblegpukernel/linalg/linsolve.jl b/src/ensemblegpukernel/linalg/linsolve.jl index 02d4f1a8..220c106b 100644 --- a/src/ensemblegpukernel/linalg/linsolve.jl +++ b/src/ensemblegpukernel/linalg/linsolve.jl @@ -6,16 +6,16 @@ end @inline function _linear_solve(::Size{(1, 1)}, - ::Size{(1,)}, - a::StaticMatrix{<:Any, <:Any, Ta}, - b::StaticVector{<:Any, Tb}) where {Ta, Tb} + ::Size{(1,)}, + a::StaticMatrix{<:Any, <:Any, Ta}, + b::StaticVector{<:Any, Tb}) where {Ta, Tb} @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1]) end @inline function _linear_solve(::Size{(2, 2)}, - ::Size{(2,)}, - a::StaticMatrix{<:Any, <:Any, Ta}, - b::StaticVector{<:Any, Tb}) where {Ta, Tb} + ::Size{(2,)}, + a::StaticMatrix{<:Any, <:Any, Ta}, + b::StaticVector{<:Any, Tb}) where {Ta, Tb} d = det(a) T = typeof((one(Ta) * zero(Tb) + one(Ta) * zero(Tb)) / d) @inbounds return similar_type(b, T)((a[2, 2] * b[1] - a[1, 2] * b[2]) / d, @@ -23,9 +23,9 @@ end end @inline function _linear_solve(::Size{(3, 3)}, - ::Size{(3,)}, - a::StaticMatrix{<:Any, <:Any, Ta}, - b::StaticVector{<:Any, Tb}) where {Ta, Tb} + ::Size{(3,)}, + a::StaticMatrix{<:Any, <:Any, Ta}, + b::StaticVector{<:Any, Tb}) where {Ta, Tb} d = det(a) T = typeof((one(Ta) * zero(Tb) + one(Ta) * zero(Tb)) / d) @inbounds return similar_type(b, T)(((a[2, 2] * a[3, 3] - a[2, 3] * a[3, 2]) * b[1] + @@ -43,9 +43,9 @@ end for Sa in [(2, 2), (3, 3)] # not needed for Sa = (1, 1); @eval begin @inline function _linear_solve(::Size{$Sa}, - ::Size{Sb}, - a::StaticMatrix{<:Any, <:Any, Ta}, - b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb} + ::Size{Sb}, + a::StaticMatrix{<:Any, <:Any, Ta}, + b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb} d = det(a) T = typeof((one(Ta) * zero(Tb) + one(Ta) * zero(Tb)) / d) if isbitstype(T) @@ -70,9 +70,9 @@ end end @generated function _linear_solve_general(::Size{Sa}, - ::Size{Sb}, - a::StaticMatrix{<:Any, <:Any, Ta}, - b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} + ::Size{Sb}, + a::StaticMatrix{<:Any, <:Any, Ta}, + b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} if Sa[1] != Sb[1] return quote throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)")) diff --git a/src/ensemblegpukernel/linalg/lu.jl b/src/ensemblegpukernel/linalg/lu.jl index 14b5e2fb..9c1f0ba0 100644 --- a/src/ensemblegpukernel/linalg/lu.jl +++ b/src/ensemblegpukernel/linalg/lu.jl @@ -80,7 +80,7 @@ function __lu(A::StaticMatrix{1, 1, T}, ::Val{Pivot}) where {T, Pivot} end function __lu(A::LinearAlgebra.HermOrSym{T, <:StaticMatrix{1, 1, T}}, - ::Val{Pivot}) where {T, Pivot} + ::Val{Pivot}) where {T, Pivot} (SMatrix{1, 1}(one(T)), A.data, SVector(1)) end diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index 5c5485ad..b3b97973 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -23,10 +23,10 @@ Only a subset of the common solver arguments are supported. function vectorized_solve end function vectorized_solve(probs, prob::ODEProblem, alg; - dt, saveat = nothing, - save_everystep = true, - debug = false, callback = CallbackSet(nothing), tstops = nothing, - kwargs...) + dt, saveat = nothing, + save_everystep = true, + debug = false, callback = CallbackSet(nothing), tstops = nothing, + kwargs...) backend = get_backend(probs) backend = maybe_prefer_blocks(backend) # if saveat is specified, we'll use a vector of timestamps. @@ -101,10 +101,10 @@ end # SDEProblems over GPU cannot support u0 as a Number type, because GPU kernels compiled only through u0 being StaticArrays function vectorized_solve(probs, prob::SDEProblem, alg; - dt, saveat = nothing, - save_everystep = true, - debug = false, - kwargs...) + dt, saveat = nothing, + save_everystep = true, + debug = false, + kwargs...) backend = get_backend(probs) backend = maybe_prefer_blocks(backend) @@ -177,11 +177,11 @@ Only a subset of the common solver arguments are supported. function vectorized_asolve end function vectorized_asolve(probs, prob::ODEProblem, alg; - dt = 0.1f0, saveat = nothing, - save_everystep = false, - abstol = 1.0f-6, reltol = 1.0f-3, - debug = false, callback = CallbackSet(nothing), tstops = nothing, - kwargs...) + dt = 0.1f0, saveat = nothing, + save_everystep = false, + abstol = 1.0f-6, reltol = 1.0f-3, + debug = false, callback = CallbackSet(nothing), tstops = nothing, + kwargs...) backend = get_backend(probs) backend = maybe_prefer_blocks(backend) @@ -242,9 +242,9 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; end function vectorized_asolve(probs, prob::SDEProblem, alg; - dt, saveat = nothing, - save_everystep = true, - debug = false, - kwargs...) + dt, saveat = nothing, + save_everystep = true, + debug = false, + kwargs...) error("Adaptive time-stepping is not supported yet with GPUEM.") end diff --git a/src/ensemblegpukernel/nlsolve/type.jl b/src/ensemblegpukernel/nlsolve/type.jl index 57bd4ffc..eb310ccb 100644 --- a/src/ensemblegpukernel/nlsolve/type.jl +++ b/src/ensemblegpukernel/nlsolve/type.jl @@ -20,7 +20,7 @@ struct NLSolver{uType, gamType, tmpType, tType, JType, WType, pType} <: Abstract end function NLSolver{tType}(z, tmp, ztmp, γ, c, α, κ, J, W, dt, t, p, - iter, maxiters, tmp2 = nothing) where {tType} + iter, maxiters, tmp2 = nothing) where {tType} NLSolver{typeof(z), typeof(γ), typeof(tmp2), tType, typeof(J), typeof(W), typeof(p)}(z, tmp, tmp2, @@ -72,16 +72,16 @@ end end @inline function build_nlsolver(alg, u, p, - t, dt, - f, - γ, c) + t, dt, + f, + γ, c) build_nlsolver(alg, u, p, t, dt, f, γ, c, 1) end @inline function build_nlsolver(alg, u, p, - t, dt, - f, - γ, c, α) + t, dt, + f, + γ, c, α) # define fields of non-linear solver z = u tmp = u diff --git a/src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl index 6557ac2d..ca09c2b8 100644 --- a/src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl @@ -1,5 +1,5 @@ @kernel function em_kernel(@Const(probs), _us, _ts, dt, - saveat, ::Val{save_everystep}) where {save_everystep} + saveat, ::Val{save_everystep}) where {save_everystep} i = @index(Global, Linear) # get the actual problem for this thread diff --git a/src/ensemblegpukernel/perform_step/gpu_siea_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_siea_perform_step.jl index a7dfc3fe..201983e1 100644 --- a/src/ensemblegpukernel/perform_step/gpu_siea_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_siea_perform_step.jl @@ -62,7 +62,7 @@ function SIEAConstantCache(::Type{T}, ::Type{T2}) where {T, T2} end @kernel function siea_kernel(@Const(probs), _us, _ts, dt, - saveat, ::Val{save_everystep}) where {save_everystep} + saveat, ::Val{save_everystep}) where {save_everystep} i = @index(Global, Linear) # get the actual problem for this thread diff --git a/src/ensemblegpukernel/problems/ode_problems.jl b/src/ensemblegpukernel/problems/ode_problems.jl index 33d715ff..d7f420d9 100644 --- a/src/ensemblegpukernel/problems/ode_problems.jl +++ b/src/ensemblegpukernel/problems/ode_problems.jl @@ -17,9 +17,9 @@ struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <: """An internal argument for storing traits about the solving process.""" problem_type::PT @add_kwonly function ImmutableODEProblem{iip}(f::AbstractODEFunction{iip}, - u0, tspan, p = NullParameters(), - problem_type = StandardODEProblem(); - kwargs...) where {iip} + u0, tspan, p = NullParameters(), + problem_type = StandardODEProblem(); + kwargs...) where {iip} _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) warn_paramtype(p) @@ -42,10 +42,10 @@ struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <: This is determined automatically, but not inferred. """ function ImmutableODEProblem{iip}(f, - u0, - tspan, - p = NullParameters(); - kwargs...) where {iip} + u0, + tspan, + p = NullParameters(); + kwargs...) where {iip} _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) _f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f) @@ -53,14 +53,14 @@ struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <: end @add_kwonly function ImmutableODEProblem{iip, recompile}(f, u0, tspan, - p = NullParameters(); - kwargs...) where {iip, recompile} + p = NullParameters(); + kwargs...) where {iip, recompile} ImmutableODEProblem{iip}(ODEFunction{iip, recompile}(f), u0, tspan, p; kwargs...) end function ImmutableODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan, - p = NullParameters(); - kwargs...) where {iip} + p = NullParameters(); + kwargs...) where {iip} _u0 = prepare_initial_state(u0) _tspan = promote_tspan(tspan) if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper) diff --git a/src/solve.jl b/src/solve.jl index 304da20b..ca41064c 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,11 +1,11 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, - alg::Union{SciMLBase.DEAlgorithm, Nothing, - DiffEqGPU.GPUODEAlgorithm, DiffEqGPU.GPUSDEAlgorithm}, - ensemblealg::Union{EnsembleArrayAlgorithm, - EnsembleKernelAlgorithm}; - trajectories, batch_size = trajectories, - unstable_check = (dt, u, p, t) -> false, adaptive = true, - kwargs...) + alg::Union{SciMLBase.DEAlgorithm, Nothing, + DiffEqGPU.GPUODEAlgorithm, DiffEqGPU.GPUSDEAlgorithm}, + ensemblealg::Union{EnsembleArrayAlgorithm, + EnsembleKernelAlgorithm}; + trajectories, batch_size = trajectories, + unstable_check = (dt, u, p, t) -> false, adaptive = true, + kwargs...) if trajectories == 1 return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories = 1, kwargs...) @@ -121,9 +121,9 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, end function batch_solve(ensembleprob, alg, - ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I, - adaptive; - kwargs...) + ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I, + adaptive; + kwargs...) @assert !isempty(I) #@assert all(p->p.f === probs[1].f,probs) @@ -251,7 +251,7 @@ function batch_solve(ensembleprob, alg, end function batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptive; - kwargs...) + kwargs...) _callback = CallbackSet(generate_callback(probs[1], length(I), ensemblealg; kwargs...)) _callback = CallbackSet(convert.(DiffEqGPU.GPUDiscreteCallback, @@ -321,9 +321,9 @@ function batch_solve_up(ensembleprob, probs, alg, ensemblealg, I, u0, p; kwargs. end function seed_duals(x::Matrix{V}, ::Type{T}, - ::ForwardDiff.Chunk{N} = ForwardDiff.Chunk(@view(x[:, 1]), - typemax(Int64))) where {V, T, - N} + ::ForwardDiff.Chunk{N} = ForwardDiff.Chunk(@view(x[:, 1]), + typemax(Int64))) where {V, T, + N} seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, V}) duals = [ForwardDiff.Dual{T}(x[i, j], seeds[i]) for i in 1:size(x, 1), j in 1:size(x, 2)] @@ -346,7 +346,7 @@ end struct DiffEqGPUAdjTag end function ChainRulesCore.rrule(::typeof(batch_solve_up), ensembleprob, probs, alg, - ensemblealg, I, u0, p; kwargs...) + ensemblealg, I, u0, p; kwargs...) pdual = seed_duals(p, DiffEqGPUAdjTag) u0 = convert.(eltype(pdual), u0) @@ -412,7 +412,7 @@ function ChainRulesCore.rrule(::typeof(batch_solve_up), ensembleprob, probs, alg end function solve_batch(prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_size; - kwargs...) + kwargs...) if length(II) == 1 || Threads.nthreads() == 1 return SciMLBase.solve_batch(prob, alg, EnsembleSerial(), II, pmap_batch_size; kwargs...)