From 6a1063daed42f66214524c53d256fa7c5b2da6d7 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Sat, 13 Aug 2022 12:22:15 -0700 Subject: [PATCH] Add `ConvergenceException` from StatsBase I've added the folks who contributed the type and improvements to it and its documentation as coauthors of the commit. (Apologies if I missed anyone.) Co-authored-by: Roger Herikstad Co-authored-by: Galen Lynch Co-authored-by: Alexander Morley --- src/statisticalmodel.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/statisticalmodel.jl | 17 +++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/statisticalmodel.jl b/src/statisticalmodel.jl index e8f4f68..3b03adc 100644 --- a/src/statisticalmodel.jl +++ b/src/statisticalmodel.jl @@ -314,3 +314,41 @@ function adjr2(model::StatisticalModel, variant::Symbol) end const adjr² = adjr2 + +""" + ConvergenceException(iterations::Int, lastchange::Real=NaN, tolerance::Real=NaN, + message::String="") + +The fitting procedure failed to converge in `iterations` number of iterations. Typically +this is because the `lastchange` between the objective in the final and penultimate +iterations was greater than the specified `tolerance`. Further information can be provided +by `message`. +""" +struct ConvergenceException{T<:Real} <: Exception + iterations::Int + lastchange::T + tolerance::T + message::String + + function ConvergenceException(iterations, lastchange=NaN, tolerance=NaN, message="") + if tolerance > lastchange + throw(ArgumentError("can't construct `ConvergenceException` with change " * + "less than tolerance; got $lastchange and $tolerance")) + end + T = promote_type(typeof(lastchange), typeof(tolerance)) + return new{T}(iterations, lastchange, tolerance, message) + end +end + +function Base.showerror(io::IO, ce::ConvergenceException) + print(io, "failure to converge after ", ce.iterations, " iterations") + if !isnan(ce.lastchange) + print(io, "; last change between iterations (", ce.lastchange, ") was greater ", + "than tolerance (", ce.tolerance, ")") + end + print(io, '.') + if !isempty(ce.message) + print(io, ' ', ce.message) + end + return nothing +end diff --git a/test/statisticalmodel.jl b/test/statisticalmodel.jl index 114a68e..c79a17c 100644 --- a/test/statisticalmodel.jl +++ b/test/statisticalmodel.jl @@ -1,7 +1,8 @@ module TestStatisticalModel using Test, StatsAPI -using StatsAPI: StatisticalModel, stderror, aic, aicc, bic, r2, r², adjr2, adjr² +using StatsAPI: ConvergenceException, StatisticalModel, stderror, aic, aicc, bic, + r2, r², adjr2, adjr² struct MyStatisticalModel <: StatisticalModel end @@ -36,4 +37,16 @@ StatsAPI.nobs(::MyStatisticalModel) = 100 @test adjr2 === adjr² end -end # module TestStatisticalModel \ No newline at end of file +@testset "ConvergenceException" begin + fail = "failure to converge after 10 iterations" + chgtol = "last change between iterations (0.2) was greater than tolerance (0.1)" + msg = "Try changing maxiter." + @test sprint(showerror, ConvergenceException(10)) == "$fail." + @test sprint(showerror, ConvergenceException(10, 0.2, 0.1)) == "$fail; $chgtol." + @test sprint(showerror, ConvergenceException(10, 0.2, 0.1, msg)) == "$fail; $chgtol. $msg" + err = @test_throws ArgumentError ConvergenceException(10, 0.1, 0.2) + @test err.value.msg == string("can't construct `ConvergenceException` with change ", + "less than tolerance; got 0.1 and 0.2") +end + +end # module TestStatisticalModel