Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Despecialize make_tracer #540

Merged
merged 16 commits into from
Jan 16, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
version = "0.2.18"
version = "0.2.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
25 changes: 14 additions & 11 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,37 +722,40 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
return Core.Typeof(res)(f, res.entry)
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {A<:CuTracedArray,ST,mode}
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CuTracedArray}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)
)
return A
end

function Reactant.traced_type(
::Type{A}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}
Base.@nospecializeinfer function Reactant.traced_type_inner(
@nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)
)
T = eltype(A)
N = ndims(A)
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
return Reactant.ConcreteRArray{T,N}
else
TT = Reactant.traced_type(T, seen, Val(mode), track_numbers)
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers)
if TT === T
return A
else
return Array{traced_type(T, seen, Val(mode), track_numbers),N}
return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N}
end
end
end

function Reactant.make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
) where {RT<:CUDA.CuArray}
seen, @nospecialize(prev::CUDA.CuArray), @nospecialize(path), mode; track_numbers=(), kwargs...
)
RT = Core.Typeof(prev)
if haskey(seen, prev)
return seen[prev]
end
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
return seen[prev] = Reactant.ConcreteRArray(Array(prev))
end
TT = Reactant.traced_type(eltype(RT), (), Val(mode), track_numbers)
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers)
if TT === eltype(RT)
return prev
end
Expand Down
9 changes: 6 additions & 3 deletions ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module ReactantOffsetArraysExt

using OffsetArrays
using OffsetArrays: OffsetArray
using Reactant: Reactant, MLIR, Ops, TracedRArray

function Reactant.traced_type(
::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,ST,mode}
Base.@nospecializeinfer function Reactant.traced_type(
@nospecialize(OA::Type{<:OffsetArray}), seen::ST, ::Val{mode}, track_numbers
) where {ST,mode}
N = ndims(OA)
T = OffsetArrays.parenttype(OA)
T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers)
return OffsetArray{eltype(T2),N,T2}
end
Expand Down
13 changes: 11 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,23 @@ else
const ReactantFloat = Union{Float16,Float32,Float64}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}
@static if isdefined(Core, :BFloat16)
const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}}
const ReactantComplexFloat = Union{
Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}
}

else
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}}
const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}


const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64}
Bool,Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)...
Comment on lines +41 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Bool,Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)...
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,

}

abstract type RNumber{T<:ReactantPrimitive} <: Number end
Expand Down
4 changes: 2 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function make_mlir_fn(
(:args, i),
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
toscalar,
track_numbers=construct_function_without_args ? (Number,) : (),
track_numbers=construct_function_without_args ? Number : Union{},
)
end

Expand Down Expand Up @@ -201,7 +201,7 @@ function make_mlir_fn(
result,
(:result,),
concretein ? Reactant.TracedTrack : Reactant.TracedSetPath;
track_numbers=construct_function_without_args ? (Number,) : (),
track_numbers=construct_function_without_args ? Number : Union{},
)

# marks buffers to be donated
Expand Down
Loading
Loading