From 83aceac2236476e05d1184d060b5ab9361a53c15 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 17:33:47 -0500 Subject: [PATCH] fix: tracing --- src/Tracing.jl | 55 ++++++++++++++++++++++++++++++++++---------------- test/basic.jl | 13 +++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index bad5f3997..181fa6c7f 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -285,36 +285,25 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(args::Vararg) ) - T = T0.parameters[1] if mode == ConcreteToTraced - return TracedRNumber{T} + return TracedRNumber{T0.parameters[1]} elseif mode == TracedToConcrete - return ConcreteRNumber{T} + return T0 else throw("Abstract RNumber cannot be made concrete") end end -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typet(TV.body)) -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = - TracedRArray{TV.parameters...} - -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typec(TV.body)) -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = - (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} - Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), + @nospecialize(CA::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if mode == ConcreteToTraced - return base_typet(T) + return TracedRArray{CA.parameters[1],CA.parameters[2]} elseif mode == TracedToConcrete - return T + return CA else throw("Abstract RArray cannot be made concrete") end @@ -346,6 +335,38 @@ Base.@nospecializeinfer function traced_type_inner( return error("This should not happen...") end +Base.@nospecializeinfer function traced_type_inner( + TR::Type{<:TracedRNumber}, + seen, + mode::TraceMode, + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) +) + T = TR.parameters[1] + if mode == ConcreteToTraced + throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") + elseif mode == TracedToConcrete + return ConcreteRNumber{T} + elseif mode == TracedTrack || mode == NoStopTracedTrack + return TracedRNumber{T} + elseif mode == TracedSetPath + if batchmode == BatchNone + return TracedRNumber{T} + elseif batchmode == BatchScalar + if tobatch === nothing + return TracedRNumber{T} + else + return TracedRArray{T,length(tobatch)} + end + else + error("Cannot BatchArray on a scalar") + end + else + throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode") + end +end + Base.@nospecializeinfer function traced_type_inner( TR::Type{<:TracedRArray}, seen, @@ -359,7 +380,7 @@ Base.@nospecializeinfer function traced_type_inner( if mode == ConcreteToTraced throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") elseif mode == TracedToConcrete - return base_typec(TracedRArray{T,N}) + return ConcreteRArray{T,N} elseif mode == TracedTrack || mode == NoStopTracedTrack return TracedRArray{T,N} elseif mode == TracedSetPath diff --git a/test/basic.jl b/test/basic.jl index 019623226..60a3b464c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -34,19 +34,18 @@ end end sinexp(x) = sin(exp(x)) -sinexpbc(x) = sinexp.(x) @testset "Broadcast combined" begin x = rand(2, 10) - r_res = sinexpbc(x) + r_res = sinexp.(x) a = Reactant.ConcreteRArray(x) - c_res = @allowscalar sinexpbc(a) + c_res = @allowscalar sinexp.(a) @test c_res ≈ r_res - @test @jit(sinexpbc(a)) ≈ r_res + @test @jit(sinexp.(a)) ≈ r_res end sumexp(x) = sum(exp, x) @@ -82,13 +81,11 @@ end @test f_res ≈ r_res end -bcast_cos(x) = cos.(x) - @testset "Basic cos" begin x = rand(3, 2) c = Reactant.ConcreteRArray(x) - @test @jit(bcast_cos(c)) ≈ cos.(x) + @test @jit(cos.(c)) ≈ cos.(x) end f_var(args...) = sum(args) @@ -376,7 +373,7 @@ end b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - # vcat test + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1}