Skip to content

Commit

Permalink
Merge pull request #203 from astro-group-bristol/fergus/fix-thick-dis…
Browse files Browse the repository at this point in the history
…c-tf

Attempt to fix thick disc transfer functions
  • Loading branch information
fjebaker authored Jun 27, 2024
2 parents ba632e6 + c9c7907 commit 4ff82e9
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 45 deletions.
2 changes: 1 addition & 1 deletion lib/GradusSpectralModels/src/GradusSpectralModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function LineProfile(
E₀ = FitParam(1.0),
kwargs...,
)
setup = integration_setup(profile, table((get_value(θ), get_value(a))); kwargs...)
setup = integration_setup(profile, table((get_value(a), get_value(θ))); kwargs...)
LineProfile((; setup = setup, table = table), K, a, θ, rin, rout, E₀)
end

Expand Down
8 changes: 4 additions & 4 deletions src/corona/profiles/radial.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct RadialDiscProfile{T,I} <: AbstractDiscProfile
radii::Vector{T}
ε::Vector{T}
t::Vector{T}
struct RadialDiscProfile{V<:AbstractVector,I} <: AbstractDiscProfile
radii::V
ε::V
t::V
interp_ε::I
interp_t::I
end
Expand Down
2 changes: 1 addition & 1 deletion src/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function _tuple_set(tuple::NTuple{N}, index, v)::NTuple{N} where {N}
elseif index == N
(tuple[1:end-1]..., v)
else
(tuple[1:index-1], v, tuple[index+1:end])
(tuple[1:index-1]..., v, tuple[index+1:end]...)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/reverberation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function lag_frequency(
kwargs...,
) where {T}
other_kwargs, em_setup = EmissivityProfileSetup(T, spectrum; kwargs...)
solver_kwargs, tf_setup = _TransferFunctionSetup(T; other_kwargs...)
solver_kwargs, tf_setup = _TransferFunctionSetup(m, d; other_kwargs...)

prof = emissivity_profile(em_setup, m, d, model; solver_kwargs...)
t0 = continuum_time(m, x, model; solver_kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions src/tracing/precision-solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function _find_offset_for_measure(
d::AbstractAccretionGeometry,
θₒ;
zero_atol = 1e-7,
root_solver = DEFAULT_ROOT_SOLVER(),
offset_max = 20.0,
initial_r = offset_max / 2,
max_time = 2 * x[2],
Expand Down Expand Up @@ -48,6 +49,7 @@ function _find_offset_for_measure(

best = eltype(x)[0.0, 1.0]
r0_candidate, resid = root_solve(
root_solver,
_offset_objective,
initial_r,
(measure, _velfunc, _solve_geodesic, best);
Expand Down
72 changes: 48 additions & 24 deletions src/transfer-functions/cunningham-transfer-functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function _promote_disc_for_transfer_functions(::ThinDisc{T}) where {T}
plane, plane
end

struct _TransferFunctionSetup{T}
struct _TransferFunctionSetup{T,A}
h::T
θ_offset::T
"Tolerance for root finding"
Expand All @@ -15,20 +15,33 @@ struct _TransferFunctionSetup{T}
β₀::T
N::Int
N_extrema::Int
root_solver::A
end

function _TransferFunctionSetup(
m::AbstractMetric{T};
m::AbstractMetric{T},
d::AbstractAccretionGeometry;
θ_offset = T(0.6),
zero_atol = T(1e-7),
N = 80,
N_extrema = 17,
α₀ = 0,
β₀ = 0,
h = T(1e-6),
root_solver = nothing,
kwargs...,
) where {T}
setup = _TransferFunctionSetup{T}(
# specialize the algorithm depending on whether we are calculating for a thin or thick disc
_alg = if isnothing(root_solver)
if d isa AbstractThickAccretionDisc
RootsAlg()
else
NonLinearAlg()
end
else
root_solver
end
setup = _TransferFunctionSetup{T,typeof(_alg)}(
h,
θ_offset,
zero_atol,
Expand All @@ -37,6 +50,7 @@ function _TransferFunctionSetup(
convert(T, β₀),
N,
N_extrema,
_alg,
)
kwargs, setup
end
Expand Down Expand Up @@ -196,6 +210,7 @@ function _setup_workhorse_jacobian_with_kwargs(
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
root_solver = setup.root_solver,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
Expand Down Expand Up @@ -260,22 +275,31 @@ function _rear_workhorse(
)
function _thick_workhorse::T)::NTuple{4,T} where {T}
g, gp, r = datum_workhorse(θ)
r₊, _ = _find_offset_for_radius(
m,
x,
d,
rₑ,
θ;
initial_r = r,
zero_atol = setup.zero_atol,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
# don't echo warnings
warn = false,
)
r₊ = try
r_thick, _ = _find_offset_for_radius(
m,
x,
d,
rₑ,
θ;
initial_r = r,
zero_atol = setup.zero_atol,
root_solver = setup.root_solver,
offset_max = offset_max,
max_time = max_time,
β₀ = setup.β₀,
α₀ = setup.α₀,
tracer_kwargs...,
# don't echo warnings
warn = false,
)
r_thick
catch
# if we fail, for whatever reason, to root solve on the thick discs,
# we don't care, we just need a NaN value and then set that point to
# "not visible"
NaN
end
is_visible, J = if !isnan(r₊) && isapprox(r, r₊, atol = 1e-3)
# trace jacobian on updated impact parameters
α, β = _rθ_to_αβ(r₊, θ; α₀ = setup.α₀, β₀ = setup.β₀)
Expand All @@ -290,11 +314,12 @@ end

function _cunningham_transfer_function!(
data::_TransferDataAccumulator,
setup::_TransferFunctionSetup,
workhorse,
θiterator,
θ_offset,
rₑ,
)
θ_offset = setup.θ_offset
for (i, θ) in enumerate(θiterator)
θ_corrected = θ + 1e-4
insert_data!(data, i, θ_corrected, workhorse(θ))
Expand Down Expand Up @@ -328,7 +353,7 @@ function cunningham_transfer_function(
rₑ::T;
kwargs...,
) where {Q,T}
solver_kwargs, setup = _TransferFunctionSetup(m; kwargs...)
solver_kwargs, setup = _TransferFunctionSetup(m, d; kwargs...)
cunningham_transfer_function(setup, m, x, d, rₑ; solver_kwargs...)
end

Expand Down Expand Up @@ -362,8 +387,7 @@ function cunningham_transfer_function(
chart = chart,
solver_kwargs...,
)
gmin, gmax =
_cunningham_transfer_function!(data, workhorse, θiterator, setup.θ_offset, rₑ)
gmin, gmax = _cunningham_transfer_function!(data, setup, workhorse, θiterator, rₑ)
CunninghamTransferData(
data.data[2, :],
data.data[3, :],
Expand Down Expand Up @@ -424,7 +448,7 @@ function interpolated_transfer_branches(
radii;
kwargs...,
)
solver_kwargs, setup = _TransferFunctionSetup(m; kwargs...)
solver_kwargs, setup = _TransferFunctionSetup(m, d; kwargs...)
interpolated_transfer_branches(setup, m, x, d, radii; solver_kwargs...)
end

Expand Down
6 changes: 4 additions & 2 deletions src/transfer-functions/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ function _integrate_transfer_problem!(
transfer_function_radial_interpolation,
r_limits,
g_grid;
pure_radial = setup.pure_radial,
g_scale = 1,
) where {T}
g_grid_view = @views g_grid[1:end-1]
Expand All @@ -327,7 +328,7 @@ function _integrate_transfer_problem!(

Δrₑ = rₑ - r_prev
# integration weight for this annulus
θ = Δrₑ * rₑ * setup.pure_radial(rₑ) * π / (branch.gmax - branch.gmin)
θ = Δrₑ * rₑ * pure_radial(rₑ) * π / (branch.gmax - branch.gmin)

@inbounds for j in eachindex(g_grid_view)
glo = g_grid[j] / g_scale
Expand Down Expand Up @@ -355,6 +356,7 @@ function _integrate_transfer_problem!(
r_limits,
g_grid,
t_grid;
pure_radial = setup.pure_radial,
g_scale = 1,
) where {T}
g_grid_view = @views g_grid[1:end-1]
Expand All @@ -371,7 +373,7 @@ function _integrate_transfer_problem!(

Δrₑ = rₑ - r_prev
# integration weight for this annulus
θ = Δrₑ * rₑ * setup.pure_radial(rₑ) * π / (branch.gmax - branch.gmin)
θ = Δrₑ * rₑ * pure_radial(rₑ) * π / (branch.gmax - branch.gmin)

# time delay for this annuli
t_source_disc = setup.time(rₑ)
Expand Down
44 changes: 35 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
abstract type AbstractRootAlgorithm end
struct RootsAlg <: AbstractRootAlgorithm end
Base.@kwdef struct NonLinearAlg{A} <: AbstractRootAlgorithm
alg::A = SimpleNonlinearSolve.SimpleBroyden()
end

DEFAULT_ROOT_SOLVER() = NonLinearAlg()

"""
root_solve(f_objective, initial_value, args)
Wrapper to different root solving backends to make root solve fast and efficient
"""
function root_solve(
f_objective,
initial_value::T,
args;
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
root_solve(DEFAULT_ROOT_SOLVER(), f_objective, initial_value, args; kwargs...)
end
function root_solve(
::RootsAlg,
f_objective,
initial_value::T,
args;
abstol = 1e-9,
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
x0 = Roots.find_zero(
r -> f_objective(r, args),
initial_value,
Roots.Order0();
atol = abstol,
)
resid = f_objective(x0, args)
x0, resid
end
function root_solve(
alg::NonLinearAlg,
f_objective,
initial_value::T,
args;
abstol = 1e-9,
kwargs...,
) where {T<:Union{<:Number,<:SVector{1}}}
# Roots.find_zero(r -> f_objective(r, args), initial_value, Roots.Order0(); atol = abstol)
x0, f = if T <: Number
function _obj_wrapper(x::SVector, p)
@inbounds SVector{1,eltype(x)}(f_objective(x[1], p))
Expand All @@ -20,14 +53,7 @@ function root_solve(
initial_value, f_objective
end
prob = SimpleNonlinearSolve.NonlinearProblem{false}(f, x0, args)
sol = solve(
prob,
SimpleNonlinearSolve.SimpleBroyden();
abstol = abstol,
reltol = abstol,
maxiters = 500,
kwargs...,
)
sol = solve(prob, alg.alg, abstol = abstol, reltol = abstol, maxiters = 500, kwargs...)
sol.u[1], sol.resid[1]
end

Expand Down
15 changes: 14 additions & 1 deletion test/transfer-functions/test-thick-disc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,17 @@ d = ShakuraSunyaev(m)
tf = cunningham_transfer_function(m, x, d, 3.0; β₀ = 1.0)

total = sum(filter(!isnan, tf.f))
@test total 12.245276038643347 atol = 1e-4
@test total 12.253422180875667 atol = 1e-4

m = KerrMetric(1.0, 0.2)
x = SVector(0.0, 10_000, deg2rad(20), 0.0)
d = ShakuraSunyaev(m; eddington_ratio = 0.2)

tf = cunningham_transfer_function(m, x, d, 5.469668466100368; β₀ = 1.0)
total = sum(filter(!isnan, tf.f))
@test total 20.83469 atol = 1e-4

# the transfer function here is pretty horrible as it's almost impossible to actually
# see; this is a test to make sure it doesn't error
# an offset to the isco of 4-e2 resolves this, but that's quite a lot
tf = cunningham_transfer_function(m, x, d, Gradus.isco(m) + 1e-2; β₀ = 1.0)
23 changes: 21 additions & 2 deletions test/unit/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ ff(x, y) = 3x^2 + x * y - sin(y)
ff(x) = SVector(ff(x[1], x[2]))
X1 = collect(range(0, 1, 1000))
X2 = collect(range(0, 1, 1000))
vals = reshape([ff(x, y) for y in X1, x in X2], (length(X1), length(X2)))
vals = reshape([ff(x, y) for x in X1, y in X2], (length(X1), length(X2)))
cache = MultilinearInterpolator{2}(vals)

# check that dual cache works too
function _interpolate_wrapper(cache, X1, X2, vals)
function _f(x)
x2, x1 = x
x1, x2 = x
SVector(interpolate!(cache, (X1, X2), vals, (x1, x2)))
end
end
Expand Down Expand Up @@ -89,3 +89,22 @@ intp = Gradus.interpolate!(cache, (X1, X2), vals, (0.0, 1.5))
intp = Gradus.interpolate!(cache, (X1, X2), vals, (0.0, 1.5))
@test intp.a == [1.5]
@test intp.b == [1.5 0; 0 1.5]

# try higher dimensional interpolation

X1 = range(0.0, 1.0, 3)
X2 = range(1.0, 2.0, 4)
X3 = range(-1.0, 0.0, 2)

f3(x, y, z) = 2x + 3y + 7z
vals = [f3(x, y, z) for x in X1, y in X2, z in X3]

cache = MultilinearInterpolator{3}(vals)
intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (0.5, 1.5, -0.5))
@test intp == f3(0.5, 1.5, -0.5)

intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (1.0, 1.1, -0.1))
@test intp == f3(1.0, 1.1, -0.1)

intp = Gradus.interpolate!(cache, (X1, X2, X3), vals, (1.0, 1.9, -0.1))
@test intp == f3(1.0, 1.9, -0.1)

0 comments on commit 4ff82e9

Please sign in to comment.