diff --git a/src/transforms/distribution_transform.jl b/src/transforms/distribution_transform.jl index d1e60bb67..4a84a0f96 100644 --- a/src/transforms/distribution_transform.jl +++ b/src/transforms/distribution_transform.jl @@ -319,7 +319,12 @@ function apply_dist_trafo(trg_d::DT, src_d::DT, src_v) where {DT <: StdMvDist} end -_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...) +_dist_params_numtype(d::Distribution) = realnumtype(typeof(params(d))) + +function ChainRulesCore.rrule(::typeof(_dist_params_numtype), d::Distribution) + _dist_params_numtype_pullback(ΔΩ) = (NoTangent(), NoTangent()) + _dist_params_numtype(d), _dist_params_numtype_pullback +end @inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) @@ -370,7 +375,7 @@ end @inline function _eval_dist_trafo_func(f::typeof(_trafo_cdf), d::Distribution{Univariate,Continuous}, src_v::Real) - R_V = float(promote_type(typeof(src_v), eltype(params(d)))) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) if insupport(d, src_v) trg_v = f(d, src_v) convert(R_V, trg_v) @@ -380,7 +385,7 @@ end end @inline function _eval_dist_trafo_func(f::typeof(_trafo_quantile), d::Distribution{Univariate,Continuous}, src_v::Real) - R_V = float(promote_type(typeof(src_v), eltype(params(d)))) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) if 0 <= src_v <= 1 trg_v = f(d, src_v) convert(R_V, trg_v)