Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Sep 25, 2024
1 parent 4cf6987 commit e658a4c
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 80 deletions.
4 changes: 3 additions & 1 deletion src/ipod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,9 @@ From Maronna et al., Robust Statistics: Theory and Methods, Equation 4.49
"""
location_variance(r::IPODRegression, sqr::Bool=false) = dispersion(r, sqr)

StatsAPI.stderror(r::IPODRegression) = location_variance(r, false) .* sqrt.(abs.(diag(vcov(r))))
function StatsAPI.stderror(r::IPODRegression)
return location_variance(r, false) .* sqrt.(abs.(diag(vcov(r))))
end

## Loglikelihood of the full model
## l = Σi log fi = Σi log ( 1/(σ * Z) exp( - ρ(ri/σ) ) = -n (log σ + log Z) - Σi ρ(ri/σ)
Expand Down
4 changes: 1 addition & 3 deletions src/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,7 @@ function proximal!(
# out[i] = zero(T)
# else
x1 = sign(x[i]) * min(λ, max(0, abs(x[i]) - λ / step))
x2 =
sign(x[i]) *
min* γ, max(λ, (abs(x[i]) *- 1) - λ * γ / step) /- 2)))
x2 = sign(x[i]) * min* γ, max(λ, (abs(x[i]) *- 1) - λ * γ / step) /- 2)))
x3 = sign(x[i]) * min* γ, abs(x[i]))
sols = [x1, x2, x3]
ind = argmin(convex_approx.(p, sols, x[i], step))
Expand Down
83 changes: 37 additions & 46 deletions src/regularizedpred.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,27 +447,20 @@ struct CGDRegPred{T<:BlasReal,M<:AbstractMatrix{T},V<:Vector{T},P<:PenaltyFuncti
if !isseparable(penalty)
error(
"Coordinate Gradient Descent is only allowed with a separable penalty: " *
"$(penalty)"
"$(penalty)",
)
end

params = CGDParams(outerloop)

return new{T,M,typeof(invvar),P}(
X,
zeros(T, m),
zeros(T, m),
Σ,
penalty,
invvar,
params,
zeros(T, m),
X, zeros(T, m), zeros(T, m), Σ, penalty, invvar, params, zeros(T, m)
)
end
end

function CGDRegPred(
X::M, penalty::P, wts::V, outerloop::Bool=true,
X::M, penalty::P, wts::V, outerloop::Bool=true
) where {M<:AbstractMatrix{T},V<:AbstractVector,P<:PenaltyFunction} where {T<:BlasReal}
n, m = size(X)
ll = size(wts, 1)
Expand Down Expand Up @@ -525,7 +518,7 @@ function update_βμ!(
outerM = p.params.outerloop ? m : 1
update_all = p.params.update_all

# copyto!(newβ, β)
# copyto!(newβ, β)
# outer loop
@inbounds for _ in 1:outerM

Expand All @@ -539,19 +532,19 @@ function update_βμ!(

# remove component due to index j in μ
# µ -= X[:, j] * β[j]
# if !iszero(βj)
# @inbounds @simd for i in eachindex(μ, Xj)
# μ[i] -= Xj[i] * β[j]
# end
## broadcast!(muladd, μ, -β[j], Xj, μ)
# end
# if !iszero(βj)
# @inbounds @simd for i in eachindex(μ, Xj)
# μ[i] -= Xj[i] * β[j]
# end
## broadcast!(muladd, μ, -β[j], Xj, μ)
# end

newβ[j] = βj
if isempty(wts)
@inbounds @simd for i in eachindex(axes(X, 1), res, Xj)
newβ[j] += Xj[i] * res[i]
# @inbounds @simd for i in eachindex(axes(X, 1), y, μ, Xj)
# newβ[j] += Xj[i] * (y[i] - μ[i] + Xj[i] * βj)
# @inbounds @simd for i in eachindex(axes(X, 1), y, μ, Xj)
# newβ[j] += Xj[i] * (y[i] - μ[i] + Xj[i] * βj)
end
else
@inbounds @simd for i in eachindex(axes(X, 1), res, Xj, wts)
Expand All @@ -564,18 +557,18 @@ function update_βμ!(
# verbose && println("∇βj before prox: $(gradβ)")
# Proximal operator on a single coordinate
proximal!(pen, newβ, j, newβ, invvar[j] * σ2)
# proximal!(pen, β, j, newβ, invvar[j] * σ2)
# proximal!(pen, β, j, newβ, invvar[j] * σ2)
# verbose && println("β after prox: $(p.β)")

# re-add component due to index j in μ
# μj = p.X[:, j] * p.β[j]
# p.µ .+= μj
## if !iszero(β[j])
## @inbounds @simd for i in eachindex(μ, Xj)
## μ[i] += Xj[i] * β[j]
## end
### broadcast!(muladd, μ, β[j], view(X, :, j), μ)
## end
## if !iszero(β[j])
## @inbounds @simd for i in eachindex(μ, Xj)
## μ[i] += Xj[i] * β[j]
## end
### broadcast!(muladd, μ, β[j], view(X, :, j), μ)
## end
Δβ = newβ[j] - βj
if !iszero(Δβ)
β[j] = newβ[j]
Expand Down Expand Up @@ -630,9 +623,7 @@ mutable struct FISTAParams{T<:AbstractFloat}
function FISTAParams{T}(
restart::Bool, use_backtracking::Bool, bt_maxiter::Integer=20, bt_delta::Real=0.5
) where {T<:AbstractFloat}
return new{T}(
restart, use_backtracking, bt_maxiter, bt_delta, one(T), one(T), 1,
)
return new{T}(restart, use_backtracking, bt_maxiter, bt_delta, one(T), one(T), 1)
end
end

Expand Down Expand Up @@ -913,7 +904,7 @@ struct AMARegPred{
T<:BlasReal,
M<:AbstractMatrix{T},
V<:Vector{T},
# C,
# C,
P<:PenaltyFunction,
M2<:AbstractMatrix{T},
V2<:AbstractVector{T},
Expand All @@ -932,8 +923,8 @@ struct AMARegPred{
scratchbeta::V
"`penbeta0`: vector of length `p`, used in [`penalized_coef`](@ref) method"
penbeta0::V
# "`chol`: cholesky factorization"
# chol::C
# "`chol`: cholesky factorization"
# chol::C
"`wXt`: transpose of the (weighted) model matrix"
wXt::M
"`A`: matrix of the constraint equation `A . u - b - v = 0`"
Expand Down Expand Up @@ -961,12 +952,12 @@ struct AMARegPred{
) where {M<:AbstractMatrix{T},P<:PenaltyFunction} where {T<:BlasReal}
m = size(Σ, 1)

# chol = cholesky(Σ)
# chol = cholesky(Σ)
bhat = A * b
params = AMAParams{T}(restart)

beta0 = zeros(T, m)
# return new{T,M,typeof(beta0),typeof(chol),P,typeof(A),typeof(b)}(
# return new{T,M,typeof(beta0),typeof(chol),P,typeof(A),typeof(b)}(
return new{T,M,typeof(beta0),P,typeof(A),typeof(b)}(
X,
beta0,
Expand All @@ -975,7 +966,7 @@ struct AMARegPred{
penalty,
zeros(T, m),
zeros(T, m),
# chol,
# chol,
wXt,
A,
bhat,
Expand Down Expand Up @@ -1055,7 +1046,7 @@ function update_beta!(
wXt = p.wXt
βkp1 = p.delbeta
scratch = p.scratchbeta
# chol = p.chol
# chol = p.chol
A = p.A
b = p.b

Expand Down Expand Up @@ -1085,7 +1076,7 @@ function update_beta!(
# βkp1 = chol \ (wXt * wrky - ρ * σ2 * A' * whatk)
scratch = mul!(scratch, A', whatk)
scratch = mul!(scratch, wXt, y, 1, -ρ * σ2)
# βkp1 = chol \ scratch
# βkp1 = chol \ scratch
cg!(βkp1, p.Σ, scratch)
# vkp1 = proximal(penalty(p), A * βkp1 - b + whatk, 1/ρ)
scratch = mul!(copyto!(scratch, b), A, βkp1, 1, -1)
Expand Down Expand Up @@ -1217,7 +1208,7 @@ struct ADMMRegPred{
T<:BlasReal,
M<:AbstractMatrix{T},
V<:Vector{T},
# C,
# C,
P<:PenaltyFunction,
M2<:AbstractMatrix{T},
V2<:AbstractVector{T},
Expand All @@ -1236,8 +1227,8 @@ struct ADMMRegPred{
scratchbeta::V
"`penbeta0`: vector of length `p`, used in [`penalized_coef`](@ref) method"
penbeta0::V
# "`chol`: cholesky factorization"
# chol::C
# "`chol`: cholesky factorization"
# chol::C
"`wXt`: transpose of the (weighted) model matrix"
wXt::M
"`A`: matrix of the constraint equation `A . u - b - v = 0`"
Expand Down Expand Up @@ -1269,12 +1260,12 @@ struct ADMMRegPred{
) where {M<:AbstractMatrix{T},P<:PenaltyFunction} where {T<:BlasReal}
m = size(Σ, 1)

# chol = cholesky(sparse(Σ))
# chol = cholesky(sparse(Σ))
bhat = A * b
params = ADMMParams{T}(restart, adapt)

beta0 = zeros(T, m)
# return new{T,M,typeof(beta0),typeof(chol),P,typeof(A),typeof(b)}(
# return new{T,M,typeof(beta0),typeof(chol),P,typeof(A),typeof(b)}(
return new{T,M,typeof(beta0),P,typeof(A),typeof(b)}(
X,
beta0,
Expand All @@ -1283,7 +1274,7 @@ struct ADMMRegPred{
penalty,
zeros(T, m),
zeros(T, m),
# chol,
# chol,
wXt,
A,
bhat,
Expand Down Expand Up @@ -1408,7 +1399,7 @@ function update_beta!(
broadcast!(+, βkp1, broadcast!(-, βkp1, vhatk, whatk), b) # dumb βkp1
mul!(scratch, A', βkp1) # dumb βkp1
mul!(scratch, wXt, y, 1, ρ * σ2)
# βkp1 = facΣρ \ scratch
# βkp1 = facΣρ \ scratch
cg!(βkp1, wrkΣ, scratch)
# proximal!(penalty(p), vkp1, A * βkp1 - b + whatk, 1/ρ)
scratch = mul!(copyto!(scratch, b), A, βkp1, 1, -1)
Expand Down Expand Up @@ -1457,8 +1448,8 @@ function updatepred!(p::ADMMRegPred, σ::Real; verbose::Bool=false, force::Bool=

# Refactorize
# TODO: should be improved
# cholesky!(p.chol, sparse(Σ + ρ * σ^2 * Hermitian(A'A)))
# p.wrkΣ = Hermitian(p.Σ + σ2 * ρ * A'A)
# cholesky!(p.chol, sparse(Σ + ρ * σ^2 * Hermitian(A'A)))
# p.wrkΣ = Hermitian(p.Σ + σ2 * ρ * A'A)
copyto!(p.wrkΣ, p.Σ)
mul!(p.wrkΣ, A', A, ρ * σ^2, 1)
copyto!(p.wrkΣ, Hermitian(p.wrkΣ))
Expand Down
8 changes: 7 additions & 1 deletion src/robustlinearmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,13 @@ function StatsAPI.fit!(
m::RobustLinearModel{T,R,P};
correct_leverage::Bool=false,
updatescale::Bool=false,
maxiter::Integer=(P<:CGDRegPred) ? 10_000 : (P<:FISTARegPred) ? 1_000 : 100,
maxiter::Integer=if (P <: CGDRegPred)
10_000
elseif (P <: FISTARegPred)
1_000
else
100
end,
minstepfac::Real=1e-3,
atol::Real=1e-8,
rtol::Real=1e-7,
Expand Down
3 changes: 1 addition & 2 deletions test/data/Animals2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,4 @@ sX = SparseMatrixCSC(X)
y = data.logBrain
nt = (; logBrain=data.logBrain, logBody=data.logBody)

data_tuples = ((form, data), (form, nt), (X, y), (sX, y))
;
data_tuples = ((form, data), (form, nt), (X, y), (sX, y));
3 changes: 1 addition & 2 deletions test/data/starsCYG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,4 @@ sX2 = SparseMatrixCSC(X2)
y2 = data2.logLight
nt2 = (; logTemp=logTemp, logLight=logLight)

data2_tuples = ((form2, data2), (form2, nt2), (X2, y2), (sX2, y2))
;
data2_tuples = ((form2, data2), (form2, nt2), (X2, y2), (sX2, y2));
38 changes: 29 additions & 9 deletions test/ipod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,13 @@ end
@testset "data type: $(typeof(A))" for (A, b) in ((X, y), (sX, y))
name = "Θ-IPOD($(l), $(pen); method=$(method)),\t"
name *= (
(A isa FormulaTerm) ? "formula" :
(A isa SparseMatrixCSC) ? "sparse " : "dense "
if (A isa FormulaTerm)
"formula"
elseif (A isa SparseMatrixCSC)
"sparse "
else
"dense "
end
)

m0 = fit(RobustLinearModel, A, b, est; method=:chol, initial_scale=σ)
Expand Down Expand Up @@ -130,7 +135,7 @@ end
def_rtol = 1e-5

kwargs = (; initial_scale=1, maxiter=10_000)

@testset "solver method $(method)" for method in pen_methods
rtol = def_rtol
if method === :cgd
Expand All @@ -145,8 +150,13 @@ end
@testset "data type: $(typeof(A))" for (A, b) in data_tuples
name = "Θ-IPOD(L2Loss, $(pen2); method=$(method)),\t"
name *= (
(A isa FormulaTerm) ? "formula" :
(A isa SparseMatrixCSC) ? "sparse " : "dense "
if (A isa FormulaTerm)
"formula"
elseif (A isa SparseMatrixCSC)
"sparse "
else
"dense "
end
)

m2 = ipod(A, b, loss1, pen2; method=method, kwargs...)
Expand Down Expand Up @@ -208,8 +218,13 @@ end
@testset "Omit intercept, data type: $(typeof(A))" for (A, b) in data_tuples
name = "Θ-IPOD(L2Loss, $(pen); method=:auto),\t"
name *= (
(A isa FormulaTerm) ? "formula" :
(A isa SparseMatrixCSC) ? "sparse " : "dense "
if (A isa FormulaTerm)
"formula"
elseif (A isa SparseMatrixCSC)
"sparse "
else
"dense "
end
)

m1 = fit(IPODRegression, A, b, loss1, pen; kwargs...)
Expand Down Expand Up @@ -255,8 +270,13 @@ end
@testset "data type: $(typeof(A))" for (A, b) in data_tuples
name = "Θ-IPOD(L2Loss, $(pen); method=$(method)),\t"
name *= (
(A isa FormulaTerm) ? "formula" :
(A isa SparseMatrixCSC) ? "sparse " : "dense "
if (A isa FormulaTerm)
"formula"
elseif (A isa SparseMatrixCSC)
"sparse "
else
"dense "
end
)

m1 = fit(IPODRegression, A, b, loss1, pen; method=method, kwargs...)
Expand Down
13 changes: 9 additions & 4 deletions test/penalties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,13 @@ t = 10 * randn(rng, 30)
@testset "data type: $(typeof(A))" for (A, b) in data2_tuples
name = "rlm($(pen); method=$(method)),\t"
name *= (
(A isa FormulaTerm) ? "formula" :
(A isa SparseMatrixCSC) ? "sparse " : "dense "
if (A isa FormulaTerm)
"formula"
elseif (A isa SparseMatrixCSC)
"sparse "
else
"dense "
end
)

m2 = rlm(A, b, pen; method=method, kwargs...)
Expand Down Expand Up @@ -166,8 +171,8 @@ t = 10 * randn(rng, 30)
@test all(abs.(var2) .>= abs.(var1))
elseif f in (leverage_weights,)
@test all(abs.(var2) .<= abs.(var1))
# elseif f in (confint,)
# @test isapprox(var1, var2; rtol=1e-1)
# elseif f in (confint,)
# @test isapprox(var1, var2; rtol=1e-1)
elseif f in (confint, projectionmatrix)
continue
else
Expand Down
Loading

0 comments on commit e658a4c

Please sign in to comment.