Skip to content

Commit

Permalink
fix: tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 19, 2025
1 parent 5a07dca commit 83aceac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
55 changes: 38 additions & 17 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,36 +285,25 @@ Base.@nospecializeinfer function traced_type_inner(
mode::TraceMode,
@nospecialize(args::Vararg)
)
T = T0.parameters[1]
if mode == ConcreteToTraced
return TracedRNumber{T}
return TracedRNumber{T0.parameters[1]}
elseif mode == TracedToConcrete
return ConcreteRNumber{T}
return T0
else
throw("Abstract RNumber cannot be made concrete")
end
end

Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) =
UnionAll(TV.var, base_typet(TV.body))
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) =
TracedRArray{TV.parameters...}

Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) =
UnionAll(TV.var, base_typec(TV.body))
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) =
(TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:ConcreteRArray}),
@nospecialize(CA::Type{<:ConcreteRArray}),
seen,
mode::TraceMode,
@nospecialize(args::Vararg)
)
if mode == ConcreteToTraced
return base_typet(T)
return TracedRArray{CA.parameters[1],CA.parameters[2]}
elseif mode == TracedToConcrete
return T
return CA
else
throw("Abstract RArray cannot be made concrete")
end
Expand Down Expand Up @@ -346,6 +335,38 @@ Base.@nospecializeinfer function traced_type_inner(
return error("This should not happen...")
end

Base.@nospecializeinfer function traced_type_inner(
TR::Type{<:TracedRNumber},
seen,
mode::TraceMode,
@nospecialize(track_numbers),
@nospecialize(batchmode),
@nospecialize(tobatch)
)
T = TR.parameters[1]
if mode == ConcreteToTraced
throw("TracedRArray $(TracedRArray{T,N}) cannot be traced")
elseif mode == TracedToConcrete
return ConcreteRNumber{T}
elseif mode == TracedTrack || mode == NoStopTracedTrack
return TracedRNumber{T}
elseif mode == TracedSetPath
if batchmode == BatchNone
return TracedRNumber{T}
elseif batchmode == BatchScalar
if tobatch === nothing
return TracedRNumber{T}
else
return TracedRArray{T,length(tobatch)}
end
else
error("Cannot BatchArray on a scalar")
end
else
throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode")
end
end

Base.@nospecializeinfer function traced_type_inner(
TR::Type{<:TracedRArray},
seen,
Expand All @@ -359,7 +380,7 @@ Base.@nospecializeinfer function traced_type_inner(
if mode == ConcreteToTraced
throw("TracedRArray $(TracedRArray{T,N}) cannot be traced")
elseif mode == TracedToConcrete
return base_typec(TracedRArray{T,N})
return ConcreteRArray{T,N}
elseif mode == TracedTrack || mode == NoStopTracedTrack
return TracedRArray{T,N}
elseif mode == TracedSetPath
Expand Down
13 changes: 5 additions & 8 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,18 @@ end
end

sinexp(x) = sin(exp(x))
sinexpbc(x) = sinexp.(x)

@testset "Broadcast combined" begin
x = rand(2, 10)

r_res = sinexpbc(x)
r_res = sinexp.(x)

a = Reactant.ConcreteRArray(x)

c_res = @allowscalar sinexpbc(a)
c_res = @allowscalar sinexp.(a)
@test c_res r_res

@test @jit(sinexpbc(a)) r_res
@test @jit(sinexp.(a)) r_res
end

sumexp(x) = sum(exp, x)
Expand Down Expand Up @@ -82,13 +81,11 @@ end
@test f_res r_res
end

bcast_cos(x) = cos.(x)

@testset "Basic cos" begin
x = rand(3, 2)
c = Reactant.ConcreteRArray(x)

@test @jit(bcast_cos(c)) cos.(x)
@test @jit(cos.(c)) cos.(x)
end

f_var(args...) = sum(args)
Expand Down Expand Up @@ -376,7 +373,7 @@ end
b = Reactant.to_rarray(_b)
c = Reactant.to_rarray(_c)

# vcat test
# vcat test
y = @jit vcat(a, b)
@test y == vcat(a, _b)
@test y isa ConcreteRArray{typeof_a,1}
Expand Down

0 comments on commit 83aceac

Please sign in to comment.