Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring VIF back #307

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
authors = ["Timothée Poisot <[email protected]>"]
version = "0.0.6"
version = "0.0.7"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
Expand All @@ -17,11 +16,12 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"

[compat]
Distributions = "0.25"
GLM = "1.9"
JSON = "0.21"
LinearAlgebra = "1"
MultivariateStats = "0.10"
Random = "1"
Statistics = "1"
StatsAPI = "1.7"
StatsBase = "0.34"
TestItems = "1.0"
Statistics = "1.10"
julia = "1.8"
7 changes: 7 additions & 0 deletions SDeMo/docs/src/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ forwardselection!
```@docs
variableimportance
```

## Variance Inflation Factor

```@docs
stepwisevif!
vif
```
10 changes: 6 additions & 4 deletions SDeMo/src/SDeMo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ module SDeMo

using TestItems

import GLM
import JSON
import StatsAPI
using Distributions
using LinearAlgebra
using MultivariateStats
using StatsBase
using Random
using Statistics
import JSON
using LinearAlgebra
using StatsBase

# Demo data
include("utilities/demodata.jl")
Expand Down Expand Up @@ -76,6 +75,9 @@ export ci
include("variables/selection.jl")
export noselection!, forwardselection!, backwardselection!

include("variables/vif.jl")
export stepwisevif!, vif

include("variables/importance.jl")
export variableimportance

Expand Down
45 changes: 32 additions & 13 deletions SDeMo/src/variables/vif.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
vif(m) = 1 / (1 - r²(m))
demean = (x) -> (x .- mean(x; dims = 2))

stepwisevif!(model::SDM, limit; kwargs...) = stepwisevif!(model, variables(model), limit)
"""
vif(::Matrix)

function stepwisevif!(model::SDM, limit; kwargs...)
Xv = features(model)[variables(model), :]
X = (Xv .- mean(Xv; dims = 2)) ./ std(Xv; dims = 2)
vifs = zeros(Float64, length(model.v))
for i in eachindex(model.v)
linreg = GLM.lm(X[setdiff(eachindex(model.v), i), :]', X[i, :])
vifs[i] = vif(linreg)
end
if all(vifs .<= threshold)
Returns the variance inflation factor for each variable in a matrix, as the diagonal of the inverse of the correlation matrix between predictors.
"""
vif(X::Matrix{T}) where {T <: Number} = diag(inv(cor(X)))

"""
vif(::AbstractSDM, tr=:)

Returns the VIF for the variables used in a SDM, optionally restricting to some training instances (defaults to `:` for all points). The VIF is calculated on the de-meaned predictors.
"""
vif(sdm::T, tr = :) where {T <: AbstractSDM} =
vif(permutedims(demean(features(sdm)[variables(sdm),:])))

"""
stepwisevif!(model::SDM, limit, tr=:;kwargs...)

Drops the variables with the largest variance inflation from the model, until all VIFs are under the threshold. The last positional argument (defaults to `:`) is the indices to use for the VIF calculation. All keyword arguments are passed to `train!`.
"""
function stepwisevif!(model::SDM, limit, tr = :; kwargs...)
vifs = vif(model, tr)
if all(vifs .<= limit)
train!(model; kwargs...)
return model
end
drop = last(findmax(vifs))
popat!(variables(model), drop)
return stepwisevif!(model, limit; kwargs...)
end
return stepwisevif!(model, limit, tr; kwargs...)
end

@testitem "We can select variables using the VIF" begin
X, y = SDeMo.__demodata()
model = SDM(RawData, NaiveBayes, X, y)
stepwisevif!(model, 10.0)
@test length(variables(model)) < size(X, 1)
end
Loading