Skip to content

Commit

Permalink
Merge pull request #280 from tshort/mass-matrix
Browse files Browse the repository at this point in the history
WIP: Mass matrix support for rosenbrock23
  • Loading branch information
utkarsh530 authored Jun 1, 2023
2 parents b56f1ed + 0277230 commit 187648e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ oneAPIExt = ["oneAPI"]
[compat]
AMDGPU = "0.4.9"
Adapt = "3"
CUDA = "4.1"
CUDA = "4.1.0"
ChainRulesCore = "1"
DiffEqBase = "6.122"
DocStringExtensions = "0.8, 0.9"
Expand Down
28 changes: 23 additions & 5 deletions src/perform_step/gpu_rosenbrock23_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
Tgrad = build_tgrad(f)
dT = Tgrad(uprev, p, t)

W = I - γ * J
mass_matrix = integ.f.mass_matrix
W = mass_matrix - γ * J
W_fact = W

# F = lu(W)
Expand All @@ -50,7 +51,11 @@

F₁ = f(uprev + dto2 * k1, p, t + dto2)

k2 = W_fact \ (F₁ - k1) + k1
if mass_matrix === I
k2 = W_fact \ (F₁ - k1) + k1
else
k2 = W_fact \ (F₁ - mass_matrix * k1) + k1
end

integ.u = uprev + dt * k2

Expand Down Expand Up @@ -149,6 +154,8 @@ end

EEst = convert(T, Inf)

mass_matrix = integ.f.mass_matrix

while EEst > convert(T, 1.0)
dt < convert(T, 1.0f-14) && error("dt<dtmin")

Expand All @@ -163,7 +170,7 @@ end
Tgrad = build_tgrad(f)
dT = Tgrad(uprev, p, t)

W = I - γ * J
W = mass_matrix - γ * J
W_fact = W

# F = lu(W)
Expand All @@ -172,13 +179,24 @@ end

F₁ = f(uprev + dto2 * k1, p, t + dto2)

k2 = W_fact \ (F₁ - k1) + k1
if mass_matrix === I
k2 = W_fact \ (F₁ - k1) + k1
else
k2 = W_fact \ (F₁ - mass_matrix * k1) + k1
end

u = uprev + dt * k2

e32 = T(6) + sqrt(T(2))
F₂ = f(u, p, t + dt)
k3 = W_fact \ (F₂ - e32 * (k2 - F₁) - 2 * (k1 - F₀) + dt * dT)

if mass_matrix === I
k3 = W_fact \ (F₂ - e32 * (k2 - F₁) - 2 * (k1 - F₀) + dt * dT)

else
k3 = W_fact \ (F₂ - mass_matrix * (e32 * k2 + 2 * k1) +
e32 * F₁ + 2 * F₀ + dt * dT)
end

tmp = dto6 * (k1 - 2 * k2 + k3)
tmp = tmp ./ (abstol .+ max.(abs.(uprev), abs.(u)) * reltol)
Expand Down
39 changes: 39 additions & 0 deletions test/gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using DiffEqGPU, StaticArrays, OrdinaryDiffEq, LinearAlgebra

include("../../utils.jl")

function rober(u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
return @SVector [
-k₁ * y₁ + k₃ * y₂ * y₃,
k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃,
y₁ + y₂ + y₃ - 1]
end
function rober_jac(u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
return @SMatrix[(k₁*-1) (y₃*k₃) (k₃*y₂)
k₁ (y₂ * k₂ * -2+y₃ * k₃ * -1) (k₃*y₂*-1)
0 (y₂*2*k₂) (0)]
end
M = @SMatrix [1.0f0 0.0f0 0.0f0
0.0f0 1.0f0 0.0f0
0.0f0 0.0f0 0.0f0]
ff = ODEFunction(rober, mass_matrix = M)
prob = ODEProblem(ff, @SVector([1.0f0, 0.0f0, 0.0f0]), (0.0f0, 1.0f5),
(0.04f0, 3.0f7, 1.0f4))

monteprob = EnsembleProblem(prob, safetycopy = false)

alg = GPURosenbrock23()

bench_sol = solve(prob, Rosenbrock23(), dt = 0.1, abstol = 1.0f-5, reltol = 1.0f-5)

sol = solve(monteprob, alg, EnsembleGPUKernel(backend),
trajectories = 2,
dt = 0.1f0,
adaptive = true, abstol = 1.0f-5, reltol = 1.0f-5)

@test norm(bench_sol.u[1] - sol[1].u[1]) < 8e-4
@test norm(bench_sol.u[end] - sol[1].u[end]) < 8e-4
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const GROUP = get(ENV, "GROUP", "CUDA")

using SafeTestsets, Test

@time @safetestset "GPU Kernelized Stiff ODE Mass Matrix" begin include("gpu_kernel_de/stiff_ode/gpu_ode_mass_matrix.jl") end
@time @testset "GPU Kernelized Non Stiff ODE Regression" begin include("gpu_kernel_de/gpu_ode_regression.jl") end
@time @safetestset "GPU Kernelized Non Stiff ODE DiscreteCallback" begin include("gpu_kernel_de/gpu_ode_discrete_callbacks.jl") end
@time @testset "GPU Kernelized Stiff ODE Regression" begin include("gpu_kernel_de/stiff_ode/gpu_ode_regression.jl") end
Expand Down

0 comments on commit 187648e

Please sign in to comment.