This package provides functions to compute the non-negative garrote estimator with (or without) a penalized initial estimator.
You can install the stable version on R CRAN.
install.packages("nnGarrote", dependencies = TRUE)
You can install the development version from GitHub.
Here is some code to compute the non-negative garrote estimator with ridge regression as an initial estimator, and compare it with ridge regression without the additional garrote shrinkage.
# Setting the parameters
p <- 100
n <- 500
n.test <- 5000
sparsity <- 0.2
rho <- 0.5
SNR <- 3
# Generating the coefficient <- floor(p*sparsity)
a <- 4*log(n)/sqrt(n)
neg.prob <- 0.2
nonzero.betas <- (-1)^(rbinom(, 1, neg.prob))*(a + abs(rnorm(
true.beta <- c(nonzero.betas, rep(0,
# Two groups correlation structure
Sigma.rho <- matrix(0, p, p)
Sigma.rho[,] <- rho
diag(Sigma.rho) <- 1
sigma.epsilon <- as.numeric(sqrt((t(true.beta) %*% Sigma.rho %*% true.beta)/SNR))
# Simulate some data
x.train <- mvnfast::rmvn(n, mu=rep(0,p), sigma=Sigma.rho)
y.train <- 1 + x.train %*% true.beta + rnorm(n=n, mean=0, sd=sigma.epsilon)
x.test <- mvnfast::rmvn(n.test, mu=rep(0,p), sigma=Sigma.rho)
y.test <- 1 + x.test %*% true.beta + rnorm(n.test, sd=sigma.epsilon)
# Applying the NNG with Ridge as an initial estimator
nng.out <- cv.nnGarrote(x.train, y.train, intercept=TRUE,
initial.model=c("LS", "glmnet")[1],
lambda.nng=NULL, lambda.initial=NULL, alpha=0,
nng.predictions <- predict(nng.out, newx=x.test)
# Ridge Regression
cv.ridge <- glmnet::cv.glmnet(x.train, y.train, alpha=0)
ridge <- glmnet::glmnet(x.train, y.train, alpha=0, lambda=cv.ridge$lambda.min)
ridge.predictions <- predict(ridge, newx=x.test)
# Comparisons of the coefficients
Note that the prediction accuracy is improved for the non-negative garrote in comparison to the ridge regression estimate. Also, the non-negative garrote output for the coefficient is much closer to the true one than the ridge regression output (in terms of the recall and precision).
This package is free and open source software, licensed under GPL (>= 2).