Skip to content

Commit

Permalink
Style fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jun 25, 2024
1 parent e3124ce commit 6ed3a62
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,54 +47,61 @@ end
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

# to help Enzyme.jl, we won't pass the broadcasted object directly
# but instead pass its arguments and reconstruct the object device-side
bc = Broadcast.preprocess(dest, bc)
bcstyle = @static if VERSION >= v"1.10-"
bc.style
else
typeof(BroadcastStyle(typeof(bc)))
end

broadcast_kernel = if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) &&
isa(IndexStyle(bc), IndexLinear))
function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...)
@static if VERSION >= v"1.10-"
bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
bc2 = Base.Broadcast.Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

i = 1
while i <= nelem
I = @linearidx(dest, i)
@inbounds dest[I] = bc2[I]
@inbounds dest[I] = bc[I]
i += 1
end
return
end
else
function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...)
@static if VERSION >= v"1.10-"
bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
bc2 = Base.Broadcast.Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

i = 0
while i < nelem
i += 1
I = @cartesianidx(dest, i)
@inbounds dest[I] = bc2[I]
@inbounds dest[I] = bc[I]
end
return
end
end

elements = length(dest)
elements_per_thread = typemax(Int)
@static if VERSION >= v"1.10-"
style = bc.style
else
style = typeof(Base.Broadcast.BroadcastStyle(typeof(bc)))
end
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, style, bc.f, bc.axes, 1, bc.args...;
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1,
bcstyle, bc.f, bc.axes, bc.args...;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, style, bc.f, bc.axes, config.elements_per_thread, bc.args...;
gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int,
bcstyle, bc.f, bc.axes, bc.args...;
threads=config.threads, blocks=config.blocks)

if eltype(dest) <: BrokenBroadcast
Expand Down

0 comments on commit 6ed3a62

Please sign in to comment.