Skip to content

Commit

Permalink
Remove KernelAbstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
ph-kev committed Jan 31, 2025
1 parent 8c32300 commit be8fb76
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 84 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ authors = ["Climate Modeling Alliance"]
version = "0.12.9"

[deps]
# Required for backwards compatibility with Julia <1.9
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RootSolvers = "7181ea78-2dcb-4de3-ab41-2b8ab5a31e74"
# Required for backwards compatibility with Julia <1.9
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"

[weakdeps]
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

[extensions]
CreateParametersExt = "ClimaParams"
KernelAbstractionsExt = "KernelAbstractions"

[compat]
ClimaParams = "0.10"
Expand Down
8 changes: 8 additions & 0 deletions ext/KernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module KernelAbstractionsExt

import Thermodynamics.Utils as Util
import KernelAbstractions as KA

Util.ka_print(args...) = KA.@print(args...)

end
5 changes: 2 additions & 3 deletions src/Thermodynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ const DSE = DocStringExtensions
import RootSolvers
const RS = RootSolvers

import KernelAbstractions
const KA = KernelAbstractions

include("Utils.jl")
export Utils
include("Parameters.jl")
import .Parameters
const TP = Parameters
Expand Down
23 changes: 23 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module Utils

import Thermodynamics

"""
print_(args...)
Print whatever is in `args`. Compatible with both CPU and GPU.
If CUDA is not loaded, then KernelAbstractions is not loaded which means
`Base.print` is called. If CUDA is loaded, then KernelAbstractions is loaded and
`KernelAbstractions.@print` is called.
"""
function print_(args...)
if !isnothing(Base.get_extension(Thermodynamics, :KernelAbstractionsExt))
return ka_print(args...)
end
return Base.print(args...)
end

function ka_print end

end
18 changes: 9 additions & 9 deletions src/config_numerical_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
# saturation adjustment, for different combinations
# of thermodynamic variable inputs.

# KA.@print only accepts literal strings, so we must
# print only accepts literal strings, so we must
# branch to print which method is being used.
@inline function print_numerical_method(
::Type{sat_adjust_method},
) where {sat_adjust_method}
if sat_adjust_method <: RS.NewtonsMethod
KA.@print(" Method=NewtonsMethod")
Utils.print_(" Method=NewtonsMethod")
elseif sat_adjust_method <: RS.NewtonsMethodAD
KA.@print(" Method=NewtonsMethodAD")
Utils.print_(" Method=NewtonsMethodAD")
elseif sat_adjust_method <: RS.SecantMethod
KA.@print(" Method=SecantMethod")
Utils.print_(" Method=SecantMethod")
elseif sat_adjust_method <: RS.RegulaFalsiMethod
KA.@print(" Method=RegulaFalsiMethod")
Utils.print_(" Method=RegulaFalsiMethod")
else
error("Unsupported numerical method")
end
Expand All @@ -26,9 +26,9 @@ end
T_guess::Real,
) where {sat_adjust_method}
if sat_adjust_method <: RS.NewtonsMethod
KA.@print(", T_guess=", T_guess)
Utils.print_(", T_guess=", T_guess)
elseif sat_adjust_method <: RS.NewtonsMethodAD
KA.@print(", T_guess=", T_guess)
Utils.print_(", T_guess=", T_guess)
end
end

Expand All @@ -37,9 +37,9 @@ end
T_guess::Nothing,
) where {sat_adjust_method}
if sat_adjust_method <: RS.NewtonsMethod
KA.@print(", T_guess=nothing")
Utils.print_(", T_guess=nothing")
elseif sat_adjust_method <: RS.NewtonsMethodAD
KA.@print(", T_guess=nothing")
Utils.print_(", T_guess=nothing")
end
end

Expand Down
138 changes: 69 additions & 69 deletions src/relations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1644,16 +1644,16 @@ See also [`saturation_adjustment`](@ref).
DataCollection.log_meta(sol)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment:\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment:\n")
print_numerical_method(sat_adjust_method)
print_T_guess(sat_adjust_method, T_guess)
KA.@print(", e_int=", e_int)
KA.@print(", ρ=", ρ)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(", e_int=", e_int)
Utils.print_(", ρ=", ρ)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -1748,16 +1748,16 @@ See also [`saturation_adjustment`](@ref).
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment_peq:\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment_peq:\n")
print_numerical_method(sat_adjust_method)
print_T_guess(sat_adjust_method, T_guess)
KA.@print(", e_int=", e_int)
KA.@print(", p=", p)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(", e_int=", e_int)
Utils.print_(", p=", p)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -1861,16 +1861,16 @@ See also [`saturation_adjustment`](@ref).
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment_phq:\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment_phq:\n")
print_numerical_method(sat_adjust_method)
print_T_guess(sat_adjust_method, T_guess)
KA.@print(", h=", h)
KA.@print(", p=", p)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(", h=", h)
Utils.print_(", p=", p)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -1958,16 +1958,16 @@ See also [`saturation_adjustment`](@ref).
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment_ρpq:\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment_ρpq:\n")
print_numerical_method(sat_adjust_method)
print_T_guess(sat_adjust_method, T_guess)
KA.@print(", ρ=", ρ)
KA.@print(", p=", p)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(", ρ=", ρ)
Utils.print_(", p=", p)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -2066,15 +2066,15 @@ See also [`saturation_adjustment`](@ref).
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment_given_ρθq:\n")
KA.@print(" Method=SecantMethod")
KA.@print(", ρ=", ρ)
KA.@print(", θ_liq_ice=", θ_liq_ice)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment_given_ρθq:\n")
Utils.print_(" Method=SecantMethod")
Utils.print_(", ρ=", ρ)
Utils.print_(", θ_liq_ice=", θ_liq_ice)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -2181,16 +2181,16 @@ See also [`saturation_adjustment`](@ref).
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print("maxiter reached in saturation_adjustment_given_pθq:\n")
Utils.print_("-----------------------------------------\n")
Utils.print_("maxiter reached in saturation_adjustment_given_pθq:\n")
print_numerical_method(sat_adjust_method)
print_T_guess(sat_adjust_method, T_guess)
KA.@print(", p=", p)
KA.@print(", θ_liq_ice=", θ_liq_ice)
KA.@print(", q_tot=", q_tot)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(", p=", p)
Utils.print_(", θ_liq_ice=", θ_liq_ice)
Utils.print_(", q_tot=", q_tot)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -2382,17 +2382,17 @@ The air temperature and `q_tot` where
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print(
Utils.print_("-----------------------------------------\n")
Utils.print_(
"maxiter reached in temperature_and_humidity_given_TᵥρRH:\n"
)
KA.@print(" Method=SecantMethod")
KA.@print(", T_virt=", T_virt)
KA.@print(", RH=", RH)
KA.@print(", ρ=", ρ)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(" Method=SecantMethod")
Utils.print_(", T_virt=", T_virt)
Utils.print_(", RH=", RH)
Utils.print_(", ρ=", ρ)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down Expand Up @@ -2483,19 +2483,19 @@ by finding the root of
)
if !sol.converged
if print_warning()
KA.@print("-----------------------------------------\n")
KA.@print(
Utils.print_("-----------------------------------------\n")
Utils.print_(
"maxiter reached in air_temperature_given_ρθq_nonlinear:\n"
)
KA.@print(" Method=SecantMethod")
KA.@print(", θ_liq_ice=", θ_liq_ice)
KA.@print(", ρ=", ρ)
KA.@print(", q.tot=", q.tot)
KA.@print(", q.liq=", q.liq)
KA.@print(", q.ice=", q.ice)
KA.@print(", T=", sol.root)
KA.@print(", maxiter=", maxiter)
KA.@print(", tol=", tol.tol, "\n")
Utils.print_(" Method=SecantMethod")
Utils.print_(", θ_liq_ice=", θ_liq_ice)
Utils.print_(", ρ=", ρ)
Utils.print_(", q.tot=", q.tot)
Utils.print_(", q.liq=", q.liq)
Utils.print_(", q.ice=", q.ice)
Utils.print_(", T=", sol.root)
Utils.print_(", maxiter=", maxiter)
Utils.print_(", tol=", tol.tol, "\n")
end
if error_on_non_convergence()
error("Failed to converge with printed set of inputs.")
Expand Down

0 comments on commit be8fb76

Please sign in to comment.