From c61d2e1663c8ca1213a89d53ef2e706dff8d1009 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 16 Jan 2025 11:29:55 -0500 Subject: [PATCH] fixes --- src/TracedUtils.jl | 4 ++-- src/Tracing.jl | 30 ++++++++++++++++-------------- test/tracing.jl | 12 ++++++------ 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index ab7a55643..c9184e41d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -131,7 +131,7 @@ function make_mlir_fn( (:args, i), concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; toscalar, - track_numbers=construct_function_without_args ? (Number,) : (), + track_numbers=construct_function_without_args ? Number : Union{}, ) end @@ -201,7 +201,7 @@ function make_mlir_fn( result, (:result,), concretein ? Reactant.TracedTrack : Reactant.TracedSetPath; - track_numbers=construct_function_without_args ? (Number,) : (), + track_numbers=construct_function_without_args ? Number : Union{}, ) # marks buffers to be donated diff --git a/src/Tracing.jl b/src/Tracing.jl index 15de26f3c..796d6b43a 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -153,7 +153,7 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(track_numbers)) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) # functions are directly returned if sizeof(T) == 0 return T @@ -309,24 +309,28 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val} throw("Val type $(Val{T}) cannot be traced") end -const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}} +const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}() -function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Val), @nospecialize(track_numbers::Type)) +function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) @nospecialize T = T.parameters[1] - mode = mode.parameters[1] + mode = mode.parameters[1]::TraceMode track_numbers = track_numbers.parameters[1] min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, track_numbers} + sig = Tuple{typeof(traced_type_inner), Type{T}, Dict{Type, Type}, TraceMode, Type{track_numbers}} lookup_result = lookup_world( sig, world, nothing, min_world, max_world ) - @assert lookup_result !== nothing + if lookup_result === nothing + @show sig + stub = Core.GeneratedFunctionStub(identity, Core.svec(:traced_type, :T, :mode, :track_numbers), Core.svec()) + return stub(world, source, method_error) + end match = lookup_result::Core.MethodMatch mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, @@ -430,7 +434,7 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, - @nospecialize(track_numbers::Type), + @nospecialize(track_numbers::Type=Union{}), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -686,12 +690,10 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type), kwargs... + seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs... ) RT = Core.Typeof(prev) - length(track_numbers) == 0 && return prev - should_convert = RT <: track_numbers - if should_convert + if RT <: track_numbers if mode == ArrayToConcrete return ConcreteRNumber(prev) else @@ -787,7 +789,7 @@ function make_tracer( @nospecialize(prev::NamedTuple), @nospecialize(path), mode; - @nospecialize(track_numbers::Type), + @nospecialize(track_numbers::Type=Union{}), kwargs..., ) NT = Core.Typeof(prev) @@ -822,8 +824,8 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) return res end -@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) - track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) +@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Type}=false) + track_numbers isa Bool && (track_numbers = track_numbers ? Number : Union{}) return to_rarray_internal(x, track_numbers) end diff --git a/test/tracing.jl b/test/tracing.jl index a005a2bdd..53f79bc71 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -95,7 +95,7 @@ using Test (Val{:x}, Val{:x}, Val{:x}), ] tracedty = traced_type( - origty, Val(ConcreteToTraced), () + origty, Val(ConcreteToTraced), Union{} ) @test tracedty == targetty @@ -112,13 +112,13 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Val(ConcreteToTraced), () + type, Val(ConcreteToTraced), Union{} ) end end @testset "traced_type exceptions" begin @test_throws TracedTypeError Reactant.traced_type( - Real, Val(Reactant.ArrayToConcrete), () + Real, Val(Reactant.ArrayToConcrete), Union{} ) struct Node @@ -126,14 +126,14 @@ using Test y::Union{Nothing,Node} end @test_throws NoFieldMatchError traced_type( - Node, Val(ArrayToConcrete), () + Node, Val(ArrayToConcrete), Union{} ) end end @testset "specialized dispatches" begin @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( - 1.0; track_numbers=(Number,) + 1.0; track_numbers=Number ) isa ConcreteRNumber @test @inferred Reactant.to_rarray(1.0) isa Float64 @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray @@ -141,7 +141,7 @@ using Test x_ra = Reactant.to_rarray(rand(3)) @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray - x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,)) + x_ra = Reactant.to_rarray(1.0; track_numbers=Number) @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber end end