Skip to content

Commit

Permalink
Improve _dist_params_numtype
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jan 22, 2024
1 parent 63553c3 commit 73b95f5
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transforms/distribution_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ function apply_dist_trafo(trg_d::DT, src_d::DT, src_v) where {DT <: StdMvDist}
end


_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...)
_dist_params_numtype(d::Distribution) = realnumtype(typeof(params(d)))

function ChainRulesCore.rrule(::typeof(_dist_params_numtype), d::Distribution)
_dist_params_numtype_pullback(ΔΩ) = (NoTangent(), NoTangent())
_dist_params_numtype(d), _dist_params_numtype_pullback
end


@inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x)
Expand Down Expand Up @@ -370,7 +375,7 @@ end


@inline function _eval_dist_trafo_func(f::typeof(_trafo_cdf), d::Distribution{Univariate,Continuous}, src_v::Real)
R_V = float(promote_type(typeof(src_v), eltype(params(d))))
R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d)))
if insupport(d, src_v)
trg_v = f(d, src_v)
convert(R_V, trg_v)
Expand All @@ -380,7 +385,7 @@ end
end

@inline function _eval_dist_trafo_func(f::typeof(_trafo_quantile), d::Distribution{Univariate,Continuous}, src_v::Real)
R_V = float(promote_type(typeof(src_v), eltype(params(d))))
R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d)))
if 0 <= src_v <= 1
trg_v = f(d, src_v)
convert(R_V, trg_v)
Expand Down

0 comments on commit 73b95f5

Please sign in to comment.