Skip to content

Commit 09589cd

Browse files
committed
more extra rules for static arrays
more overloads for StaticArrays
1 parent e1c7c7e commit 09589cd

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/extra_rules.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,14 @@ function ChainRules.rrule(::DiffractorRuleConfig, ::Type{SArray{S, T, N, L}}, x:
172172
end
173173

174174
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,T}) where {S, T, N, L}
175-
SArray{S, T, N, L}(x), SArray{S, T, N, L}(∂x.backing)
175+
#TODO: we really shouldn't actually see the isa(∂x, AbstractZero) case since the frule should be called then
176+
Δx = isa(∂x, AbstractZero) ? ∂x : SArray{S, T, N, L}(ChainRulesCore.backing(∂x))
177+
SArray{S, T, N, L}(x), Δx
176178
end
177179

180+
Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds)
181+
Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind]
182+
178183
function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L}
179184
SArray{S, T, N, L}(x), SArray{S}(∂x)
180185
end

0 commit comments

Comments
 (0)