From 73b95f5432cdaa5fbc278a6cc6a4e74ec7ae641d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 22 Jan 2024 13:54:48 +0100 Subject: [PATCH 1/2] Improve _dist_params_numtype --- src/transforms/distribution_transform.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transforms/distribution_transform.jl b/src/transforms/distribution_transform.jl index d1e60bb67..4a84a0f96 100644 --- a/src/transforms/distribution_transform.jl +++ b/src/transforms/distribution_transform.jl @@ -319,7 +319,12 @@ function apply_dist_trafo(trg_d::DT, src_d::DT, src_v) where {DT <: StdMvDist} end -_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...) +_dist_params_numtype(d::Distribution) = realnumtype(typeof(params(d))) + +function ChainRulesCore.rrule(::typeof(_dist_params_numtype), d::Distribution) + _dist_params_numtype_pullback(ΔΩ) = (NoTangent(), NoTangent()) + _dist_params_numtype(d), _dist_params_numtype_pullback +end @inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) @@ -370,7 +375,7 @@ end @inline function _eval_dist_trafo_func(f::typeof(_trafo_cdf), d::Distribution{Univariate,Continuous}, src_v::Real) - R_V = float(promote_type(typeof(src_v), eltype(params(d)))) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) if insupport(d, src_v) trg_v = f(d, src_v) convert(R_V, trg_v) @@ -380,7 +385,7 @@ end end @inline function _eval_dist_trafo_func(f::typeof(_trafo_quantile), d::Distribution{Univariate,Continuous}, src_v::Real) - R_V = float(promote_type(typeof(src_v), eltype(params(d)))) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) if 0 <= src_v <= 1 trg_v = f(d, src_v) convert(R_V, trg_v) From 7c3a123c3b538ef34610c00f1b8f172a215013e1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 29 Dec 2023 11:02:14 +0100 Subject: [PATCH 2/2] Relax FillArrays version bound --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ad0ebd5dd..ea43738c6 100644 --- a/Project.toml +++ b/Project.toml @@ -106,7 +106,7 @@ DoubleFloats = "0.9, 1" ElasticArrays = "1.2.3" EmpiricalDistributions = "0.2, 0.3.1" FFTW = "1" -FillArrays = "1.1.1" +FillArrays = "0.13, 1.1.1" Folds = "0.2" ForwardDiff = "0.10" ForwardDiffPullbacks = "0.1.1, 0.2"