Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to fix thick disc transfer functions #203

Merged
merged 7 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/GradusSpectralModels/src/GradusSpectralModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function LineProfile(
E₀ = FitParam(1.0),
kwargs...,
)
setup = integration_setup(profile, table((get_value(θ), get_value(a))); kwargs...)
setup = integration_setup(profile, table((get_value(a), get_value(θ))); kwargs...)
LineProfile((; setup = setup, table = table), K, a, θ, rin, rout, E₀)
end

Expand Down
8 changes: 4 additions & 4 deletions src/corona/profiles/radial.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct RadialDiscProfile{T,I} <: AbstractDiscProfile
radii::Vector{T}
ε::Vector{T}
t::Vector{T}
struct RadialDiscProfile{V<:AbstractVector,I} <: AbstractDiscProfile
radii::V
ε::V
t::V
interp_ε::I
interp_t::I
end
Expand Down
2 changes: 1 addition & 1 deletion src/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function _tuple_set(tuple::NTuple{N}, index, v)::NTuple{N} where {N}
elseif index == N
(tuple[1:end-1]..., v)
else
(tuple[1:index-1], v, tuple[index+1:end])
(tuple[1:index-1]..., v, tuple[index+1:end]...)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/reverberation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function lag_frequency(
kwargs...,
) where {T}
other_kwargs, em_setup = EmissivityProfileSetup(T, spectrum; kwargs...)
solver_kwargs, tf_setup = _TransferFunctionSetup(T; other_kwargs...)
solver_kwargs, tf_setup = _TransferFunctionSetup(m, d; other_kwargs...)

prof = emissivity_profile(em_setup, m, d, model; solver_kwargs...)
t0 = continuum_time(m, x, model; solver_kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions src/tracing/precision-solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function _find_offset_for_measure(
d::AbstractAccretionGeometry,
θₒ;
zero_atol = 1e-7,
root_solver = DEFAULT_ROOT_SOLVER(),
offset_max = 20.0,
initial_r = offset_max / 2,
max_time = 2 * x[2],
Expand Down Expand Up @@ -48,6 +49,7 @@ function _find_offset_for_measure(

best = eltype(x)[0.0, 1.0]
r0_candidate, resid = root_solve(
root_solver,
_offset_objective,
initial_r,
(measure, _velfunc, _solve_geodesic, best);
Expand Down
72 changes: 48 additions & 24 deletions src/transfer-functions/cunningham-transfer-functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _promote_disc_for_transfer_functions(::ThinDisc{T}) where {T}
plane, plane
end

struct _TransferFunctionSetup{T}
struct _TransferFunctionSetup{T,A}
h::T
θ_offset::T
"Tolerance for root finding"
Expand All @@ -15,20 +15,33 @@ struct _TransferFunctionSetup{T}
β₀::T
N::Int
N_extrema::Int
root_solver::A
end

function _TransferFunctionSetup(
m::AbstractMetric{T};
m::AbstractMetric{T},
d::AbstractAccretionGeometry;
θ_offset = T(0.6),
zero_atol = T(1e-7),
N = 80,
N_extrema = 17,
α₀ = 0,
β₀ = 0,
h = T(1e-6),
root_solver = nothing,
kwargs...,
) where {T}
setup = _TransferFunctionSetup{T}(
# specialize the algorithm depending on whether we are calculating for a thin or thick disc
_alg = if isnothing(root_solver)
if d isa AbstractThickAccretionDisc
RootsAlg()
else
NonLinearAlg()
end
else
root_solver
end
setup = _TransferFunctionSetup{T,typeof(_alg)}(
h,
θ_offset,
zero_atol,
Expand All @@ -37,6 +50,7 @@ function _TransferFunctionSetup(
convert(T, β₀),
N,
N_extrema,
_alg,
)
kwargs, setup
end
Expand Down Expand Up @@ -196,6 +210,7 @@ function _setup_workhorse_jacobian_with_kwargs(
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
root_solver = setup.root_solver,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
Expand Down Expand Up @@ -260,22 +275,31 @@ function _rear_workhorse(
)
function _thick_workhorse::T)::NTuple{4,T} where {T}
g, gp, r = datum_workhorse(θ)
r₊, _ = _find_offset_for_radius(
m,
x,
d,
rₑ,
θ;
initial_r = r,
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
# don't echo warnings
warn = false,
)
r₊ = try
r_thick, _ = _find_offset_for_radius(
m,
x,
d,
rₑ,
θ;
initial_r = r,
zero_atol = setup.zero_atol,
root_solver = setup.root_solver,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
# don't echo warnings
warn = false,
)
r_thick
catch
# if we fail, for whatever reason, to root solve on the thick discs,
# we don't care, we just need a NaN value and then set that point to
# "not visible"
NaN
end
is_visible, J = if !isnan(r₊) && isapprox(r, r₊, atol = 1e-3)
# trace jacobian on updated impact parameters
α, β = _rθ_to_αβ(r₊, θ; α₀ = setup.α₀, β₀ = setup.β₀)
Expand All @@ -290,11 +314,12 @@ end

function _cunningham_transfer_function!(
data::_TransferDataAccumulator,
setup::_TransferFunctionSetup,
workhorse,
θiterator,
θ_offset,
rₑ,
)
θ_offset = setup.θ_offset
for (i, θ) in enumerate(θiterator)
θ_corrected = θ + 1e-4
insert_data!(data, i, θ_corrected, workhorse(θ))
Expand Down Expand Up @@ -328,7 +353,7 @@ function cunningham_transfer_function(
rₑ::T;
kwargs...,
) where {Q,T}
solver_kwargs, setup = _TransferFunctionSetup(m; kwargs...)
solver_kwargs, setup = _TransferFunctionSetup(m, d; kwargs...)
cunningham_transfer_function(setup, m, x, d, rₑ; solver_kwargs...)
end

Expand Down Expand Up @@ -362,8 +387,7 @@ function cunningham_transfer_function(
chart = chart,
solver_kwargs...,
)
gmin, gmax =
_cunningham_transfer_function!(data, workhorse, θiterator, setup.θ_offset, rₑ)
gmin, gmax = _cunningham_transfer_function!(data, setup, workhorse, θiterator, rₑ)
CunninghamTransferData(
data.data[2, :],
data.data[3, :],
Expand Down Expand Up @@ -424,7 +448,7 @@ function interpolated_transfer_branches(
radii;
kwargs...,
)
solver_kwargs, setup = _TransferFunctionSetup(m; kwargs...)
solver_kwargs, setup = _TransferFunctionSetup(m, d; kwargs...)
interpolated_transfer_branches(setup, m, x, d, radii; solver_kwargs...)
end

Expand Down
6 changes: 4 additions & 2 deletions src/transfer-functions/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ function _integrate_transfer_problem!(
transfer_function_radial_interpolation,
r_limits,
g_grid;
pure_radial = setup.pure_radial,
g_scale = 1,
) where {T}
g_grid_view = @views g_grid[1:end-1]
Expand All @@ -327,7 +328,7 @@ function _integrate_transfer_problem!(

Δrₑ = rₑ - r_prev
# integration weight for this annulus
θ = Δrₑ * rₑ * setup.pure_radial(rₑ) * π / (branch.gmax - branch.gmin)
θ = Δrₑ * rₑ * pure_radial(rₑ) * π / (branch.gmax - branch.gmin)

@inbounds for j in eachindex(g_grid_view)
glo = g_grid[j] / g_scale
Expand Down Expand Up @@ -355,6 +356,7 @@ function _integrate_transfer_problem!(
r_limits,
g_grid,
t_grid;
pure_radial = setup.pure_radial,
g_scale = 1,
) where {T}
g_grid_view = @views g_grid[1:end-1]
Expand All @@ -371,7 +373,7 @@ function _integrate_transfer_problem!(

Δrₑ = rₑ - r_prev
# integration weight for this annulus
θ = Δrₑ * rₑ * setup.pure_radial(rₑ) * π / (branch.gmax - branch.gmin)
θ = Δrₑ * rₑ * pure_radial(rₑ) * π / (branch.gmax - branch.gmin)

# time delay for this annuli
t_source_disc = setup.time(rₑ)
Expand Down
44 changes: 35 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
abstract type AbstractRootAlgorithm end
struct RootsAlg <: AbstractRootAlgorithm end
Base.@kwdef struct NonLinearAlg{A} <: AbstractRootAlgorithm
alg::A = SimpleNonlinearSolve.SimpleBroyden()
end

DEFAULT_ROOT_SOLVER() = NonLinearAlg()

"""
root_solve(f_objective, initial_value, args)
Wrapper to different root solving backends to make root solve fast and efficient
"""
function root_solve(
f_objective,
initial_value::T,
args;
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
root_solve(DEFAULT_ROOT_SOLVER(), f_objective, initial_value, args; kwargs...)
end
function root_solve(
::RootsAlg,
f_objective,
initial_value::T,
args;
abstol = 1e-9,
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
x0 = Roots.find_zero(
r -> f_objective(r, args),
initial_value,
Roots.Order0();
atol = abstol,
)
resid = f_objective(x0, args)
x0, resid
end
function root_solve(
alg::NonLinearAlg,
f_objective,
initial_value::T,
args;
abstol = 1e-9,
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
# Roots.find_zero(r -> f_objective(r, args), initial_value, Roots.Order0(); atol = abstol)
x0, f = if T <: Number
function _obj_wrapper(x::SVector, p)
@inbounds SVector{1,eltype(x)}(f_objective(x[1], p))
Expand All @@ -20,14 +53,7 @@ function root_solve(
initial_value, f_objective
end
prob = SimpleNonlinearSolve.NonlinearProblem{false}(f, x0, args)
sol = solve(
prob,
SimpleNonlinearSolve.SimpleBroyden();
abstol = abstol,
reltol = abstol,
maxiters = 500,
kwargs...,
)
sol = solve(prob, alg.alg, abstol = abstol, reltol = abstol, maxiters = 500, kwargs...)
sol.u[1], sol.resid[1]
end

Expand Down
15 changes: 14 additions & 1 deletion test/transfer-functions/test-thick-disc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,17 @@ d = ShakuraSunyaev(m)
tf = cunningham_transfer_function(m, x, d, 3.0; β₀ = 1.0)

total = sum(filter(!isnan, tf.f))
@test total 12.245276038643347 atol = 1e-4
@test total 12.253422180875667 atol = 1e-4

m = KerrMetric(1.0, 0.2)
x = SVector(0.0, 10_000, deg2rad(20), 0.0)
d = ShakuraSunyaev(m; eddington_ratio = 0.2)

tf = cunningham_transfer_function(m, x, d, 5.469668466100368; β₀ = 1.0)
total = sum(filter(!isnan, tf.f))
@test total 20.83469 atol = 1e-4

# the transfer function here is pretty horrible as it's almost impossible to actually
# see; this is a test to make sure it doesn't error
# an offset to the isco of 4-e2 resolves this, but that's quite a lot
tf = cunningham_transfer_function(m, x, d, Gradus.isco(m) + 1e-2; β₀ = 1.0)
23 changes: 21 additions & 2 deletions test/unit/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ ff(x, y) = 3x^2 + x * y - sin(y)
ff(x) = SVector(ff(x[1], x[2]))
X1 = collect(range(0, 1, 1000))
X2 = collect(range(0, 1, 1000))
vals = reshape([ff(x, y) for y in X1, x in X2], (length(X1), length(X2)))
vals = reshape([ff(x, y) for x in X1, y in X2], (length(X1), length(X2)))
cache = MultilinearInterpolator{2}(vals)

# check that dual cache works too
function _interpolate_wrapper(cache, X1, X2, vals)
function _f(x)
x2, x1 = x
x1, x2 = x
SVector(interpolate!(cache, (X1, X2), vals, (x1, x2)))
end
end
Expand Down Expand Up @@ -89,3 +89,22 @@ intp = Gradus.interpolate!(cache, (X1, X2), vals, (0.0, 1.5))
intp = Gradus.interpolate!(cache, (X1, X2), vals, (0.0, 1.5))
@test intp.a == [1.5]
@test intp.b == [1.5 0; 0 1.5]

# try higher dimensional interpolation

X1 = range(0.0, 1.0, 3)
X2 = range(1.0, 2.0, 4)
X3 = range(-1.0, 0.0, 2)

f3(x, y, z) = 2x + 3y + 7z
vals = [f3(x, y, z) for x in X1, y in X2, z in X3]

cache = MultilinearInterpolator{3}(vals)
intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (0.5, 1.5, -0.5))
@test intp == f3(0.5, 1.5, -0.5)

intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (1.0, 1.1, -0.1))
@test intp == f3(1.0, 1.1, -0.1)

intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (1.0, 1.9, -0.1))
@test intp == f3(1.0, 1.9, -0.1)
Loading