diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 9c793591..959ad88a 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,3 @@ style = "sciml" -format_markdown = true \ No newline at end of file +format_markdown = true +format_docstrings = true \ No newline at end of file diff --git a/docs/pages.jl b/docs/pages.jl index 52975c00..4c1b1d3c 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,13 +2,15 @@ pages = ["index.md", "getting_started.md", - "Tutorials" => Any["GPU Ensembles" => Any["tutorials/gpu_ensemble_basic.md", + "Tutorials" => Any[ + "GPU Ensembles" => Any["tutorials/gpu_ensemble_basic.md", "tutorials/parallel_callbacks.md", "tutorials/multigpu.md", "tutorials/lower_level_api.md", "tutorials/weak_order_conv_sde.md"], "Within-Method GPU" => Any["tutorials/within_method_gpu.md"]], - "Examples" => Any["GPU Ensembles" => Any["examples/sde.md", + "Examples" => Any[ + "GPU Ensembles" => Any["examples/sde.md", "examples/ad.md", "examples/reductions.md"], "Within-Method GPU" => Any["examples/reaction_diffusion.md", @@ -17,5 +19,5 @@ pages = ["index.md", "manual/ensemblegpuarray.md", "manual/backends.md", "manual/optimal_trajectories.md", - "manual/choosing_ensembler.md"], + "manual/choosing_ensembler.md"] ] diff --git a/docs/src/examples/sde.md b/docs/src/examples/sde.md index 82987291..44bd24ef 100644 --- a/docs/src/examples/sde.md +++ b/docs/src/examples/sde.md @@ -27,6 +27,7 @@ prob = SDEProblem(lorenz, multiplicative_noise, u0, tspan, p) const pre_p = [rand(Float32, 3) for i in 1:10_000] prob_func = (prob, i, repeat) -> remake(prob, p = pre_p[i] .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func) -sol = solve(monteprob, SOSRI(), EnsembleGPUArray(CUDA.CUDABackend()), trajectories = 10_000, +sol = solve( + monteprob, SOSRI(), EnsembleGPUArray(CUDA.CUDABackend()), trajectories = 10_000, saveat = 1.0f0) ``` diff --git a/docs/src/tutorials/gpu_ensemble_basic.md b/docs/src/tutorials/gpu_ensemble_basic.md index f3115107..42fbd3af 100644 --- a/docs/src/tutorials/gpu_ensemble_basic.md +++ b/docs/src/tutorials/gpu_ensemble_basic.md @@ -25,7 +25,8 @@ Changing this to being GPU-parallelized is as simple as changing the ensemble me `EnsembleGPUArray`: ```@example lorenz -sol = solve(monteprob, Tsit5(), EnsembleGPUArray(CUDA.CUDABackend()), trajectories = 10_000, +sol = solve( + monteprob, Tsit5(), EnsembleGPUArray(CUDA.CUDABackend()), trajectories = 10_000, saveat = 1.0f0); ``` diff --git a/docs/src/tutorials/lower_level_api.md b/docs/src/tutorials/lower_level_api.md index 740a4da1..3afdc887 100644 --- a/docs/src/tutorials/lower_level_api.md +++ b/docs/src/tutorials/lower_level_api.md @@ -89,12 +89,14 @@ end ## Finally use the lower API for faster solves! (Fixed time-stepping) -@time CUDA.@sync sol = DiffEqGPU.vectorized_map_solve(probs, Tsit5(), EnsembleGPUArray(0.0), +@time CUDA.@sync sol = DiffEqGPU.vectorized_map_solve( + probs, Tsit5(), EnsembleGPUArray(0.0), batch, false, dt = 0.001f0, save_everystep = false, dense = false) ## Adaptive time-stepping (Notice the boolean argument) -@time CUDA.@sync sol = DiffEqGPU.vectorized_map_solve(probs, Tsit5(), EnsembleGPUArray(0.0), +@time CUDA.@sync sol = DiffEqGPU.vectorized_map_solve( + probs, Tsit5(), EnsembleGPUArray(0.0), batch, true, dt = 0.001f0, save_everystep = false, dense = false) ``` diff --git a/src/algorithms.jl b/src/algorithms.jl index dd330734..a539ca39 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -16,7 +16,7 @@ struct EnsembleCPUArray <: EnsembleArrayAlgorithm end """ ```julia -EnsembleGPUArray(backend,cpu_offload = 0.2) +EnsembleGPUArray(backend, cpu_offload = 0.2) ``` An `EnsembleArrayAlgorithm` which utilizes the GPU kernels to parallelize each ODE solve @@ -73,13 +73,14 @@ function lorenz(du, u, p, t) du[3] = u[1] * u[2] - p[3] * u[3] end -u0 = Float32[1.0;0.0;0.0] -tspan = (0.0f0,100.0f0) -p = [10.0f0,28.0f0,8/3f0] -prob = ODEProblem(lorenz,u0,tspan,p) -prob_func = (prob,i,repeat) -> remake(prob,p=rand(Float32,3).*p) -monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=false) -@time sol = solve(monteprob,Tsit5(),EnsembleGPUArray(CUDADevice()),trajectories=10_000,saveat=1.0f0) +u0 = Float32[1.0; 0.0; 0.0] +tspan = (0.0f0, 100.0f0) +p = [10.0f0, 28.0f0, 8 / 3.0f0] +prob = ODEProblem(lorenz, u0, tspan, p) +prob_func = (prob, i, repeat) -> remake(prob, p = rand(Float32, 3) .* p) +monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) +@time sol = solve(monteprob, Tsit5(), EnsembleGPUArray(CUDADevice()), + trajectories = 10_000, saveat = 1.0f0) ``` """ struct EnsembleGPUArray{Backend} <: EnsembleArrayAlgorithm @@ -89,7 +90,7 @@ end """ ```julia -EnsembleGPUKernel(backend,cpu_offload = 0.2) +EnsembleGPUKernel(backend, cpu_offload = 0.2) ``` A massively-parallel ensemble algorithm which generates a unique GPU kernel for the entire @@ -146,7 +147,7 @@ prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) @time sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(), trajectories = 10_000, - adaptive = false, dt = 0.1f0) + adaptive = false, dt = 0.1f0) ``` """ struct EnsembleGPUKernel{Dev} <: EnsembleKernelAlgorithm diff --git a/src/ensemblegpuarray/lowerlevel_solve.jl b/src/ensemblegpuarray/lowerlevel_solve.jl index de98ec35..2c279914 100644 --- a/src/ensemblegpuarray/lowerlevel_solve.jl +++ b/src/ensemblegpuarray/lowerlevel_solve.jl @@ -3,23 +3,22 @@ Lower level API for `EnsembleArrayAlgorithm`. Avoids conversion of solution to C ```julia vectorized_map_solve(probs, alg, - ensemblealg::Union{EnsembleArrayAlgorithm}, I, - adaptive) + ensemblealg::Union{EnsembleArrayAlgorithm}, I, + adaptive) ``` ## Arguments -- `probs`: the GPU-setup problems generated by the ensemble. -- `alg`: the kernel-based differential equation solver. Most of the solvers from OrdinaryDiffEq.jl - are supported. -- `ensemblealg`: The `EnsembleGPUArray()` algorithm. -- `I`: The iterator argument. Can be set to for e.g. 1:10_000 to simulate 10,000 trajectories. -- `adaptive`: The Boolean argument for time-stepping. Use `true` to enable adaptive time-stepping. + - `probs`: the GPU-setup problems generated by the ensemble. + - `alg`: the kernel-based differential equation solver. Most of the solvers from OrdinaryDiffEq.jl + are supported. + - `ensemblealg`: The `EnsembleGPUArray()` algorithm. + - `I`: The iterator argument. Can be set to for e.g. 1:10_000 to simulate 10,000 trajectories. + - `adaptive`: The Boolean argument for time-stepping. Use `true` to enable adaptive time-stepping. ## Keyword Arguments Only a subset of the common solver arguments are supported. - """ function vectorized_map_solve end diff --git a/src/ensemblegpukernel/callbacks.jl b/src/ensemblegpukernel/callbacks.jl index aa537a9c..e382100a 100644 --- a/src/ensemblegpukernel/callbacks.jl +++ b/src/ensemblegpukernel/callbacks.jl @@ -46,7 +46,7 @@ struct GPUContinuousCallback{F1, F2, F3, F4, F5, F6, T, T2, T3, I, R} <: interp_points, save_positions::F6, dtrelax::R, abstol::T, reltol::T2, repeat_nudge::T3) where {F1, F2, F3, F4, F5, F6, T, T2, - T3, I, R, + T3, I, R } if save_positions != (false, false) error("Callback `save_positions` are incompatible with kernel-based GPU ODE solvers due requiring static sizing. Please ensure `save_positions = (false,false)` is set in all callback definitions used with such solvers.") diff --git a/src/ensemblegpukernel/integrators/integrator_utils.jl b/src/ensemblegpukernel/integrators/integrator_utils.jl index bad07883..fcfe1b33 100644 --- a/src/ensemblegpukernel/integrators/integrator_utils.jl +++ b/src/ensemblegpukernel/integrators/integrator_utils.jl @@ -10,11 +10,12 @@ function build_adaptive_controller_cache(alg::A, ::Type{T}) where {A, T} return beta1, beta2, qmax, qmin, gamma, qoldinit, qold end -@inline function savevalues!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function savevalues!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, ts, us, force = false) where {AlgType <: GPUODEAlgorithm, IIP, S, T} @@ -45,7 +46,8 @@ end saved, savedexactly end -@inline function DiffEqBase.terminate!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, +@inline function DiffEqBase.terminate!( + integrator::DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}, retcode = ReturnCode.Terminated) where { @@ -53,15 +55,16 @@ end GPUODEAlgorithm, IIP, S, - T, + T } integrator.retcode = retcode end -@inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function apply_discrete_callback!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, - S, T, + S, T }, ts, us, callback::GPUDiscreteCallback) where { @@ -79,10 +82,11 @@ end integrator.u_modified, saved_in_cb end -@inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function apply_discrete_callback!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, - S, T, + S, T }, ts, us, callback::GPUDiscreteCallback, @@ -93,10 +97,11 @@ end args...) end -@inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function apply_discrete_callback!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, - S, T, + S, T }, ts, us, discrete_modified::Bool, @@ -110,10 +115,11 @@ end discrete_modified || bool, saved_in_cb || saved_in_cb2 end -@inline function apply_discrete_callback!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function apply_discrete_callback!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, - S, T, + S, T }, ts, us, discrete_modified::Bool, @@ -126,11 +132,12 @@ end discrete_modified || bool, saved_in_cb || saved_in_cb2 end -@inline function interpolate(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function interpolate( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, t) where {AlgType <: GPUODEAlgorithm, IIP, S, T} θ = (t - integrator.tprev) / integrator.dt @@ -142,11 +149,12 @@ end b7θ * integrator.k7) end -@inline function _change_t_via_interpolation!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function _change_t_via_interpolation!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, t, modify_save_endpoint::Type{Val{T1}}) where { @@ -155,7 +163,7 @@ end IIP, S, T, - T1, + T1 } # Can get rid of an allocation here with a function # get_tmp_arr(integrator.cache) which gives a pointer to some @@ -169,11 +177,12 @@ end #integrator.dt = integrator.t - integrator.tprev end end -@inline function DiffEqBase.change_t_via_interpolation!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function DiffEqBase.change_t_via_interpolation!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, t, modify_save_endpoint::Type{Val{T1}} = Val{ @@ -184,12 +193,13 @@ end IIP, S, T, - T1, + T1 } _change_t_via_interpolation!(integrator, t, modify_save_endpoint) end -@inline function apply_callback!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, IIP, +@inline function apply_callback!( + integrator::DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}, callback::GPUContinuousCallback, cb_time, prev_sign, event_idx, ts, @@ -219,7 +229,8 @@ end true, saved_in_cb end -@inline function handle_callbacks!(integrator::DiffEqBase.AbstractODEIntegrator{AlgType, +@inline function handle_callbacks!( + integrator::DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T}, ts, us) where {AlgType <: GPUODEAlgorithm, IIP, S, T} discrete_callbacks = integrator.callback.discrete_callbacks @@ -232,7 +243,8 @@ end if !(continuous_callbacks isa Tuple{}) event_occurred = false - time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator, + time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback( + integrator, continuous_callbacks...) if event_occurred @@ -256,16 +268,18 @@ end return false, saved_in_cb end -@inline function DiffEqBase.find_callback_time(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function DiffEqBase.find_callback_time( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, callback::DiffEqGPU.GPUContinuousCallback, counter) where {AlgType <: GPUODEAlgorithm, IIP, S, T} - event_occurred, interp_index, prev_sign, prev_sign_index, event_idx = DiffEqBase.determine_event_occurance(integrator, + event_occurred, interp_index, prev_sign, prev_sign_index, event_idx = DiffEqBase.determine_event_occurance( + integrator, callback, counter) @@ -321,15 +335,16 @@ end GPUODEAlgorithm, IIP, S, - T, + T } return nothing end -@inline function DiffEqBase.get_condition(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function DiffEqBase.get_condition( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, - S, T, + S, T }, callback, abst) where {AlgType <: GPUODEAlgorithm, IIP, S, T @@ -345,11 +360,12 @@ end end # interp_points = 0 or equivalently nothing -@inline function DiffEqBase.determine_event_occurance(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function DiffEqBase.determine_event_occurance( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, - T, + T }, callback::DiffEqGPU.GPUContinuousCallback, counter) where { diff --git a/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl b/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl index 7163ec4b..fc4b1cd4 100644 --- a/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl +++ b/src/ensemblegpukernel/integrators/nonstiff/interpolants.jl @@ -1,13 +1,13 @@ # Default: Hermite Interpolation @inline @muladd function _ode_interpolant(Θ, dt, y₀, integ::DiffEqBase.AbstractODEIntegrator{AlgType, - IIP, S, T, + IIP, S, T }) where { AlgType <: GPUODEAlgorithm, IIP, S, - T, + T } y₁ = integ.u k1 = integ.k1 @@ -44,7 +44,7 @@ end @inline @muladd function _ode_interpolant(Θ, dt, y₀, integ::T) where {T <: - Union{GPUV7I, GPUAV7I}} + Union{GPUV7I, GPUAV7I}} b1Θ, b4Θ, b5Θ, b6Θ, b7Θ, b8Θ, b9Θ, b11Θ, b12Θ, b13Θ, b14Θ, b15Θ, b16Θ = bΘs(integ, Θ) @unpack c11, a1101, a1104, a1105, a1106, a1107, a1108, a1109, c12, a1201, a1204, @@ -56,29 +56,46 @@ end @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, uprev, f, t, p = integ - k11 = f(uprev + - dt * (a1101 * k1 + a1104 * k4 + a1105 * k5 + a1106 * k6 + - a1107 * k7 + a1108 * k8 + a1109 * k9), p, t + c11 * dt) - k12 = f(uprev + - dt * (a1201 * k1 + a1204 * k4 + a1205 * k5 + a1206 * k6 + - a1207 * k7 + a1208 * k8 + a1209 * k9 + a1211 * k11), p, + k11 = f( + uprev + + dt * (a1101 * k1 + a1104 * k4 + a1105 * k5 + a1106 * k6 + + a1107 * k7 + a1108 * k8 + a1109 * k9), + p, + t + c11 * dt) + k12 = f( + uprev + + dt * (a1201 * k1 + a1204 * k4 + a1205 * k5 + a1206 * k6 + + a1207 * k7 + a1208 * k8 + a1209 * k9 + a1211 * k11), + p, t + c12 * dt) - k13 = f(uprev + - dt * (a1301 * k1 + a1304 * k4 + a1305 * k5 + a1306 * k6 + - a1307 * k7 + a1308 * k8 + a1309 * k9 + a1311 * k11 + - a1312 * k12), p, t + c13 * dt) - k14 = f(uprev + - dt * (a1401 * k1 + a1404 * k4 + a1405 * k5 + a1406 * k6 + - a1407 * k7 + a1408 * k8 + a1409 * k9 + a1411 * k11 + - a1412 * k12 + a1413 * k13), p, t + c14 * dt) - k15 = f(uprev + - dt * (a1501 * k1 + a1504 * k4 + a1505 * k5 + a1506 * k6 + - a1507 * k7 + a1508 * k8 + a1509 * k9 + a1511 * k11 + - a1512 * k12 + a1513 * k13), p, t + c15 * dt) - k16 = f(uprev + - dt * (a1601 * k1 + a1604 * k4 + a1605 * k5 + a1606 * k6 + - a1607 * k7 + a1608 * k8 + a1609 * k9 + a1611 * k11 + - a1612 * k12 + a1613 * k13), p, t + c16 * dt) + k13 = f( + uprev + + dt * (a1301 * k1 + a1304 * k4 + a1305 * k5 + a1306 * k6 + + a1307 * k7 + a1308 * k8 + a1309 * k9 + a1311 * k11 + + a1312 * k12), + p, + t + c13 * dt) + k14 = f( + uprev + + dt * (a1401 * k1 + a1404 * k4 + a1405 * k5 + a1406 * k6 + + a1407 * k7 + a1408 * k8 + a1409 * k9 + a1411 * k11 + + a1412 * k12 + a1413 * k13), + p, + t + c14 * dt) + k15 = f( + uprev + + dt * (a1501 * k1 + a1504 * k4 + a1505 * k5 + a1506 * k6 + + a1507 * k7 + a1508 * k8 + a1509 * k9 + a1511 * k11 + + a1512 * k12 + a1513 * k13), + p, + t + c15 * dt) + k16 = f( + uprev + + dt * (a1601 * k1 + a1604 * k4 + a1605 * k5 + a1606 * k6 + + a1607 * k7 + a1608 * k8 + a1609 * k9 + a1611 * k11 + + a1612 * k12 + a1613 * k13), + p, + t + c16 * dt) return y₀ + dt * (integ.k1 * b1Θ @@ -128,7 +145,7 @@ end @inline @muladd function _ode_interpolant(Θ, dt, y₀, integ::T) where {T <: - Union{GPUV9I, GPUAV9I}} + Union{GPUV9I, GPUAV9I}} b1Θ, b8Θ, b9Θ, b10Θ, b11Θ, b12Θ, b13Θ, b14Θ, b15Θ, b17Θ, b18Θ, b19Θ, b20Θ, b21Θ, b22Θ, b23Θ, b24Θ, b25Θ, b26Θ = bΘs(integ, Θ) @@ -146,53 +163,80 @@ end @unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, uprev, f, t, p = integ - k11 = f(uprev + - dt * (a1701 * k1 + a1708 * k2 + a1709 * k3 + a1710 * k4 + - a1711 * k5 + a1712 * k6 + a1713 * k7 + a1714 * k8 + a1715 * k9), + k11 = f( + uprev + + dt * (a1701 * k1 + a1708 * k2 + a1709 * k3 + a1710 * k4 + + a1711 * k5 + a1712 * k6 + a1713 * k7 + a1714 * k8 + a1715 * k9), p, t + c17 * dt) - k12 = f(uprev + - dt * (a1801 * k1 + a1808 * k2 + a1809 * k3 + a1810 * k4 + - a1811 * k5 + a1812 * k6 + a1813 * k7 + a1814 * k8 + - a1815 * k9 + a1817 * k11), p, t + c18 * dt) - k13 = f(uprev + - dt * (a1901 * k1 + a1908 * k2 + a1909 * k3 + a1910 * k4 + - a1911 * k5 + a1912 * k6 + a1913 * k7 + a1914 * k8 + - a1915 * k9 + a1917 * k11 + a1918 * k12), p, t + c19 * dt) - k14 = f(uprev + - dt * (a2001 * k1 + a2008 * k2 + a2009 * k3 + a2010 * k4 + - a2011 * k5 + a2012 * k6 + a2013 * k7 + a2014 * k8 + - a2015 * k9 + a2017 * k11 + a2018 * k12 + a2019 * k13), p, + k12 = f( + uprev + + dt * (a1801 * k1 + a1808 * k2 + a1809 * k3 + a1810 * k4 + + a1811 * k5 + a1812 * k6 + a1813 * k7 + a1814 * k8 + + a1815 * k9 + a1817 * k11), + p, + t + c18 * dt) + k13 = f( + uprev + + dt * (a1901 * k1 + a1908 * k2 + a1909 * k3 + a1910 * k4 + + a1911 * k5 + a1912 * k6 + a1913 * k7 + a1914 * k8 + + a1915 * k9 + a1917 * k11 + a1918 * k12), + p, + t + c19 * dt) + k14 = f( + uprev + + dt * (a2001 * k1 + a2008 * k2 + a2009 * k3 + a2010 * k4 + + a2011 * k5 + a2012 * k6 + a2013 * k7 + a2014 * k8 + + a2015 * k9 + a2017 * k11 + a2018 * k12 + a2019 * k13), + p, t + c20 * dt) - k15 = f(uprev + - dt * (a2101 * k1 + a2108 * k2 + a2109 * k3 + a2110 * k4 + - a2111 * k5 + a2112 * k6 + a2113 * k7 + a2114 * k8 + - a2115 * k9 + a2117 * k11 + a2118 * k12 + a2119 * k13 + - a2120 * k14), p, t + c21 * dt) - k16 = f(uprev + - dt * (a2201 * k1 + a2208 * k2 + a2209 * k3 + a2210 * k4 + - a2211 * k5 + a2212 * k6 + a2213 * k7 + a2214 * k8 + - a2215 * k9 + a2217 * k11 + a2218 * k12 + a2219 * k13 + - a2220 * k14 + a2221 * k15), p, t + c22 * dt) - k17 = f(uprev + - dt * (a2301 * k1 + a2308 * k2 + a2309 * k3 + a2310 * k4 + - a2311 * k5 + a2312 * k6 + a2313 * k7 + a2314 * k8 + - a2315 * k9 + a2317 * k11 + a2318 * k12 + a2319 * k13 + - a2320 * k14 + a2321 * k15), p, t + c23 * dt) - k18 = f(uprev + - dt * (a2401 * k1 + a2408 * k2 + a2409 * k3 + a2410 * k4 + - a2411 * k5 + a2412 * k6 + a2413 * k7 + a2414 * k8 + - a2415 * k9 + a2417 * k11 + a2418 * k12 + a2419 * k13 + - a2420 * k14 + a2421 * k15), p, t + c24 * dt) - k19 = f(uprev + - dt * (a2501 * k1 + a2508 * k2 + a2509 * k3 + a2510 * k4 + - a2511 * k5 + a2512 * k6 + a2513 * k7 + a2514 * k8 + - a2515 * k9 + a2517 * k11 + a2518 * k12 + a2519 * k13 + - a2520 * k14 + a2521 * k15), p, t + c25 * dt) - k20 = f(uprev + - dt * (a2601 * k1 + a2608 * k2 + a2609 * k3 + a2610 * k4 + - a2611 * k5 + a2612 * k6 + a2613 * k7 + a2614 * k8 + - a2615 * k9 + a2617 * k11 + a2618 * k12 + a2619 * k13 + - a2620 * k14 + a2621 * k15), p, t + c26 * dt) + k15 = f( + uprev + + dt * (a2101 * k1 + a2108 * k2 + a2109 * k3 + a2110 * k4 + + a2111 * k5 + a2112 * k6 + a2113 * k7 + a2114 * k8 + + a2115 * k9 + a2117 * k11 + a2118 * k12 + a2119 * k13 + + a2120 * k14), + p, + t + c21 * dt) + k16 = f( + uprev + + dt * (a2201 * k1 + a2208 * k2 + a2209 * k3 + a2210 * k4 + + a2211 * k5 + a2212 * k6 + a2213 * k7 + a2214 * k8 + + a2215 * k9 + a2217 * k11 + a2218 * k12 + a2219 * k13 + + a2220 * k14 + a2221 * k15), + p, + t + c22 * dt) + k17 = f( + uprev + + dt * (a2301 * k1 + a2308 * k2 + a2309 * k3 + a2310 * k4 + + a2311 * k5 + a2312 * k6 + a2313 * k7 + a2314 * k8 + + a2315 * k9 + a2317 * k11 + a2318 * k12 + a2319 * k13 + + a2320 * k14 + a2321 * k15), + p, + t + c23 * dt) + k18 = f( + uprev + + dt * (a2401 * k1 + a2408 * k2 + a2409 * k3 + a2410 * k4 + + a2411 * k5 + a2412 * k6 + a2413 * k7 + a2414 * k8 + + a2415 * k9 + a2417 * k11 + a2418 * k12 + a2419 * k13 + + a2420 * k14 + a2421 * k15), + p, + t + c24 * dt) + k19 = f( + uprev + + dt * (a2501 * k1 + a2508 * k2 + a2509 * k3 + a2510 * k4 + + a2511 * k5 + a2512 * k6 + a2513 * k7 + a2514 * k8 + + a2515 * k9 + a2517 * k11 + a2518 * k12 + a2519 * k13 + + a2520 * k14 + a2521 * k15), + p, + t + c25 * dt) + k20 = f( + uprev + + dt * (a2601 * k1 + a2608 * k2 + a2609 * k3 + a2610 * k4 + + a2611 * k5 + a2612 * k6 + a2613 * k7 + a2614 * k8 + + a2615 * k9 + a2617 * k11 + a2618 * k12 + a2619 * k13 + + a2620 * k14 + a2621 * k15), + p, + t + c26 * dt) return y₀ + dt * @@ -207,7 +251,7 @@ end @inline @muladd function _ode_interpolant(Θ, dt, y₀, integ::T) where {T <: - Union{GPUT5I, GPUAT5I}} + Union{GPUT5I, GPUAT5I}} b1θ, b2θ, b3θ, b4θ, b5θ, b6θ, b7θ = SimpleDiffEq.bθs(integ.rs, Θ) return y₀ + dt * @@ -218,7 +262,7 @@ end @inline @muladd function _ode_interpolant(Θ, dt, y₀, integ::T) where {T <: - Union{GPURB23I, GPUARB23I}} + Union{GPURB23I, GPUARB23I}} c1 = Θ * (1 - Θ) / (1 - 2 * integ.d) c2 = Θ * (Θ - 2 * integ.d) / (1 - 2 * integ.d) return y₀ + dt * (c1 * integ.k1 + c2 * integ.k2) diff --git a/src/ensemblegpukernel/integrators/nonstiff/types.jl b/src/ensemblegpukernel/integrators/nonstiff/types.jl index 7c9e3ba1..22f7ca69 100644 --- a/src/ensemblegpukernel/integrators/nonstiff/types.jl +++ b/src/ensemblegpukernel/integrators/nonstiff/types.jl @@ -354,7 +354,8 @@ end vector_event_last_time = 0 last_event_error = zero(T) - integ = GPUAT5I{IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(alg)}(alg, + integ = GPUAT5I{IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(alg)}( + alg, f, copy(u0), copy(u0), diff --git a/src/ensemblegpukernel/integrators/stiff/types.jl b/src/ensemblegpukernel/integrators/stiff/types.jl index 021f8bf1..27f34f75 100644 --- a/src/ensemblegpukernel/integrators/stiff/types.jl +++ b/src/ensemblegpukernel/integrators/stiff/types.jl @@ -2,19 +2,20 @@ AlgType, IIP, S, - T, + T })(t) where { AlgType <: GPUODEAlgorithm, IIP, S, - T, + T } Θ = (t - integrator.tprev) / integrator.dt _ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator) end -@inline function DiffEqBase.u_modified!(integrator::DiffEqBase.AbstractODEIntegrator{ +@inline function DiffEqBase.u_modified!( + integrator::DiffEqBase.AbstractODEIntegrator{ AlgType, IIP, S, T}, @@ -94,7 +95,7 @@ mutable struct GPUARosenbrock23Integrator{ Q, TS, CB, - AlgType, + AlgType } <: DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T} alg::AlgType @@ -151,7 +152,8 @@ const GPUARB23I = GPUARosenbrock23Integrator d = T(2) d = 1 / (d + sqrt(d)) - integ = GPUARB23I{IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(alg)}(alg, + integ = GPUARB23I{IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(alg)}( + alg, f, copy(u0), copy(u0), @@ -271,7 +273,7 @@ mutable struct GPUARodas4Integrator{ TS, CB, TabType, - AlgType, + AlgType } <: DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T} alg::AlgType @@ -435,7 +437,7 @@ end # Adaptive Step mutable struct GPUARodas5PIntegrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, TabType, - AlgType, + AlgType } <: DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T} alg::AlgType @@ -490,7 +492,8 @@ const GPUARodas5PI = GPUARodas5PIntegrator tab = Rodas5PTableau(T, T) - integ = GPUARodas5PI{IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(tab), + integ = GPUARodas5PI{ + IIP, S, T, ST, P, F, N, TOL, typeof(qoldinit), TS, CB, typeof(tab), typeof(alg)}(alg, f, copy(u0), @@ -601,7 +604,7 @@ end # Adaptive Step mutable struct GPUAKvaerno3Integrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, TabType, - AlgType, + AlgType } <: DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T} alg::AlgType @@ -766,7 +769,7 @@ end # Adaptive Step mutable struct GPUAKvaerno5Integrator{IIP, S, T, ST, P, F, N, TOL, Q, TS, CB, TabType, - AlgType, + AlgType } <: DiffEqBase.AbstractODEIntegrator{AlgType, IIP, S, T} alg::AlgType diff --git a/src/ensemblegpukernel/linalg/linsolve.jl b/src/ensemblegpukernel/linalg/linsolve.jl index 220c106b..8dc3b98b 100644 --- a/src/ensemblegpukernel/linalg/linsolve.jl +++ b/src/ensemblegpukernel/linalg/linsolve.jl @@ -28,10 +28,11 @@ end b::StaticVector{<:Any, Tb}) where {Ta, Tb} d = det(a) T = typeof((one(Ta) * zero(Tb) + one(Ta) * zero(Tb)) / d) - @inbounds return similar_type(b, T)(((a[2, 2] * a[3, 3] - a[2, 3] * a[3, 2]) * b[1] + - (a[1, 3] * a[3, 2] - a[1, 2] * a[3, 3]) * b[2] + - (a[1, 2] * a[2, 3] - a[1, 3] * a[2, 2]) * b[3]) / - d, + @inbounds return similar_type(b, T)( + ((a[2, 2] * a[3, 3] - a[2, 3] * a[3, 2]) * b[1] + + (a[1, 3] * a[3, 2] - a[1, 2] * a[3, 3]) * b[2] + + (a[1, 2] * a[2, 3] - a[1, 3] * a[2, 2]) * b[3]) / + d, ((a[2, 3] * a[3, 1] - a[2, 1] * a[3, 3]) * b[1] + (a[1, 1] * a[3, 3] - a[1, 3] * a[3, 1]) * b[2] + (a[1, 3] * a[2, 1] - a[1, 1] * a[2, 3]) * b[3]) / d, diff --git a/src/ensemblegpukernel/linalg/lu.jl b/src/ensemblegpukernel/linalg/lu.jl index 9c1f0ba0..c10fda61 100644 --- a/src/ensemblegpukernel/linalg/lu.jl +++ b/src/ensemblegpukernel/linalg/lu.jl @@ -148,8 +148,8 @@ function __lu(A::StaticLUMatrix{M, N, T}, ::Val{Pivot}) where {M, N, T, Pivot} Lrest, Urest, prest = __lu(Arest, Val(Pivot)) p = [SVector{1, Int}(kp); ps[prest]] L = [[SVector{1}(one(eltype(Ls))); Ls[prest]] [zeros(typeof(SMatrix{1}(Lrest[1, - :]))); - Lrest]] + :]))); + Lrest]] U = [Ufirst; [zeros(typeof(Urest[:, 1])) Urest]] end return (L, U, p) diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index b3b97973..b3c48779 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -1,9 +1,9 @@ """ ```julia vectorized_solve(probs, prob::Union{ODEProblem, SDEProblem}alg; - dt, saveat = nothing, - save_everystep = true, - debug = false, callback = CallbackSet(nothing), tstops = nothing) + dt, saveat = nothing, + save_everystep = true, + debug = false, callback = CallbackSet(nothing), tstops = nothing) ``` A lower level interface to the kernel generation solvers of EnsembleGPUKernel with fixed @@ -55,22 +55,24 @@ function vectorized_solve(probs, prob::ODEProblem, alg; _saveat = range(convert(eltype(prob.tspan), first(saveat)), convert(eltype(prob.tspan), last(saveat)), length = length(saveat)) - convert(StepRangeLen{ + convert( + StepRangeLen{ eltype(_saveat), eltype(_saveat), eltype(_saveat), - eltype(_saveat) === Float32 ? Int32 : Int64, + eltype(_saveat) === Float32 ? Int32 : Int64 }, _saveat) elseif saveat isa AbstractVector adapt(backend, convert.(eltype(prob.tspan), saveat)) else _saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] - convert(StepRangeLen{ + convert( + StepRangeLen{ eltype(_saveat), eltype(_saveat), eltype(_saveat), - eltype(_saveat) === Float32 ? Int32 : Int64, + eltype(_saveat) === Float32 ? Int32 : Int64 }, _saveat) end @@ -154,10 +156,10 @@ end """ ```julia vectorized_asolve(probs, prob::ODEProblem, alg; - dt = 0.1f0, saveat = nothing, - save_everystep = false, - abstol = 1.0f-6, reltol = 1.0f-3, - callback = CallbackSet(nothing), tstops = nothing) + dt = 0.1f0, saveat = nothing, + save_everystep = false, + abstol = 1.0f-6, reltol = 1.0f-3, + callback = CallbackSet(nothing), tstops = nothing) ``` A lower level interface to the kernel generation solvers of EnsembleGPUKernel with adaptive diff --git a/src/ensemblegpukernel/perform_step/gpu_kvaerno3_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_kvaerno3_perform_step.jl index 8f231c37..5863f206 100644 --- a/src/ensemblegpukernel/perform_step/gpu_kvaerno3_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_kvaerno3_perform_step.jl @@ -86,7 +86,8 @@ end @inline function step!(integ::GPUAKvaerno3I{false, S, T}, ts, us) where {T, S} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew diff --git a/src/ensemblegpukernel/perform_step/gpu_kvaerno5_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_kvaerno5_perform_step.jl index 566d1c4f..dc17ca18 100644 --- a/src/ensemblegpukernel/perform_step/gpu_kvaerno5_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_kvaerno5_perform_step.jl @@ -117,7 +117,8 @@ end @inline function step!(integ::GPUAKvaerno5I{false, S, T}, ts, us) where {T, S} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew diff --git a/src/ensemblegpukernel/perform_step/gpu_rodas4_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_rodas4_perform_step.jl index 89bce2f1..b72b34ba 100644 --- a/src/ensemblegpukernel/perform_step/gpu_rodas4_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_rodas4_perform_step.jl @@ -115,7 +115,8 @@ end @inline function step!(integ::GPUARodas4I{false, S, T}, ts, us) where {T, S} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew diff --git a/src/ensemblegpukernel/perform_step/gpu_rodas5P_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_rodas5P_perform_step.jl index f4b900e6..4c3b8909 100644 --- a/src/ensemblegpukernel/perform_step/gpu_rodas5P_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_rodas5P_perform_step.jl @@ -147,7 +147,8 @@ end @inline function step!(integ::GPUARodas5PI{false, S, T}, ts, us) where {T, S} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew diff --git a/src/ensemblegpukernel/perform_step/gpu_rosenbrock23_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_rosenbrock23_perform_step.jl index 27481271..4613f272 100644 --- a/src/ensemblegpukernel/perform_step/gpu_rosenbrock23_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_rosenbrock23_perform_step.jl @@ -72,7 +72,8 @@ end #############################Adaptive Version##################################### @inline function step!(integ::GPUARB23I{false, S, T}, ts, us) where {S, T} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew t = integ.t diff --git a/src/ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl index 158e4d2c..9fbce179 100644 --- a/src/ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl @@ -66,7 +66,8 @@ end #############################Adaptive Version##################################### @inline function step!(integ::GPUAT5I{false, S, T}, ts, us) where {S, T} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) c1, c2, c3, c4, c5, c6 = integ.cs dt = integ.dtnew diff --git a/src/ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl index 29ddd4a5..ed523f6b 100644 --- a/src/ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl @@ -42,8 +42,10 @@ k6 = f(uprev + dt * (a061 * k1 + a063 * k3 + a064 * k4 + a065 * k5), p, t + c6 * dt) k7 = f(uprev + dt * (a071 * k1 + a073 * k3 + a074 * k4 + a075 * k5 + a076 * k6), p, t + c7 * dt) - k8 = f(uprev + - dt * (a081 * k1 + a083 * k3 + a084 * k4 + a085 * k5 + a086 * k6 + a087 * k7), p, + k8 = f( + uprev + + dt * (a081 * k1 + a083 * k3 + a084 * k4 + a085 * k5 + a086 * k6 + a087 * k7), + p, t + c8 * dt) g9 = uprev + dt * @@ -77,7 +79,8 @@ end #############################Adaptive Version##################################### @inline function step!(integ::GPUAV7I{false, S, T}, ts, us) where {S, T} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew @@ -121,8 +124,9 @@ end k6 = f(uprev + dt * (a061 * k1 + a063 * k3 + a064 * k4 + a065 * k5), p, t + c6 * dt) k7 = f(uprev + dt * (a071 * k1 + a073 * k3 + a074 * k4 + a075 * k5 + a076 * k6), p, t + c7 * dt) - k8 = f(uprev + - dt * (a081 * k1 + a083 * k3 + a084 * k4 + a085 * k5 + a086 * k6 + a087 * k7), + k8 = f( + uprev + + dt * (a081 * k1 + a083 * k3 + a084 * k4 + a085 * k5 + a086 * k6 + a087 * k7), p, t + c8 * dt) g9 = uprev + diff --git a/src/ensemblegpukernel/perform_step/gpu_vern9_perform_step.jl b/src/ensemblegpukernel/perform_step/gpu_vern9_perform_step.jl index 9f19f7f3..d901c5b3 100644 --- a/src/ensemblegpukernel/perform_step/gpu_vern9_perform_step.jl +++ b/src/ensemblegpukernel/perform_step/gpu_vern9_perform_step.jl @@ -51,22 +51,32 @@ k9 = f(uprev + dt * (a0901 * k1 + a0906 * k6 + a0907 * k7 + a0908 * k8), p, t + c8 * dt) k10 = f(uprev + dt * (a1001 * k1 + a1006 * k6 + a1007 * k7 + a1008 * k8 + a1009 * k9), p, t + c9 * dt) - k11 = f(uprev + - dt * - (a1101 * k1 + a1106 * k6 + a1107 * k7 + a1108 * k8 + a1109 * k9 + a1110 * k10), + k11 = f( + uprev + + dt * + (a1101 * k1 + a1106 * k6 + a1107 * k7 + a1108 * k8 + a1109 * k9 + a1110 * k10), p, t + c10 * dt) - k12 = f(uprev + - dt * - (a1201 * k1 + a1206 * k6 + a1207 * k7 + a1208 * k8 + a1209 * k9 + a1210 * k10 + - a1211 * k11), p, t + c11 * dt) - k13 = f(uprev + - dt * - (a1301 * k1 + a1306 * k6 + a1307 * k7 + a1308 * k8 + a1309 * k9 + a1310 * k10 + - a1311 * k11 + a1312 * k12), p, t + c12 * dt) - k14 = f(uprev + - dt * - (a1401 * k1 + a1406 * k6 + a1407 * k7 + a1408 * k8 + a1409 * k9 + a1410 * k10 + - a1411 * k11 + a1412 * k12 + a1413 * k13), p, t + c13 * dt) + k12 = f( + uprev + + dt * + (a1201 * k1 + a1206 * k6 + a1207 * k7 + a1208 * k8 + a1209 * k9 + a1210 * k10 + + a1211 * k11), + p, + t + c11 * dt) + k13 = f( + uprev + + dt * + (a1301 * k1 + a1306 * k6 + a1307 * k7 + a1308 * k8 + a1309 * k9 + a1310 * k10 + + a1311 * k11 + a1312 * k12), + p, + t + c12 * dt) + k14 = f( + uprev + + dt * + (a1401 * k1 + a1406 * k6 + a1407 * k7 + a1408 * k8 + a1409 * k9 + a1410 * k10 + + a1411 * k11 + a1412 * k12 + a1413 * k13), + p, + t + c13 * dt) g15 = uprev + dt * (a1501 * k1 + a1506 * k6 + a1507 * k7 + a1508 * k8 + a1509 * k9 + a1510 * k10 + @@ -104,7 +114,8 @@ end #############################Adaptive Version##################################### @inline function step!(integ::GPUAV9I{false, S, T}, ts, us) where {S, T} - beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache(integ.alg, + beta1, beta2, qmax, qmin, gamma, qoldinit, _ = build_adaptive_controller_cache( + integ.alg, T) dt = integ.dtnew @@ -155,29 +166,40 @@ end k8 = f(uprev + dt * (a0801 * k1 + a0806 * k6 + a0807 * k7), p, t + c7 * dt) k9 = f(uprev + dt * (a0901 * k1 + a0906 * k6 + a0907 * k7 + a0908 * k8), p, t + c8 * dt) - k10 = f(uprev + - dt * (a1001 * k1 + a1006 * k6 + a1007 * k7 + a1008 * k8 + a1009 * k9), + k10 = f( + uprev + + dt * (a1001 * k1 + a1006 * k6 + a1007 * k7 + a1008 * k8 + a1009 * k9), p, t + c9 * dt) - k11 = f(uprev + - dt * - (a1101 * k1 + a1106 * k6 + a1107 * k7 + a1108 * k8 + a1109 * k9 + - a1110 * k10), + k11 = f( + uprev + + dt * + (a1101 * k1 + a1106 * k6 + a1107 * k7 + a1108 * k8 + a1109 * k9 + + a1110 * k10), p, t + c10 * dt) - k12 = f(uprev + - dt * - (a1201 * k1 + a1206 * k6 + a1207 * k7 + a1208 * k8 + a1209 * k9 + - a1210 * k10 + - a1211 * k11), p, t + c11 * dt) - k13 = f(uprev + - dt * - (a1301 * k1 + a1306 * k6 + a1307 * k7 + a1308 * k8 + a1309 * k9 + - a1310 * k10 + - a1311 * k11 + a1312 * k12), p, t + c12 * dt) - k14 = f(uprev + - dt * - (a1401 * k1 + a1406 * k6 + a1407 * k7 + a1408 * k8 + a1409 * k9 + - a1410 * k10 + - a1411 * k11 + a1412 * k12 + a1413 * k13), p, t + c13 * dt) + k12 = f( + uprev + + dt * + (a1201 * k1 + a1206 * k6 + a1207 * k7 + a1208 * k8 + a1209 * k9 + + a1210 * k10 + + a1211 * k11), + p, + t + c11 * dt) + k13 = f( + uprev + + dt * + (a1301 * k1 + a1306 * k6 + a1307 * k7 + a1308 * k8 + a1309 * k9 + + a1310 * k10 + + a1311 * k11 + a1312 * k12), + p, + t + c12 * dt) + k14 = f( + uprev + + dt * + (a1401 * k1 + a1406 * k6 + a1407 * k7 + a1408 * k8 + a1409 * k9 + + a1410 * k10 + + a1411 * k11 + a1412 * k12 + a1413 * k13), + p, + t + c13 * dt) g15 = uprev + dt * (a1501 * k1 + a1506 * k6 + a1507 * k7 + a1508 * k8 + a1509 * k9 + diff --git a/src/ensemblegpukernel/problems/ode_problems.jl b/src/ensemblegpukernel/problems/ode_problems.jl index d7f420d9..5017f502 100644 --- a/src/ensemblegpukernel/problems/ode_problems.jl +++ b/src/ensemblegpukernel/problems/ode_problems.jl @@ -1,6 +1,7 @@ import SciMLBase: @add_kwonly, AbstractODEProblem, AbstractODEFunction, - FunctionWrapperSpecialize, StandardODEProblem, prepare_initial_state, promote_tspan, - warn_paramtype + FunctionWrapperSpecialize, StandardODEProblem, prepare_initial_state, + promote_tspan, + warn_paramtype struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <: AbstractODEProblem{uType, tType, isinplace} diff --git a/src/ensemblegpukernel/tableaus/kvaerno_tableaus.jl b/src/ensemblegpukernel/tableaus/kvaerno_tableaus.jl index e1d4bcee..5fc4d0dd 100644 --- a/src/ensemblegpukernel/tableaus/kvaerno_tableaus.jl +++ b/src/ensemblegpukernel/tableaus/kvaerno_tableaus.jl @@ -38,7 +38,8 @@ function Kvaerno3Tableau(T, T2) α32 = ((-2θ + 3θ^2) + (6θ * (1 - θ) / c2) * γ) α41 = convert(T2, 0.0) α42 = convert(T2, 0.0) - Kvaerno3Tableau(γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, + Kvaerno3Tableau( + γ, a31, a32, a41, a42, a43, btilde1, btilde2, btilde3, btilde4, c3, α31, α32, α41, α42) end diff --git a/src/ensemblegpukernel/tableaus/verner_tableaus.jl b/src/ensemblegpukernel/tableaus/verner_tableaus.jl index 939e30bb..681df2df 100644 --- a/src/ensemblegpukernel/tableaus/verner_tableaus.jl +++ b/src/ensemblegpukernel/tableaus/verner_tableaus.jl @@ -438,7 +438,8 @@ function Vern7Tableau(T::Type{T1}, T2::Type{T1}) where {T1} extra = Vern7ExtraStages(T, T2) interp = Vern7InterpolationCoefficients(T) - Vern7Tableau(c2, c3, c4, c5, c6, c7, c8, a021, a031, a032, a041, a043, a051, a053, a054, + Vern7Tableau( + c2, c3, c4, c5, c6, c7, c8, a021, a031, a032, a041, a043, a051, a053, a054, a061, a063, a064, a065, a071, a073, a074, a075, a076, a081, a083, a084, a085, a086, a087, a091, a093, a094, a095, a096, a097, a098, a101, a103, a104, a105, a106, a107, b1, b4, b5, b6, b7, b8, b9, btilde1, btilde4, diff --git a/src/solve.jl b/src/solve.jl index ca41064c..c8326f73 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -64,8 +64,10 @@ function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, converged::Bool = false u = ensembleprob.u_init === nothing ? - similar(batch_solve(ensembleprob, alg, ensemblealg, 1:batch_size, adaptive; - unstable_check = unstable_check, kwargs...), 0) : + similar( + batch_solve(ensembleprob, alg, ensemblealg, 1:batch_size, adaptive; + unstable_check = unstable_check, kwargs...), + 0) : ensembleprob.u_init if nprocs() == 1 @@ -152,10 +154,13 @@ function batch_solve(ensembleprob, alg, get(kwargs, :save_everystep, true) error("Using different time-spans require either turning off save_everystep or using saveat. If using saveat, it should be of same length across the ensemble.") end - if !all(Base.Fix2((prob1, prob2) -> isequal(sizeof(get(prob1.kwargs, :saveat, + if !all( + Base.Fix2( + (prob1, prob2) -> isequal(sizeof(get(prob1.kwargs, :saveat, nothing)), sizeof(get(prob2.kwargs, :saveat, - nothing))), probs[1]), + nothing))), + probs[1]), probs) error("Using different saveat in EnsembleGPUKernel requires all of them to be of same length. Use saveats of same size only.") end @@ -168,26 +173,27 @@ function batch_solve(ensembleprob, alg, solts, solus = batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptive; saveat = saveat, kwargs...) [begin - ts = @view solts[:, i] - us = @view solus[:, i] - sol_idx = findlast(x -> x != probs[i].tspan[1], ts) - if sol_idx === nothing - @error "No solution found" tspan=probs[i].tspan[1] ts - error("Batch solve failed") - end - @views ensembleprob.output_func(SciMLBase.build_solution(probs[i], - alg, - ts[1:sol_idx], - us[1:sol_idx], - k = nothing, - stats = nothing, - calculate_error = false, - retcode = sol_idx != - length(ts) ? - ReturnCode.Terminated : - ReturnCode.Success), - i)[1] - end + ts = @view solts[:, i] + us = @view solus[:, i] + sol_idx = findlast(x -> x != probs[i].tspan[1], ts) + if sol_idx === nothing + @error "No solution found" tspan=probs[i].tspan[1] ts + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution(probs[i], + alg, + ts[1:sol_idx], + us[1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx != + length(ts) ? + ReturnCode.Terminated : + ReturnCode.Success), + i)[1] + end for i in eachindex(probs)] else @@ -227,13 +233,17 @@ function batch_solve(ensembleprob, alg, probs[1] = orig_prob - [ensembleprob.output_func(SciMLBase.build_solution(probs[i], alg, - map(t -> probs[i].tspan[1] + - (probs[i].tspan[2] - - probs[i].tspan[1]) * - t, sol.t), solus[i], - stats = sol.stats, - retcode = sol.retcode), i)[1] + [ensembleprob.output_func( + SciMLBase.build_solution(probs[i], alg, + map( + t -> probs[i].tspan[1] + + (probs[i].tspan[2] - + probs[i].tspan[1]) * + t, + sol.t), solus[i], + stats = sol.stats, + retcode = sol.retcode), + i)[1] for i in 1:length(probs)] else p = reduce(hcat, @@ -241,10 +251,12 @@ function batch_solve(ensembleprob, alg, for i in 1:length(I)) sol, solus = batch_solve_up(ensembleprob, probs, alg, ensemblealg, I, u0, p; adaptive = adaptive, kwargs...) - [ensembleprob.output_func(SciMLBase.build_solution(probs[i], alg, sol.t, - solus[i], - stats = sol.stats, - retcode = sol.retcode), i)[1] + [ensembleprob.output_func( + SciMLBase.build_solution(probs[i], alg, sol.t, + solus[i], + stats = sol.stats, + retcode = sol.retcode), + i)[1] for i in 1:length(probs)] end end @@ -254,7 +266,8 @@ function batch_solve_up_kernel(ensembleprob, probs, alg, ensemblealg, I, adaptiv kwargs...) _callback = CallbackSet(generate_callback(probs[1], length(I), ensemblealg; kwargs...)) - _callback = CallbackSet(convert.(DiffEqGPU.GPUDiscreteCallback, + _callback = CallbackSet( + convert.(DiffEqGPU.GPUDiscreteCallback, _callback.discrete_callbacks)..., convert.(DiffEqGPU.GPUContinuousCallback, _callback.continuous_callbacks)...) diff --git a/test/ensemblegpuarray_oop.jl b/test/ensemblegpuarray_oop.jl index 534ba283..0f99a81c 100644 --- a/test/ensemblegpuarray_oop.jl +++ b/test/ensemblegpuarray_oop.jl @@ -17,8 +17,8 @@ function lorenz_jac(u, p, t) y = u[2] z = u[3] SA[-σ σ 0 - ρ-z -1 -x - y x -β] + ρ-z -1 -x + y x -β] end function lorenz_tgrad(u, p, t) diff --git a/test/gpu_kernel_de/forward_diff.jl b/test/gpu_kernel_de/forward_diff.jl index c3ba5662..076bede5 100644 --- a/test/gpu_kernel_de/forward_diff.jl +++ b/test/gpu_kernel_de/forward_diff.jl @@ -15,13 +15,13 @@ function lorenz(u, p, t) end u0 = @SVector [ForwardDiff.Dual(1.0f0, (1.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0)); - ForwardDiff.Dual(0.0f0, (0.0f0, 1.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0)); - ForwardDiff.Dual(0.0f0, (0.0f0, 0.0f0, 1.0f0, 0.0f0, 0.0f0, 0.0f0))] + ForwardDiff.Dual(0.0f0, (0.0f0, 1.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0)); + ForwardDiff.Dual(0.0f0, (0.0f0, 0.0f0, 1.0f0, 0.0f0, 0.0f0, 0.0f0))] p = @SVector [ ForwardDiff.Dual(10.0f0, (0.0f0, 0.0f0, 0.0f0, 1.0f0, 0.0f0, 0.0f0)), ForwardDiff.Dual(28.0f0, (0.0f0, 0.0f0, 0.0f0, 0.0f0, 1.0f0, 0.0f0)), - ForwardDiff.Dual(8 / 3.0f0, (0.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0, 1.0f0)), + ForwardDiff.Dual(8 / 3.0f0, (0.0f0, 0.0f0, 0.0f0, 0.0f0, 0.0f0, 1.0f0)) ] tspan = (0.0f0, 10.0f0) diff --git a/test/gpu_kernel_de/gpu_sde_regression.jl b/test/gpu_kernel_de/gpu_sde_regression.jl index 63dd0562..b8101a1e 100644 --- a/test/gpu_kernel_de/gpu_sde_regression.jl +++ b/test/gpu_kernel_de/gpu_sde_regression.jl @@ -56,12 +56,14 @@ for alg in algs monteprob = EnsembleProblem(prob) dt = Float32(1 // 2^(8)) - sol = solve(monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, + sol = solve( + monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, adaptive = false) @test sol.converged == true - sol = solve(monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, + sol = solve( + monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, adaptive = false, save_everystep = false) @test sol.converged == true @@ -69,7 +71,8 @@ for alg in algs saveat = [0.3f0, 0.5f0] - sol = solve(monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, + sol = solve( + monteprob, alg, EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, adaptive = false, saveat = saveat) end @@ -97,12 +100,14 @@ noise_rate_prototype = @SMatrix zeros(Float32, 2, 4) prob = SDEProblem(f, g, u0, (0.0f0, 1.0f0), noise_rate_prototype = noise_rate_prototype) monteprob = EnsembleProblem(prob) -sol = solve(monteprob, GPUEM(), EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, +sol = solve( + monteprob, GPUEM(), EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, adaptive = false) @test sol.converged == true -sol = solve(monteprob, GPUEM(), EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, +sol = solve( + monteprob, GPUEM(), EnsembleGPUKernel(backend, 0.0), dt = dt, trajectories = 10, adaptive = false, save_everystep = false) @test sol.converged == true diff --git a/test/gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl b/test/gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl index e18ccfe5..11f159a4 100644 --- a/test/gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl +++ b/test/gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl @@ -14,12 +14,12 @@ function rober_jac(u, p, t) y₁, y₂, y₃ = u k₁, k₂, k₃ = p return @SMatrix[(k₁*-1) (y₃*k₃) (k₃*y₂) - k₁ (y₂ * k₂ * -2+y₃ * k₃ * -1) (k₃*y₂*-1) - 0 (y₂*2*k₂) (0)] + k₁ (y₂ * k₂ * -2+y₃ * k₃ * -1) (k₃*y₂*-1) + 0 (y₂*2*k₂) (0)] end M = @SMatrix [1.0f0 0.0f0 0.0f0 - 0.0f0 1.0f0 0.0f0 - 0.0f0 0.0f0 0.0f0] + 0.0f0 1.0f0 0.0f0 + 0.0f0 0.0f0 0.0f0] ff = ODEFunction(rober, mass_matrix = M) prob = ODEProblem(ff, @SVector([1.0f0, 0.0f0, 0.0f0]), (0.0f0, 1.0f5), (0.04f0, 3.0f7, 1.0f4)) diff --git a/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl b/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl index 01cb848f..8e393b3c 100644 --- a/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl +++ b/test/gpu_kernel_de/stiff_ode/gpu_ode_regression.jl @@ -134,7 +134,8 @@ for alg in algs if GROUP == "CUDA" monteprob = EnsembleProblem(large_prob, safetycopy = false) - local sol = solve(monteprob, alg, EnsembleGPUKernel(backend, 0.0), trajectories = 2, + local sol = solve( + monteprob, alg, EnsembleGPUKernel(backend, 0.0), trajectories = 2, adaptive = true, dt = 0.1f0) end end