Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move @noinline into code generation utilities #205

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ SCT = SparseConnectivityTracer

## 1-to-1

@noinline function gradient_tracer_1_to_1(
t::T, is_der1_zero::Bool
) where {T<:GradientTracer}
function gradient_tracer_1_to_1(t::T, is_der1_zero::Bool) where {T<:GradientTracer}
if is_der1_zero && !isemptytracer(t)
return myempty(T)
else
Expand Down Expand Up @@ -36,7 +34,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
end
end

Expand All @@ -55,7 +53,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -65,7 +63,7 @@ end

## 2-to-1

@noinline function gradient_tracer_2_to_1(
function gradient_tracer_2_to_1(
tx::T, ty::T, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {T<:GradientTracer}
# TODO: add tests for isempty
Expand Down Expand Up @@ -116,7 +114,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer}
return $SCT.gradient_tracer_2_to_1(
return @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g
)
end
Expand All @@ -141,7 +139,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)
ty = $SCT.tracer(dy)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_2_to_1(
t_out = @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, is_der1_arg1_zero, is_der1_arg2_zero
)
return $SCT.Dual(p_out, t_out)
Expand All @@ -164,12 +162,12 @@ function generate_code_gradient_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

Expand All @@ -188,7 +186,7 @@ function generate_code_gradient_2_to_1_typed(

tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -208,7 +206,7 @@ function generate_code_gradient_2_to_1_typed(

ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -218,7 +216,7 @@ end

## 1-to-2

@noinline function gradient_tracer_1_to_2(
function gradient_tracer_1_to_2(
t::T, is_der1_out1_zero::Bool, is_der1_out2_zero::Bool
) where {T<:GradientTracer}
if isemptytracer(t) # TODO: add test
Expand All @@ -237,7 +235,9 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_2(
t, $is_der1_out1_zero_g, $is_der1_out2_zero_g
)
end
end

Expand All @@ -257,7 +257,7 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.gradient_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.gradient_tracer_1_to_2(
t, is_der1_out1_zero, is_der1_out2_zero
)
return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test
Expand Down
34 changes: 21 additions & 13 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SCT = SparseConnectivityTracer
# 𝟙[∇γ] = 𝟙[∂φ]⋅𝟙[∇α]
# 𝟙[∇²γ] = 𝟙[∂φ]⋅𝟙[∇²α] ∨ 𝟙[∂²φ]⋅(𝟙[∇α] ∨ 𝟙[∇α]ᵀ)

@noinline function hessian_tracer_1_to_1(
function hessian_tracer_1_to_1(
t::T, is_der1_zero::Bool, is_der2_zero::Bool
) where {P<:AbstractHessianPattern,T<:HessianTracer{P}}
if isemptytracer(t) # TODO: add test
Expand Down Expand Up @@ -65,7 +65,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
expr_hessiantracer = quote
## HessianTracer
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
end
end

Expand All @@ -85,7 +85,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x)
t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -96,7 +96,7 @@ end

## 2-to-1

@noinline function hessian_tracer_2_to_1(
function hessian_tracer_2_to_1(
tx::T,
ty::T,
is_der1_arg1_zero::Bool,
Expand Down Expand Up @@ -189,7 +189,7 @@ function generate_code_hessian_2_to_1(

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer}
return $SCT.hessian_tracer_2_to_1(
return @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
$is_der1_arg1_zero_g,
Expand Down Expand Up @@ -232,7 +232,7 @@ function generate_code_hessian_2_to_1(
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_2_to_1(
t_out = @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
is_der1_arg1_zero,
Expand Down Expand Up @@ -263,12 +263,16 @@ function generate_code_hessian_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g
)
end
end
expr_type_tracer = quote
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g
)
end
end

Expand All @@ -288,7 +292,9 @@ function generate_code_hessian_2_to_1_typed(
tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
tx, is_der1_arg1_zero, is_der2_arg1_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -309,7 +315,9 @@ function generate_code_hessian_2_to_1_typed(
ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
ty, is_der1_arg2_zero, is_der2_arg2_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -319,7 +327,7 @@ end

## 1-to-2

@noinline function hessian_tracer_1_to_2(
function hessian_tracer_1_to_2(
t::T,
is_der1_out1_zero::Bool,
is_der2_out1_zero::Bool,
Expand All @@ -344,7 +352,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)

expr_hessiantracer = quote
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_2(
return @noinline $SCT.hessian_tracer_1_to_2(
t,
$is_der1_out1_zero_g,
$is_der2_out1_zero_g,
Expand Down Expand Up @@ -375,7 +383,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)
is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.hessian_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.hessian_tracer_1_to_2(
d,
is_der1_out1_zero,
is_der2_out1_zero,
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Pkg.develop(;
)

using SparseConnectivityTracer
using Compat: pkgversion
using Documenter: Documenter, DocMeta
using Test

Expand Down
1 change: 0 additions & 1 deletion test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError
using Test

using Compat: Returns
using Random: rand, GLOBAL_RNG
using LinearAlgebra: det, dot, logdet

Expand Down
Loading