Skip to content

Commit

Permalink
Extend Remapper to FiniteDifferenceSpaces, update interpolate interface
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski authored and Sbozzolo committed Oct 28, 2024
1 parent 7297f5d commit 1504d19
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 34 deletions.
34 changes: 34 additions & 0 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,40 @@ function set_interpolated_values_kernel!(
return nothing
end

# GPU, vertical case
function set_interpolated_values_kernel!(
out::AbstractArray,
::Nothing,
::Nothing,
vert_interpolation_weights,
vert_bounding_indices,
::Nothing,
)
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
num_fields = length(field_values)

vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z

totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z

CI = CartesianIndex
for j in vindex:totalThreadsY:num_vert
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if j num_vert && k num_fields
out[j, k] = (
A * field_values[k][CI(1, 1, 1, v_lo, 1)] +
B * field_values[k][CI(1, 1, 1, v_hi, 1)]
)
end
end
end
return nothing
end

function _set_interpolated_values_device!(
out::AbstractArray,
fields::AbstractArray{<:Fields.Field},
Expand Down
232 changes: 201 additions & 31 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
# process-local interpolate points in the correct shape with respect to the global
# interpolation (and where to collect results)
#
# Horizontal and vertical interpolation can be switch off, so that we interpolate purely
# horizontal/vertical Fields.
#
# To process multiple Fields at the same time, some of the scratch spaces gain an extra
# dimension (buffer_length). With this extra dimension, we can batch the work and process up
# to buffer_length fields at the same time. This reduces the number of kernel launches and
Expand Down Expand Up @@ -116,18 +119,22 @@ function target_hcoords_pid_bitmask(target_hcoords, topology, pid)
return pid_hcoord.(target_hcoords) .== pid
end


# TODO: define an inner construct and restrict types, as was done in
# https://github.com/CliMA/RRTMGP.jl/pull/352
# to avoid potential compilation issues.
struct Remapper{
CC <: ClimaComms.AbstractCommsContext,
SPACE <: Spaces.AbstractSpace,
T1 <: AbstractArray,
T1, # <: Union{AbstractArray, Nothing},
TARG_Z <: Union{Nothing, AA1} where {AA1 <: AbstractArray},
T3 <: AbstractArray,
T4 <: Tuple,
T5 <: AbstractArray,
T3, # <: Union{AbstractArray, Nothing},
T4, # <: Union{Tuple, Nothing},
T5, # <: Union{AbstractArray, Nothing},
VERT_W <: Union{Nothing, AA2} where {AA2 <: AbstractArray},
VERT_IND <: Union{Nothing, AA3} where {AA3 <: AbstractArray},
T8 <: AbstractArray,
T9 <: AbstractArray,
T8, # <: AbstractArray,
T9, # <: AbstractArray,
T10 <: AbstractArray,
T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}},
}
Expand Down Expand Up @@ -200,13 +207,16 @@ struct Remapper{
end

"""
Remapper(space, target_hcoords, target_zcoords; buffer_length = 1)
Remapper(space, target_hcoords, target_zcoords, buffer_length = 1)
Remapper(space; target_hcoords, target_zcoords, buffer_length = 1)
Remapper(space, target_hcoords; buffer_length = 1)
Remapper(space, target_zcoords; buffer_length = 1)
Return a `Remapper` responsible for interpolating any `Field` defined on the given `space`
to the Cartesian product of `target_hcoords` with `target_zcoords`.
`target_zcoords` can be `nothing` for interpolation on horizontal spaces.
`target_zcoords` can be `nothing` for interpolation on horizontal spaces. Similarly,
`target_hcoords` can be `nothing` for interpolation on vertical spaces.
The `Remapper` is designed to not be tied to any particular `Field`. You can use the same
`Remapper` for any `Field` as long as they are all defined on the same `topology`.
Expand All @@ -220,13 +230,40 @@ Keyword arguments
for interpolation. Effectively, this controls how many fields can be remapped simultaneously
in `interpolate`. When more fields than `buffer_length` are passed, the remapper will batch
the work in sizes of `buffer_length`.
"""
function Remapper(
function Remapper end

Remapper(
space::Spaces.AbstractSpace;
target_hcoords::Union{AbstractArray, Nothing} = nothing,
target_zcoords::Union{AbstractArray, Nothing} = nothing,
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length, target_hcoords)

Remapper(
space::Spaces.FiniteDifferenceSpace;
target_zcoords::AbstractArray,
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length)

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray,
target_zcoords::Union{AbstractArray, Nothing};
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords, buffer_length, target_hcoords)

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray;
buffer_length::Int = 1,
) = _Remapper(space; target_zcoords = nothing, target_hcoords, buffer_length)

function _Remapper(
space::Spaces.AbstractSpace;
target_zcoords::Union{AbstractArray, Nothing},
target_hcoords::AbstractArray,
buffer_length::Int = 1,
)

comms_ctx = ClimaComms.context(space)
Expand Down Expand Up @@ -367,11 +404,44 @@ function Remapper(
)
end

Remapper(
space::Spaces.AbstractSpace,
target_hcoords::AbstractArray;
function _Remapper(
space::Spaces.FiniteDifferenceSpace;
target_zcoords::AbstractArray,
buffer_length::Int = 1,
) = Remapper(space, target_hcoords, nothing; buffer_length)
)
# PUrely vertical case
comms_ctx = ClimaComms.context(space)
FT = Spaces.undertype(space)
ArrayType = ClimaComms.array_type(space)

vert_interpolation_weights =
ArrayType(vertical_interpolation_weights(space, target_zcoords))
vert_bounding_indices =
ArrayType(vertical_bounding_indices(space, target_zcoords))

local_interpolated_values =
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))
interpolated_values =
ArrayType(zeros(FT, (length(target_zcoords), buffer_length)))
colons = (:,)

return Remapper(
comms_ctx,
space,
nothing, # local_target_hcoords,
target_zcoords,
nothing, # local_target_hcoords_bitmask,
nothing, # local_horiz_interpolation_weights,
nothing, # local_horiz_indices,
vert_interpolation_weights,
vert_bounding_indices,
local_interpolated_values,
nothing, # field_values,
interpolated_values,
buffer_length,
colons,
)
end

"""
_set_interpolated_values!(remapper, field)
Expand Down Expand Up @@ -439,6 +509,38 @@ function set_interpolated_values_cpu_kernel!(
end
end

# CPU, vertical case
function set_interpolated_values_cpu_kernel!(
out::AbstractArray,
fields::AbstractArray{<:Fields.Field},
::Nothing,
::Nothing,
vert_interpolation_weights,
vert_bounding_indices,
::Nothing,
)
space = axes(first(fields))
FT = Spaces.undertype(space)
for (field_index, field) in enumerate(fields)
field_values = Fields.field_values(field)

# Reading values from field_values is expensive, so we try to limit the number of reads. We can do
# this because multiple target points might be all contained in the same element.
prev_vindex = -1
@inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights)
(v_lo, v_hi) = vert_bounding_indices[vindex]
# If we are no longer in the same element, read the field values again
if prev_vindex != vindex
out[vindex, field_index] = (
A * field_values[CartesianIndex(1, 1, 1, v_lo, 1)] +
B * field_values[CartesianIndex(1, 1, 1, v_hi, 1)]
)
prev_vindex = vindex
end
end
end
end

# CPU, 2D case
function set_interpolated_values_cpu_kernel!(
out::AbstractArray,
Expand Down Expand Up @@ -757,6 +859,8 @@ function interpolate(remapper::Remapper, fields)
error("Field is defined on a different space than remapper")
end

isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

index_field_begin, index_field_end =
1, min(length(fields), remapper.buffer_length)

Expand All @@ -777,13 +881,15 @@ function interpolate(remapper::Remapper, fields)
remapper,
view(fields, index_field_begin:index_field_end),
)
# Reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful. This also
# moves the data off the GPU
ret = _collect_and_return_interpolated_values!(remapper, num_fields)
return ret
# obtain the final answer. Only the root will contain something useful.
return _collect_and_return_interpolated_values!(remapper, num_fields)
end

# Non-root processes
Expand All @@ -794,7 +900,12 @@ function interpolate(remapper::Remapper, fields)
end

"""
interpolate(field::ClimaCore.Fields, target_hcoords, target_zcoords)
interpolate(field::ClimaCore.Fields;
hresolution = 180,
resolution = 50,
target_hcoords = get_target_hcoords(space; hresolution),
target_zcoords = get_target_cords(space; vresolution)
)
Interpolate the given fields on the Cartesian product of `target_hcoords` with
`target_zcoords` (if not empty).
Expand All @@ -810,22 +921,77 @@ Example
Given `field`, a `Field` defined on a cubed sphere.
By default, a target uniform grid is chosen (with resolution `hresolution` and
`vresolution`), so remapping is simply
```julia
longpts = range(-180.0, 180.0, 21)
latpts = range(-80.0, 80.0, 21)
zpts = range(0.0, 1000.0, 21)
julia> interpolate(field, hcoords, zcoords)
```
hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
zcoords = [Geometry.ZPoint(z) for z in zpts]
Coordinates can be specified:
```julia
julia> longpts = range(-180.0, 180.0, 21)
julia> latpts = range(-80.0, 80.0, 21)
julia> zpts = range(0.0, 1000.0, 21)
interpolate(field, hcoords, zcoords)
julia> hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
julia> zcoords = [Geometry.ZPoint(z) for z in zpts]
julia> interpolate(field, hcoords, zcoords)
```
"""
function interpolate(
field::Fields.Field;
vresolution = 50,
hresolution = 100,
target_hcoords = get_target_hcoords(axes(field); hresolution),
target_zcoords = get_target_zcoords(axes(field); vresolution),
)
return interpolate(field, axes(field); hresolution, vresolution)
end

function interpolate(field::Fields.Field, target_hcoords, target_zcoords)
remapper = Remapper(axes(field), target_hcoords, target_zcoords)
return interpolate(remapper, field)
end

function get_target_hcoords(space::Spaces.AbstractSpace; hresolution)
return get_target_hcoords(Spaces.horizontal_space(space); hresolution)
end

function get_target_hcoords(
space::Spaces.SpectralElementSpace2D;
hresolution = 180,
)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
PT1 = typeof(domain.interval1.coord_min)
PT2 = typeof(domain.interval2.coord_min)
x1min = Geometry.component(domain.interval1.coord_min, 1)
x2min = Geometry.component(domain.interval2.coord_min, 1)
x1max = Geometry.component(domain.interval1.coord_max, 1)
x2max = Geometry.component(domain.interval2.coord_max, 1)
x1 = map(PT1, range(x1min, x1max; length = hresolution))
x2 = map(PT2, range(x2min, x2max; length = hresolution))
return Base.Iterators.product((x1, x2))
end

function get_target_hcoords(space::Spaces.SpectralElementSpace1D; hresolution = 180)
topology = Spaces.topology(space)
mesh = topology.mesh
domain = Meshes.domain(mesh)
PointType = typeof(domain.interval1.coord_min)
xmin = Geometry.component(domain.interval1.coord_min, 1)
xmax = Geometry.component(domain.interval1.coord_max, 1)
return PointType.(range(x1min, x1max; length = hresolution))
end

function get_target_zcoords(space; vresolution = 50)
return Geometry.ZPoint.(
range(z_min(space), z_max(space); length = vresolution)
)
end

# dest has to be allowed to be nothing because interpolation happens only on the root
# process
function interpolate!(
Expand All @@ -837,6 +1003,7 @@ function interpolate!(
if only_one_field
fields = [fields]
end
isa_vertical_space = remapper.space isa Spaces.FiniteDifferenceSpace

if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
Expand Down Expand Up @@ -869,11 +1036,14 @@ function interpolate!(
remapper,
view(fields, index_field_begin:index_field_end),
)
# Reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)

if !isa_vertical_space
# For spaces with an horizontal component, reshape the output so that it is a nice grid.
_apply_mpi_bitmask!(remapper, num_fields)
end

# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful. This also
# moves the data off the GPU
# obtain the final answer.
_collect_interpolated_values!(
dest,
remapper,
Expand Down
Loading

0 comments on commit 1504d19

Please sign in to comment.