Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 16, 2025
1 parent 5742e62 commit c61d2e1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function make_mlir_fn(
(:args, i),
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
toscalar,
track_numbers=construct_function_without_args ? (Number,) : (),
track_numbers=construct_function_without_args ? Number : Union{},
)
end

Expand Down Expand Up @@ -201,7 +201,7 @@ function make_mlir_fn(
result,
(:result,),
concretein ? Reactant.TracedTrack : Reactant.TracedSetPath;
track_numbers=construct_function_without_args ? (Number,) : (),
track_numbers=construct_function_without_args ? Number : Union{},
)

# marks buffers to be donated
Expand Down
30 changes: 16 additions & 14 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Base.@nospecializeinfer function traced_type_inner(
end
end

Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(track_numbers))
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(track_numbers::Type))
# functions are directly returned
if sizeof(T) == 0
return T
Expand Down Expand Up @@ -309,24 +309,28 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}
throw("Val type $(Val{T}) cannot be traced")
end

const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}
const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}()

function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Val), @nospecialize(track_numbers::Type))
function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type))
@nospecialize
T = T.parameters[1]
mode = mode.parameters[1]
mode = mode.parameters[1]::TraceMode
track_numbers = track_numbers.parameters[1]


min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, track_numbers}
sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, Type{track_numbers}}

lookup_result = lookup_world(
sig, world, nothing, min_world, max_world
)
@assert lookup_result !== nothing
if lookup_result === nothing
@show sig
stub = Core.GeneratedFunctionStub(identity, Core.svec(:traced_type, :T, :mode, :track_numbers), Core.svec())
return stub(world, source, method_error)
end
match = lookup_result::Core.MethodMatch

mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
Expand Down Expand Up @@ -430,7 +434,7 @@ function make_tracer(
mode;
toscalar=false,
tobatch=nothing,
@nospecialize(track_numbers::Type),
@nospecialize(track_numbers::Type=Union{}),
kwargs...,
)
if mode != NoStopTracedTrack && haskey(seen, prev)
Expand Down Expand Up @@ -686,12 +690,10 @@ function make_tracer(
end

function make_tracer(
seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type), kwargs...
seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs...
)
RT = Core.Typeof(prev)
length(track_numbers) == 0 && return prev
should_convert = RT <: track_numbers
if should_convert
if RT <: track_numbers
if mode == ArrayToConcrete
return ConcreteRNumber(prev)
else
Expand Down Expand Up @@ -787,7 +789,7 @@ function make_tracer(
@nospecialize(prev::NamedTuple),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type),
@nospecialize(track_numbers::Type=Union{}),
kwargs...,
)
NT = Core.Typeof(prev)
Expand Down Expand Up @@ -822,8 +824,8 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
return res
end

@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Type}=false)
track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{})
return to_rarray_internal(x, track_numbers)
end

Expand Down
12 changes: 6 additions & 6 deletions test/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ using Test
(Val{:x}, Val{:x}, Val{:x}),
]
tracedty = traced_type(
origty, Val(ConcreteToTraced), ()
origty, Val(ConcreteToTraced), Union{}
)
@test tracedty == targetty

Expand All @@ -112,36 +112,36 @@ using Test
TracedRArray{Float64,3},
]
@test_throws Union{ErrorException,String} traced_type(
type, Val(ConcreteToTraced), ()
type, Val(ConcreteToTraced), Union{}
)
end
end
@testset "traced_type exceptions" begin
@test_throws TracedTypeError Reactant.traced_type(
Real, Val(Reactant.ArrayToConcrete), ()
Real, Val(Reactant.ArrayToConcrete), Union{}
)

struct Node
x::Vector{Float64}
y::Union{Nothing,Node}
end
@test_throws NoFieldMatchError traced_type(
Node, Val(ArrayToConcrete), ()
Node, Val(ArrayToConcrete), Union{}
)
end
end

@testset "specialized dispatches" begin
@test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray(
1.0; track_numbers=(Number,)
1.0; track_numbers=Number
) isa ConcreteRNumber
@test @inferred Reactant.to_rarray(1.0) isa Float64
@test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray

x_ra = Reactant.to_rarray(rand(3))
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray

x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,))
x_ra = Reactant.to_rarray(1.0; track_numbers=Number)
@test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber
end
end

0 comments on commit c61d2e1

Please sign in to comment.