forked from boyuren158/GP-CERF
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from NSAPH-Software/develop
Develop
- Loading branch information
Showing
149 changed files
with
32,354 additions
and
835 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,6 @@ | |
^\.github$ | ||
^\_analysis | ||
^\_src | ||
^_pkgdown\.yml$ | ||
^docs$ | ||
index.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,5 @@ | |
.RData | ||
.Ruserdata | ||
*.DS_Store | ||
inst/doc | ||
tests/testthat/GPCERF.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: GPCERF | ||
Title: What the Package Does (One Line, Title Case) | ||
Version: 0.0.1 | ||
Version: 0.0.2 | ||
Authors@R: c( | ||
person("Naeem", "Khoshnevis", email = "[email protected]", | ||
role=c("aut","cre"), | ||
|
@@ -18,21 +18,30 @@ Maintainer: Naeem Khoshnevis <[email protected]> | |
Description: What the package does (one paragraph). | ||
License: GPL (>= 3) | ||
Language: en-US | ||
URL: https://github.com/fasrc/GPCERF | ||
BugReports: https://github.com/fasrc/GPCERF/issues | ||
URL: https://github.com/NSAPH-Software/GPCERF | ||
BugReports: https://github.com/NSAPH-Software/GPCERF/issues | ||
Copyright: Harvard University | ||
Imports: | ||
data.table, | ||
xgboost, | ||
stats, | ||
MASS, | ||
spatstat.geom | ||
spatstat.geom, | ||
logger, | ||
Rcpp, | ||
RcppArmadillo | ||
Encoding: UTF-8 | ||
LazyData: true | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.1.2 | ||
Depends: | ||
R (>= 3.5.0) | ||
Suggests: | ||
rmarkdown, | ||
knitr, | ||
testthat (>= 3.0.0) | ||
Config/testthat/edition: 3 | ||
VignetteBuilder: knitr | ||
LinkingTo: | ||
RcppArmadillo, | ||
Rcpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,31 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(calc_ac) | ||
export(compute_deriv_nn) | ||
export(compute_deriv_weights_gp) | ||
export(compute_inverse) | ||
export(compute_m_sigma) | ||
export(compute_posterior_m_nn) | ||
export(compute_posterior_sd_nn) | ||
export(compute_rl_deriv_gp) | ||
export(compute_rl_deriv_nn) | ||
export(compute_w_corr) | ||
export(compute_weight_gp) | ||
export(estimate_cerf_gp) | ||
export(estimate_cerf_nngp) | ||
export(estimate_mean_sd_nn) | ||
export(estimate_noise) | ||
export(estimate_noise_nn) | ||
export(find_optimal_nn) | ||
export(generate_synthetic_data) | ||
export(get_logger) | ||
export(set_logger) | ||
export(train_GPS) | ||
import(MASS) | ||
import(Rcpp) | ||
import(RcppArmadillo) | ||
import(data.table) | ||
import(stats) | ||
import(xgboost) | ||
importFrom(spatstat.geom,crossdist) | ||
useDynLib(GPCERF, .registration = TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Generated by using Rcpp::compileAttributes() -> do not edit by hand | ||
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 | ||
|
||
calc_cross <- function(cross, within) { | ||
.Call('_GPCERF_calc_cross', PACKAGE = 'GPCERF', cross, within) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,42 @@ | ||
#' @title | ||
#' Calculate Covariate Balance | ||
#' | ||
#' @description | ||
#' Calculate weighted correlation between a list of covariates and an exposure. | ||
#' Weights are defined by the covariance function of the GP. | ||
#' | ||
#' @param w A vector exposure values across all subjects. | ||
#' @param X A matrix of covariate values. Subjects in rows and covariates in columns. | ||
#' @param weights A vector of weights assigned to all subjects based on the trained GP. | ||
#' | ||
#' @return | ||
#' A vector of correlations between w and each column of X. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' | ||
calc.ac = function(w, X, weights){ | ||
w.mean = sum(w*weights) | ||
w.sd = sqrt(sum((w-w.mean)^2*weights)) | ||
w.trans = (w-w.mean)/w.sd | ||
|
||
X.mean = colSums(X*weights) | ||
X.cov = (t(X) - X.mean)%*%diag(weights)%*%t(t(X)-X.mean) | ||
X.trans = t(t(solve(chol(X.cov)))%*%(t(X)-X.mean)) | ||
|
||
c(w.trans%*%diag(weights)%*%X.trans) | ||
} | ||
#' @title | ||
#' Calculate Covariate Balance | ||
#' | ||
#' @description | ||
#' Calculate weighted correlation between a list of covariates and an exposure. | ||
#' Weights are defined by the covariance function of the GP. | ||
#' | ||
#' @param w A vector exposure values across all subjects. | ||
#' @param X A matrix of covariate values. Subjects in rows and covariates in columns. | ||
#' @param weights A vector of weights assigned to all subjects based on the trained GP. | ||
#' | ||
#' @return | ||
#' A vector of correlations between w and each column of X. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' | ||
#' set.seed(429) | ||
#' | ||
#' # generate data | ||
#' data <- generate_synthetic_data(sample_size = 200, gps_spec = 3) | ||
#' | ||
#' # generate random weights | ||
#' weights <- runif(nrow(data)) | ||
#' weights <- weights/sum(weights) | ||
#' | ||
#' # covariate matrix | ||
#' design_mt <- model.matrix(~.-1, data = data[, 3:ncol(data)]) | ||
#' | ||
#' cb <- calc_ac(w = data$treat, X = design_mt, weights=weights) | ||
#' | ||
calc_ac <- function(w, X, weights){ | ||
w.mean = sum(w*weights) | ||
w.sd = sqrt(sum((w-w.mean)^2*weights)) | ||
w.trans = (w-w.mean)/w.sd | ||
|
||
X.mean = colSums(X*weights) | ||
X.cov = (t(X) - X.mean)%*%diag(weights)%*%t(t(X)-X.mean) | ||
X.trans = t(t(solve(chol(X.cov)))%*%(t(X)-X.mean)) | ||
|
||
c(w.trans%*%diag(weights)%*%X.trans) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#' @title | ||
#' Calculate Derivatives of CERF for nnGP | ||
#' | ||
#' @description | ||
#' Calculates the posterior mean of the derivative of CERF at a given | ||
#' exposure level with nnGP. | ||
#' | ||
#' @param w A scalar of exposure level of interest. | ||
#' @param w_obs A vector of observed exposure levels of all samples. | ||
#' @param GPS_m A data.table of GPS vectors. | ||
#' - Column 1: GPS | ||
#' - Column 2: Prediction of exposure for covariate of each data sample (e_gps_pred). | ||
#' - Column 3: Standard deviation of e_gps (e_gps_std) | ||
#' @param y_obs A vector of observed outcome values. | ||
#' @param hyperparam A vector of hyper-parameters in the GP model. | ||
#' @param n_neighbor Number of nearest neighbours on one side (see also \code{expand}). | ||
#' @param expand Scaling factor to determine the total number of nearest neighbours. The total is \code{2*expand*n_neighbor}. | ||
#' @param block_size Number of samples included in a computation block. Mainly used to | ||
#' balance the speed and memory requirement. Larger \code{block_size} is faster, but requires more memory. | ||
#' @param kernel_fn The covariance function. The input is the square of Euclidean distance. | ||
#' @param kernel_deriv_fn The partial derivative of the covariance function. The input is the square of Euclidean distance. | ||
#' | ||
#' @return | ||
#' A scalar of estimated derivative of CERF at \code{w} in nnGP. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' | ||
#' set.seed(365) | ||
#' data <- generate_synthetic_data(sample_size = 200) | ||
#' GPS_m <- train_GPS(cov.mt = as.matrix(data[,-(1:2)]), | ||
#' w.all = as.matrix(data$treat)) | ||
#' | ||
#' wi <- 4.8 | ||
#' | ||
#' deriv_val <- compute_deriv_nn(w = wi, | ||
#' w_obs = data$treat, | ||
#' GPS_m = GPS_m, | ||
#' y_obs = data$Y, | ||
#' hyperparam = c(0.1,0.2,1), | ||
#' n_neighbor = 20, | ||
#' expand = 1, | ||
#' block_size = 1000) | ||
#' | ||
compute_deriv_nn <- function(w, | ||
w_obs, | ||
GPS_m, | ||
y_obs, | ||
hyperparam, | ||
n_neighbor, | ||
expand, | ||
block_size, | ||
kernel_fn = function(x) exp(-x), | ||
kernel_deriv_fn = function(x) -exp(-x)){ | ||
|
||
|
||
alpha <- hyperparam[[1]] | ||
beta <- hyperparam[[2]] | ||
g_sigma <- hyperparam[[3]] | ||
|
||
|
||
GPS <- GPS_m$GPS | ||
e_gps_pred <- GPS_m$e_gps_pred | ||
e_gps_std <- GPS_m$e_gps_std | ||
|
||
|
||
# params[1]: alpha, params[2]: beta, params[3]: gamma | ||
# cov = gamma*h(alpha*w^2 + beta*GPS^2) + diag(1) | ||
GPS_w = dnorm(w, mean = e_gps_pred, sd = e_gps_std, log = T) | ||
|
||
n = length(GPS_w) | ||
n.block = ceiling(n/block_size) | ||
obs.raw = cbind(w_obs, GPS) | ||
obs.ord = obs.raw[order(obs.raw[,1]),] | ||
y_obs.ord = y_obs[order(obs.raw[,1])] | ||
#params: length 3, first scale for w, second scale for GPS, | ||
#third scale for exp fn | ||
if(w >= obs.ord[nrow(obs.ord),1]){ | ||
idx.all = seq( nrow(obs.ord) - expand*n_neighbor + 1, nrow(obs.ord), 1) | ||
}else{ | ||
idx.anchor = which.max(obs.ord[,1]>=w) | ||
idx.start = max(1, idx.anchor - n_neighbor*expand) | ||
idx.end = min(nrow(obs.ord), idx.anchor + n_neighbor*expand) | ||
if(idx.end == nrow(obs.ord)){ | ||
idx.all = seq(idx.end - n_neighbor*2*expand + 1, idx.end, 1) | ||
}else{ | ||
idx.all = seq(idx.start, idx.start+n_neighbor*2*expand-1, 1) | ||
} | ||
} | ||
|
||
obs.use = t(t(obs.ord[idx.all,])*(1/sqrt(c(alpha, beta)))) | ||
y.use = y_obs.ord[idx.all] | ||
|
||
obs.new = t(t(cbind(w, GPS_w))*(1/sqrt(c(alpha, beta)))) | ||
id.all = split(1:n, ceiling(seq_along(1:n)/n.block)) | ||
Sigma.obs = g_sigma*kernel_fn(as.matrix(dist(obs.use))^2) + diag(nrow(obs.use)) | ||
Sigma.obs.inv = chol2inv(chol(Sigma.obs)) | ||
|
||
all.weights = sapply(id.all, function(id.ind){ | ||
cross.dist = spatstat.geom::crossdist(obs.new[id.ind,1], obs.new[id.ind,2], | ||
obs.use[,1], obs.use[,2]) | ||
Sigma.cross = g_sigma*(1/alpha)*(2*outer(rep(w,length(id.ind))*(1/alpha), obs.use[,1], "-"))* | ||
kernel_deriv_fn(cross.dist^2) | ||
#mean | ||
wght = Sigma.cross%*%Sigma.obs.inv | ||
colSums(wght) | ||
}) | ||
weights = rowSums(all.weights)/n | ||
weights%*%y.use | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#' @title | ||
#' Calculate Derivatives of CERF | ||
#' | ||
#' @description | ||
#' Calculate the weights assigned to each observed outcome when deriving the | ||
#' posterior mean of the first derivative of CERF at a given exposure level. | ||
#' | ||
#' @param w A scalar of exposure level of interest. | ||
#' @param w_obs A vector of observed exposure levels of all samples. | ||
#' @param GPS_m A data.table of GPS vectors. | ||
#' - Column 1: GPS | ||
#' - Column 2: Prediction of exposure for covariate of each data sample (e_gps_pred). | ||
#' - Column 3: Standard deviation of e_gps (e_gps_std) | ||
#' @param hyperparam A vector of hyper-parameters in the GP model. | ||
#' @param kernel_fn The covariance function. | ||
#' @param kernel_deriv_fn The partial derivative of the covariance function. | ||
#' | ||
#' @return | ||
#' A vector of weights for all samples, based on which the posterior mean of the derivative of CERF at the | ||
#' exposure level of interest is calculated. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' | ||
#' set.seed(915) | ||
#' data <- generate_synthetic_data(sample_size = 200) | ||
#' GPS_m <- train_GPS(cov.mt = as.matrix(data[,-(1:2)]), | ||
#' w.all = as.matrix(data$treat)) | ||
#' | ||
#' wi <- 4.8 | ||
#' weights <- compute_deriv_weights_gp(w = wi, | ||
#' w_obs = data$treat, | ||
#' GPS_m = GPS_m, | ||
#' hyperparam = c(1,1,2)) | ||
#' | ||
compute_deriv_weights_gp <- function(w, | ||
w_obs, | ||
GPS_m, | ||
hyperparam, | ||
kernel_fn = function(x) exp(-x), | ||
kernel_deriv_fn = function(x) -exp(-x)){ | ||
|
||
|
||
alpha <- hyperparam[[1]] | ||
beta <- hyperparam[[2]] | ||
g_sigma <- hyperparam[[3]] | ||
|
||
|
||
GPS <- GPS_m$GPS | ||
e_gps_pred <- GPS_m$e_gps_pred | ||
e_gps_std <- GPS_m$e_gps_std | ||
|
||
|
||
# param[1]: alpha, param[2]: beta, param[3]: gamma | ||
# cov = gamma*h(alpha*w^2 + beta*GPS^2) + diag(1) | ||
GPS_w = dnorm(w, mean = e_gps_pred, sd = e_gps_std, log = TRUE) | ||
n = length(GPS_w) | ||
|
||
obs.use = cbind( w_obs*sqrt(1/alpha), GPS*sqrt(1/beta) ) | ||
obs.new = cbind( w*sqrt(1/alpha), GPS_w*sqrt(1/beta) ) | ||
Sigma.obs = g_sigma*kernel_fn(as.matrix(dist(obs.use))^2) + diag(nrow(obs.use)) | ||
cross.dist = spatstat.geom::crossdist(obs.new[,1], obs.new[,2], | ||
obs.use[,1], obs.use[,2]) | ||
Sigma.cross = g_sigma*sqrt(1/alpha)*kernel_deriv_fn(cross.dist^2)* | ||
(2*outer(rep(w,n), w_obs, "-")) | ||
weights.all = Sigma.cross%*%chol2inv(chol(Sigma.obs)) | ||
# weights.all[weights.all<0] = 0 | ||
# weights = colMeans(weights.all) | ||
# weights/sum(weights) | ||
colMeans(weights.all) | ||
} |
Oops, something went wrong.