Skip to content

Commit

Permalink
It's Perfect. It's flawless. Really something.
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Feb 11, 2023
1 parent f8c7e6d commit 7da6082
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 137 deletions.
15 changes: 2 additions & 13 deletions src/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N}
∂f = ∂☆{N}()(ZeroBundle{N}(f),
TaylorBundle{N}(x,
(one(x), (zero(x) for i = 1:(N-1))...,)))
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
Jet{typeof(x), typeof(x), N}(x, ∂f.primal,
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
@assert isa(∂f, TaylorBundle)
Jet{typeof(x), typeof(x), N}(x, ∂f.primal, ∂f.tangent.coeffs)
end
∂⃖ₙ(mapev, js, a)
end
Expand Down Expand Up @@ -248,13 +247,3 @@ end
($((:(jet_taylor_ev(Val{$i}(), coeffs, j)) for i = 1:O)...),))
end
end

function (j::Jet{S, T, 1} where {S,T})(x::ExplicitTangentBundle{1})
domain_check(j, x.primal)
coeffs = x.tangent.partials
ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
end

function (j::Jet{S, T, N} where T)(x::ExplicitTangentBundle{N, M}) where {S, N, M}
error("TODO")
end
49 changes: 4 additions & 45 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
partial(x::UniformTangent, i) = getfield(x, :val)
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
Expand All @@ -25,15 +24,6 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
UniformBundle{N-1, <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)

function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
# N.B: This depends on the special properties of the canonical tangent index order
ExplicitTangentBundle{N-1}(
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
ntuple(1<<(N-1)-1) do i
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
end)
end

function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
TaylorBundle{N-1}(
TaylorBundle{1}(b.primal, (b.tangent.coeffs[1],)),
Expand All @@ -58,31 +48,12 @@ function shuffle_up(r::CompositeBundle{1})
return TaylorBundle{2}(z₀, (z₁, z₁₂))
end

function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
primal(b) === a[TaylorTangentIndex(1)] || return false
return all(1:(N-1)) do i
b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)]
end
end

# Check whether the tangent bundle element is taylor-like
isswifty(::TaylorBundle) = true
isswifty(::UniformBundle) = true
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
isswifty(::Any) = false

function shuffle_up(r::CompositeBundle{N}) where {N}
a, b = r.tup
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
return TaylorBundle{N+1}(primal(a),
ntuple(i->i == N+1 ?
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
N+1))
else
return TangentBundle{N+1}(r.tup[1].primal,
(r.tup[1].tangent.partials..., primal(b),
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
end
return TaylorBundle{N+1}(primal(a),
ntuple(i->i == N+1 ?
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
N+1))
end

function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
Expand Down Expand Up @@ -134,18 +105,6 @@ end
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)

# Special case rules for performance
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s),
map(x->lifted_getfield(x, s), x.tangent.partials))
end

@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
s = primal(s)
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
map(x->lifted_getfield(x, s), x.tangent.partials))
end

@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TaylorBundle{N}(getfield(primal(x), s),
Expand Down
5 changes: 2 additions & 3 deletions src/stage1/mixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),
TaylorBundle{N+M}(x,
(one(x), (zero(x) for i = 1:(N+M-1))...,)))
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
Jet{typeof(x), N+M}(x, ∂f.primal,
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
@assert isa(∂f, TaylorBundle)
Jet{typeof(x), N+M}(x, ∂f.primal, ∂f.tangent.coeffs)
end
∂⃖ₙ(mapev_unbundled, ∂☆ₘ, js, a)
end
2 changes: 1 addition & 1 deletion src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct ∂☆new{N}; end
(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
CompositeBundle{N, B}(a)

@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
(::∂☆new{N})(B::Type) where {N} = return ZeroBundle{N}(B)

# Sometimes we don't know whether or not we need to the ZeroBundle when doing
# the transform, so this can happen - allow it for now.
Expand Down
73 changes: 0 additions & 73 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,6 @@ end

abstract type AbstractTangentSpace; end

"""
struct ExplicitTangent{P}
A fully explicit coordinate representation of the tangent space,
represented by a vector of `2^(N-1)` partials.
"""
struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
partials::P
end

struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
coeffs::C
end
Expand Down Expand Up @@ -151,46 +141,9 @@ struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N
TangentBundle{N}(B, P) where {N} = new{N, typeof(B), typeof(P)}(B,P)
end

const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}

check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
@ChainRulesCore.non_differentiable check_tangent_invariant(lp, N)

function ExplicitTangentBundle{N}(primal::B, partials::P) where {N, B, P}
check_tangent_invariant(length(partials), N)
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
end

function ExplicitTangentBundle{N,B}(primal::B, partials::P) where {N, B, P}
check_tangent_invariant(length(partials), N)
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
end

function ExplicitTangentBundle{N,B,P}(primal::B, partials::P) where {N, B, P}
check_tangent_invariant(length(partials), N)
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
end

function Base.show(io::IO, x::ExplicitTangentBundle)
print(io, x.primal)
print(io, " + ")
x = x.tangent
print(io, x.partials[1], " ∂₁")
length(x.partials) >= 2 && print(io, " + ", x.partials[2], " ∂₂")
length(x.partials) >= 3 && print(io, " + ", x.partials[3], " ∂₁ ∂₂")
length(x.partials) >= 4 && print(io, " + ", x.partials[4], " ∂₃")
length(x.partials) >= 5 && print(io, " + ", x.partials[5], " ∂₁ ∂₃")
length(x.partials) >= 6 && print(io, " + ", x.partials[6], " ∂₂ ∂₃")
length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃")
end

function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N}
if b.i === N
return a.tangent.partials[end]
end
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
end

const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}

function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
Expand Down Expand Up @@ -268,24 +221,6 @@ end
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
expand_singleton_to_array(asize, a::AbstractArray) = a

function unbundle(atb::ExplicitTangentBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}}
asize = size(atb.primal)
StructArray{ExplicitTangentBundle{Order, T}}((atb.primal, map(a->expand_singleton_to_array(asize, a), atb.tangent.partials)...))
end

function StructArrays.staticschema(::Type{<:ExplicitTangentBundle{N, B, T}}) where {N, B, T}
Tuple{B, T.parameters...}
end

function StructArrays.component(m::ExplicitTangentBundle{N, B, T}, i::Int) where {N, B, T}
i == 1 && return m.primal
return m.tangent.partials[i - 1]
end

function StructArrays.createinstance(T::Type{<:ExplicitTangentBundle}, args...)
T(first(args), Base.tail(args))
end

function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}}
StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...))
end
Expand Down Expand Up @@ -323,14 +258,6 @@ function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
T(args[1], args[2])
end

function rebundle(A::AbstractArray{<:ExplicitTangentBundle{N}}) where {N}
ExplicitTangentBundle{N}(
map(x->x.primal, A),
ntuple(2^N-1) do i
map(x->x.tangent.partials[i], A)
end)
end

function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N}
TaylorBundle{N}(
map(x->x.primal, A),
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()

# Minimal 2-nd order forward smoke test
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
Diffractor.TaylorBundle{2}(1.0, (1.0 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
Diffractor.TaylorBundle{2}(1.0, (1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)

function simple_control_flow(b, x)
if b
Expand Down
1 change: 0 additions & 1 deletion test/stage2_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ module stage2_fwd

self_minus(a) = myminus(a, a)
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
# TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}}
# We should have Diffractor be able to prove uniformity
@test_broken isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
@test self_minus′′(1.0) == 0.
Expand Down

0 comments on commit 7da6082

Please sign in to comment.