Skip to content

Commit

Permalink
fix: handle type propagation in getfield
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Mar 14, 2024
1 parent 75cc676 commit 2570016
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ juliatype(::Type{Struct{T}}) where T = T
getelements(s::Type{<:Struct}) = fieldnames(juliatype(s))
getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s))

typed_getfield(obj, ::Val{fieldname}) where fieldname = getfield(obj, fieldname)

function symbolic_getproperty(ss, name::Symbol)
s = symtype(ss)
idx = findfirst(isequal(name), getelements(s))
idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!")
T = getelementtypes(s)[idx]
SymbolicUtils.term(getfield, ss, Meta.quot(name), type = T)
if isstructtype(T)
T = Struct{T}
end
SymbolicUtils.term(typed_getfield, ss, Val{name}(), type = T)
end
function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol)
wrap(symbolic_getproperty(unwrap(s), name))
Expand All @@ -54,6 +59,18 @@ end

# We cannot precisely derive the type after `getfield` due to SU limitations,
# so give up and just say Real.
SymbolicUtils.promote_symtype(::typeof(getfield), ::Type{<:Struct}, _...) = Real
function SymbolicUtils.promote_symtype(::typeof(typed_getfield), ::Type{<:Struct{T}}, v::Type{Val{fieldname}}) where {T, fieldname}
FT = fieldtype(T, fieldname)
if isstructtype(FT)
return Struct{FT}
end
FT
end


SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
function SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T
s
end

SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T
SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T = s

0 comments on commit 2570016

Please sign in to comment.