From 6ed3a626e3c0f095aa228685fa6f1c14d12aef96 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 25 Jun 2024 12:40:03 +0200 Subject: [PATCH] Style fixes. --- src/host/broadcast.jl | 49 ++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index fbf167bb..f3d841ca 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -47,37 +47,47 @@ end @inline function _copyto!(dest::AbstractArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest + + # to help Enzyme.jl, we won't pass the broadcasted object directly + # but instead pass its arguments and reconstruct the object device-side bc = Broadcast.preprocess(dest, bc) + bcstyle = @static if VERSION >= v"1.10-" + bc.style + else + typeof(BroadcastStyle(typeof(bc))) + end broadcast_kernel = if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) - function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...) -@static if VERSION >= v"1.10-" - bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes) -else - bc2 = Base.Broadcast.Broadcasted{bcstyle}(bcf, bcargs, bcaxes) -end + function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) + bc = @static if VERSION >= v"1.10-" + Broadcasted(bcstyle, bcf, bcargs, bcaxes) + else + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) + end + i = 1 while i <= nelem I = @linearidx(dest, i) - @inbounds dest[I] = bc2[I] + @inbounds dest[I] = bc[I] i += 1 end return end else - function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...) -@static if VERSION >= v"1.10-" - bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes) -else - bc2 = Base.Broadcast.Broadcasted{bcstyle}(bcf, bcargs, bcaxes) -end + function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) + bc = @static if VERSION >= v"1.10-" + Broadcasted(bcstyle, bcf, bcargs, bcaxes) + else + Broadcasted{bcstyle}(bcf, bcargs, bcaxes) + end + i = 0 while i < nelem i += 1 I = @cartesianidx(dest, i) - @inbounds dest[I] = bc2[I] + @inbounds dest[I] = bc[I] end return end @@ -85,16 +95,13 @@ end elements = length(dest) elements_per_thread = typemax(Int) -@static if VERSION >= v"1.10-" - style = bc.style -else - style = typeof(Base.Broadcast.BroadcastStyle(typeof(bc))) -end - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, style, bc.f, bc.axes, 1, bc.args...; + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, + bcstyle, bc.f, bc.axes, bc.args...; elements, elements_per_thread) config = launch_configuration(backend(dest), heuristic; elements, elements_per_thread) - gpu_call(broadcast_kernel, dest, style, bc.f, bc.axes, config.elements_per_thread, bc.args...; + gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, + bcstyle, bc.f, bc.axes, bc.args...; threads=config.threads, blocks=config.blocks) if eltype(dest) <: BrokenBroadcast