Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LazyBroadcast for strain_rate methods #3575

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,8 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_env_closures!
ᶠu⁰ = p.scratch.ᶠtemp_C123
@. ᶠu⁰ = C123(ᶠinterp(Y.c.uₕ)) + C123(ᶠu³⁰)
ᶜstrain_rate = p.scratch.ᶜtemp_UVWxUVW
compute_strain_rate_center!(ᶜstrain_rate, ᶠu⁰)
bc_strain_rate = compute_strain_rate_center(ᶠu⁰)
@. ᶜstrain_rate = bc_strain_rate
@. ᶜstrain_rate_norm = norm_sqr(ᶜstrain_rate)

ᶜprandtl_nvec = p.scratch.ᶜtemp_scalar
Expand Down
3 changes: 2 additions & 1 deletion src/cache/prognostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ NVTX.@annotate function set_prognostic_edmf_precomputed_quantities_closures!(
ᶠu⁰ = p.scratch.ᶠtemp_C123
@. ᶠu⁰ = C123(ᶠinterp(Y.c.uₕ)) + C123(ᶠu³⁰)
ᶜstrain_rate = p.scratch.ᶜtemp_UVWxUVW
compute_strain_rate_center!(ᶜstrain_rate, ᶠu⁰)
bc_strain_rate = compute_strain_rate_center(ᶠu⁰)
@. ᶜstrain_rate = bc_strain_rate
@. ᶜstrain_rate_norm = norm_sqr(ᶜstrain_rate)

ᶜprandtl_nvec = p.scratch.ᶜtemp_scalar
Expand Down
6 changes: 4 additions & 2 deletions src/prognostic_equations/edmfx_sgs_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ function edmfx_sgs_diffusive_flux_tendency!(

# momentum
ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
compute_strain_rate_face!(ᶠstrain_rate, ᶜu⁰)
bc_strain_rate = compute_strain_rate_face(ᶜu⁰)
@. ᶠstrain_rate = bc_strain_rate
@. Yₜ.c.uₕ -= C12(ᶜdivᵥ(-(2 * ᶠρaK_u * ᶠstrain_rate)) / Y.c.ρ)
# apply boundary condition for momentum flux
ᶜdivᵥ_uₕ = Operators.DivergenceF2C(
Expand Down Expand Up @@ -292,7 +293,8 @@ function edmfx_sgs_diffusive_flux_tendency!(

# momentum
ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
compute_strain_rate_face!(ᶠstrain_rate, ᶜu)
bc_strain_rate = compute_strain_rate_face(ᶜu)
@. ᶠstrain_rate = bc_strain_rate
@. Yₜ.c.uₕ -= C12(ᶜdivᵥ(-(2 * ᶠρaK_u * ᶠstrain_rate)) / Y.c.ρ)
# apply boundary condition for momentum flux
ᶜdivᵥ_uₕ = Operators.DivergenceF2C(
Expand Down
3 changes: 2 additions & 1 deletion src/prognostic_equations/gm_sgs_closures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ NVTX.@annotate function compute_gm_mixing_length!(ᶜmixing_length, Y, p)
ᶠu = p.scratch.ᶠtemp_C123
@. ᶠu = C123(ᶠinterp(Y.c.uₕ)) + C123(ᶠu³)
ᶜstrain_rate = p.scratch.ᶜtemp_UVWxUVW
compute_strain_rate_center!(ᶜstrain_rate, ᶠu)
bc_strain_rate = compute_strain_rate_center(ᶠu)
@. ᶜstrain_rate = bc_strain_rate
@. ᶜstrain_rate_norm = norm_sqr(ᶜstrain_rate)

ᶜprandtl_nvec = p.scratch.ᶜtemp_scalar_2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ function vertical_diffusion_boundary_layer_tendency!(

if diffuse_momentum(p.atmos.vert_diff)
ᶠstrain_rate = p.scratch.ᶠtemp_UVWxUVW
compute_strain_rate_face!(ᶠstrain_rate, ᶜu)
bc_strain_rate = compute_strain_rate_face(ᶜu)
@. ᶠstrain_rate = bc_strain_rate
@. Yₜ.c.uₕ -= C12(
ᶜdivᵥ(-2 * ᶠinterp(Y.c.ρ) * ᶠinterp(ᶜK_u) * ᶠstrain_rate) / Y.c.ρ,
)
Expand Down
34 changes: 16 additions & 18 deletions src/utils/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,27 @@ state.
compute_kinetic(Y::Fields.FieldVector) = compute_kinetic(Y.c.uₕ, Y.f.u₃)

"""
compute_strain_rate_center!(ϵ::Field, u::Field)
bc_ϵ = compute_strain_rate_center(u::Field)
@. ϵ = bc_ϵ

Compute the strain_rate at cell centers, storing in `ϵ` from
velocity at cell faces.
Compute the strain_rate at cell centers from velocity at cell faces.
"""
function compute_strain_rate_center!(ϵ::Fields.Field, u::Fields.Field)
function compute_strain_rate_center(u::Fields.Field)
@assert eltype(u) <: C123
axis_uvw = Geometry.UVWAxis()
@. ϵ =
(
Geometry.project((axis_uvw,), ᶜgradᵥ(UVW(u))) +
adjoint(Geometry.project((axis_uvw,), ᶜgradᵥ(UVW(u))))
) / 2
return @lazy @. (
Geometry.project((axis_uvw,), ᶜgradᵥ(UVW(u))) +
adjoint(Geometry.project((axis_uvw,), ᶜgradᵥ(UVW(u))))
) / 2
end

"""
compute_strain_rate_face!(ϵ::Field, u::Field)
bc_ϵ = compute_strain_rate_face(u::Field)
@. ϵ = bc_ϵ

Compute the strain_rate at cell faces, storing in `ϵ` from
velocity at cell centers.
Compute the strain_rate at cell faces from velocity at cell centers.
"""
function compute_strain_rate_face!(ϵ::Fields.Field, u::Fields.Field)
function compute_strain_rate_face(u::Fields.Field)
@assert eltype(u) <: C123
∇ᵥuvw_boundary =
Geometry.outer(Geometry.WVector(0), Geometry.UVWVector(0, 0, 0))
Expand All @@ -108,11 +107,10 @@ function compute_strain_rate_face!(ϵ::Fields.Field, u::Fields.Field)
top = Operators.SetGradient(∇ᵥuvw_boundary),
)
axis_uvw = Geometry.UVWAxis()
@. ϵ =
(
Geometry.project((axis_uvw,), ᶠgradᵥ(UVW(u))) +
adjoint(Geometry.project((axis_uvw,), ᶠgradᵥ(UVW(u))))
) / 2
return @lazy @. (
Geometry.project((axis_uvw,), ᶠgradᵥ(UVW(u))) +
adjoint(Geometry.project((axis_uvw,), ᶠgradᵥ(UVW(u))))
) / 2
end

"""
Expand Down
8 changes: 6 additions & 2 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ end
ᶠu = @. UVW(Geometry.UVector(ᶠu)) +
UVW(Geometry.VVector(ᶠv)) +
UVW(Geometry.WVector(ᶠw))
CA.compute_strain_rate_center!(ᶜϵ, Geometry.Covariant123Vector.(ᶠu))
CA.compute_strain_rate_face!(ᶠϵ, Geometry.Covariant123Vector.(ᶜu))
bc_strain_rate =
CA.compute_strain_rate_center(Geometry.Covariant123Vector.(ᶠu))
@. ᶜϵ = bc_strain_rate
bc_strain_rate =
CA.compute_strain_rate_face(Geometry.Covariant123Vector.(ᶜu))
@. ᶠϵ = bc_strain_rate

# Center valued strain rate
@test ᶜϵ.components.data.:1 == ᶜϵ.components.data.:1 .* FT(0)
Expand Down
Loading