diff --git a/src/ensemblegpukernel/lowerlevel_solve.jl b/src/ensemblegpukernel/lowerlevel_solve.jl index b3b97973..68e306ee 100644 --- a/src/ensemblegpukernel/lowerlevel_solve.jl +++ b/src/ensemblegpukernel/lowerlevel_solve.jl @@ -206,13 +206,27 @@ function vectorized_asolve(probs, prob::ODEProblem, alg; us = allocate(backend, typeof(prob.u0), (len, length(probs))) else saveat = if saveat isa AbstractRange - range(convert(eltype(prob.tspan), first(saveat)), + _saveat = range(convert(eltype(prob.tspan), first(saveat)), convert(eltype(prob.tspan), last(saveat)), length = length(saveat)) + convert(StepRangeLen{ + eltype(_saveat), + eltype(_saveat), + eltype(_saveat), + eltype(_saveat) === Float32 ? Int32 : Int64, + }, + _saveat) elseif saveat isa AbstractVector adapt(backend, convert.(eltype(prob.tspan), saveat)) else - prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] + _saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end] + convert(StepRangeLen{ + eltype(_saveat), + eltype(_saveat), + eltype(_saveat), + eltype(_saveat) === Float32 ? Int32 : Int64, + }, + _saveat) end ts = allocate(backend, typeof(dt), (length(saveat), length(probs))) fill!(ts, prob.tspan[1]) diff --git a/test/gpu_kernel_de/conversions.jl b/test/gpu_kernel_de/conversions.jl index d524a929..f7d3c62f 100644 --- a/test/gpu_kernel_de/conversions.jl +++ b/test/gpu_kernel_de/conversions.jl @@ -18,25 +18,21 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p) prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p) monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) -## Don't test the problems in which GPUs don't support FP64 completely yet -## Creating StepRangeLen causes some param types to be FP64 inferred by `float` function -if ENV["GROUP"] ∉ ("Metal", "oneAPI") - @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), - trajectories = 10_000, - saveat = 1:10)[1].t == Float32.(1:10) - - @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), - trajectories = 10_000, - saveat = 1:0.1:10)[1].t == 1.0f0:0.1f0:10.0f0 - - @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), - trajectories = 10_000, - saveat = 1:(1.0f0):10)[1].t == 1:1.0f0:10 - - @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), - trajectories = 10_000, - saveat = 1.0)[1].t == 0.0f0:1.0f0:10.0f0 -end +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:10)[1].t == Float32.(1:10) + +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:0.1:10)[1].t == 1.0f0:0.1f0:10.0f0 + +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1:(1.0f0):10)[1].t == 1:1.0f0:10 + +@test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), + trajectories = 10_000, + saveat = 1.0)[1].t == 0.0f0:1.0f0:10.0f0 @test solve(monteprob, GPUTsit5(), EnsembleGPUKernel(backend), trajectories = 10_000,