-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Despecialize make_tracer #540
Changes from all commits
0adac3c
86c583c
9639190
6060012
81a1ace
7f1d2a5
5742e62
c61d2e1
ed7d216
0871a08
5d19fe8
020ef87
e2238ca
deafe73
4d1bd42
67056b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Reactant" | ||
uuid = "3c362404-f566-11ee-1572-e11a4b42c853" | ||
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"] | ||
version = "0.2.18" | ||
version = "0.2.19" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -722,37 +722,40 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction( | |||||||||||||||
return Core.Typeof(res)(f, res.entry) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function Reactant.traced_type( | ||||||||||||||||
::Type{A}, seen::ST, ::Val{mode}, track_numbers | ||||||||||||||||
) where {A<:CuTracedArray,ST,mode} | ||||||||||||||||
Base.@nospecializeinfer function Reactant.traced_type_inner( | ||||||||||||||||
@nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) | ||||||||||||||||
) | ||||||||||||||||
return A | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function Reactant.traced_type( | ||||||||||||||||
::Type{A}, seen::ST, ::Val{mode}, track_numbers | ||||||||||||||||
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode} | ||||||||||||||||
Base.@nospecializeinfer function Reactant.traced_type_inner( | ||||||||||||||||
@nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||
) | ||||||||||||||||
T = eltype(A) | ||||||||||||||||
N = ndims(A) | ||||||||||||||||
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive | ||||||||||||||||
return Reactant.ConcreteRArray{T,N} | ||||||||||||||||
else | ||||||||||||||||
TT = Reactant.traced_type(T, seen, Val(mode), track_numbers) | ||||||||||||||||
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) | ||||||||||||||||
if TT === T | ||||||||||||||||
return A | ||||||||||||||||
else | ||||||||||||||||
return Array{traced_type(T, seen, Val(mode), track_numbers),N} | ||||||||||||||||
return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N} | ||||||||||||||||
end | ||||||||||||||||
end | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function Reactant.make_tracer( | ||||||||||||||||
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... | ||||||||||||||||
) where {RT<:CUDA.CuArray} | ||||||||||||||||
seen, @nospecialize(prev::CUDA.CuArray), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs... | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||
) | ||||||||||||||||
RT = Core.Typeof(prev) | ||||||||||||||||
if haskey(seen, prev) | ||||||||||||||||
return seen[prev] | ||||||||||||||||
end | ||||||||||||||||
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive | ||||||||||||||||
return seen[prev] = Reactant.ConcreteRArray(Array(prev)) | ||||||||||||||||
end | ||||||||||||||||
TT = Reactant.traced_type(eltype(RT), (), Val(mode), track_numbers) | ||||||||||||||||
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers) | ||||||||||||||||
if TT === eltype(RT) | ||||||||||||||||
return prev | ||||||||||||||||
end | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,12 +1,15 @@ | ||||||||||||
module ReactantOffsetArraysExt | ||||||||||||
|
||||||||||||
using OffsetArrays | ||||||||||||
using OffsetArrays: OffsetArray | ||||||||||||
using Reactant: Reactant, MLIR, Ops, TracedRArray | ||||||||||||
|
||||||||||||
function Reactant.traced_type( | ||||||||||||
::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers | ||||||||||||
) where {T,N,ST,mode} | ||||||||||||
T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers) | ||||||||||||
Base.@nospecializeinfer function Reactant.traced_type_inner( | ||||||||||||
@nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, @nospecialize(track_numbers::Type=Union{}) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||
) | ||||||||||||
N = ndims(OA) | ||||||||||||
T = OffsetArrays.parenttype(OA) | ||||||||||||
T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers) | ||||||||||||
return OffsetArray{eltype(T2),N,T2} | ||||||||||||
end | ||||||||||||
|
||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -23,14 +23,23 @@ else | |||||||||||||||||||||||||||
const ReactantFloat = Union{Float16,Float32,Float64} | ||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} | ||||||||||||||||||||||||||||
@static if isdefined(Core, :BFloat16) | ||||||||||||||||||||||||||||
const ReactantComplexFloat = Union{Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}} | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}} | ||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const ReactantComplexInt = Union{Complex{Int8},Complex{UInt8},Complex{Int16},Complex{UInt16},Complex{Int32},Complex{UInt32},Complex{Int64},Complex{UInt64},Complex{Int128},Complex{UInt128}} | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const ReactantFloatInt = Union{ | ||||||||||||||||||||||||||||
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
const ReactantPrimitive = Union{ | ||||||||||||||||||||||||||||
Bool,Base.uniontypes(ReactantFloatInt)...,Complex{Float32},Complex{Float64} | ||||||||||||||||||||||||||||
Bool,Base.uniontypes(ReactantFloatInt)..., | ||||||||||||||||||||||||||||
Base.uniontypes(ReactantComplexInt)...,Base.uniontypes(ReactantComplexFloat)... | ||||||||||||||||||||||||||||
Comment on lines
+41
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
abstract type RNumber{T<:ReactantPrimitive} <: Number end | ||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶