-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathpredict.R
140 lines (126 loc) · 6.22 KB
/
predict.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#' Predict the parameters or values of additional proteins
#'
#' This function can either predict the abundance matrix for proteins
#' (\code{type = "response"}) without missing values according to the
#' linear probabilistic dropout model, fitted with \code{proDA()}. Or, it
#' can predict the feature parameters for additional proteins given their
#' abundances including missing values after estimating the hyper-parameters
#' on a dataset with the same sample structure
#' (\code{type = "feature_parameters"}).
#'
#' \strong{Note:} this method behaves a little different from what one might
#' expect from the classical \code{predict.lm()} function, because
#' \code{object} is not just a single set of coefficients for one fit, but
#' many fits (ie. one for each protein) with some more hyper-parameters. The
#' classical \code{predict} function predicts the response for new samples.
#' This function does not support this, instead it is useful for getting a
#' matrix without missing values for additional proteins.
#'
#' @param object an 'proDAFit' object that is produced by \code{proDA()}.
#' @param newdata a matrix or a SummarizedExperiment which contains
#' the new abundances for which values are predicted.
#' @param newdesign a formula or design matrix that specifies the new structure
#' that will be fitted
#' @param type either "response" or "feature_parameters". Default:
#' \code{"response"}
#' @param ... additional parameters for the construction of the
#' 'proDAFit' object.
#'
#'
#' @return If \code{type = "response"} a matrix with the same dimensions
#' as \code{object}. Or, if \code{type = "feature_parameters"} a
#' 'proDAFit' object with the same hyper-parameters and column data
#' as \code{object}, but new fitted \code{rowData()}.
#'
#' @export
setMethod("predict", signature = "proDAFit", function(object, newdata, newdesign,
type=c("response", "feature_parameters"), ...){
type <- match.arg(type, c("response", "feature_parameters"))
if(type == "response" && (missing(newdata) || is.null(newdata)) &&
(missing(newdesign) || is.null(newdesign))){
return(object$coefficients %*% t(object$design))
}else if(type == "feature_parameters" && (missing(newdata) || is.null(newdata)) &&
(missing(newdesign) || is.null(newdesign))){
return(object)
}
design_formula <- NULL
if(missing(newdesign) || is.null(newdesign)){
design <- design(object)
}else if(inherits(newdesign,"formula")){
design <- convert_formula_to_model_matrix(newdesign, colData(object))
design_formula <- newdesign
}else if(is.matrix(newdesign)){
design <- newdesign
}else{
stop("Illegal newdesign of class", paste0(class(design)), ". Can only handle ",
"formula or matrix arguments")
}
if(missing(newdata) || is.null(newdata)){
newdata <- abundances(object)
}
# Check if newdata is applicable
stopifnot(ncol(object) == ncol(newdata))
stopifnot(colnames(object) == colnames(newdata))
stopifnot(ncol(object) == nrow(design))
feat_data <- fit_feature_parameters(object, newdata, design)
feat_df <- feat_data$feature_df
coef_mat <- feat_data$coefficient_matrix
coef_var_list <- feat_data$coef_var_list
new_fit_object <- proDAFit(newdata,
col_data = colData(object)[, is.na(mcols(colData(object))$type) |
mcols(colData(object))$type != "hyper_parameter"],
dropout_curve_position = object$hyper_parameters$dropout_curve_position,
dropout_curve_scale = object$hyper_parameters$dropout_curve_scale,
feature_parameters = feat_df,
coefficients = coef_mat,
coef_var = coef_var_list,
design_matrix = design,
design_formula = design_formula,
reference_level = reference_level(object),
location_prior_mean = object$hyper_parameters$location_prior_mean,
location_prior_scale = object$hyper_parameters$location_prior_scale,
location_prior_df = object$hyper_parameters$location_prior_df,
variance_prior_scale = object$hyper_parameters$variance_prior_scale,
variance_prior_df = object$hyper_parameters$variance_prior_df,
convergence = object$convergence, ...)
if(type == "response"){
predict(new_fit_object, type="response")
}else{
new_fit_object
}
})
fit_feature_parameters <- function(fit, newdata, design){
if(! is.null(fit$hyper_parameters$location_prior_mean) ||
! is.null(fit$hyper_parameters$variance_prior_scale)){
res_reg <- lapply(seq_len(nrow(newdata)), function(i){
pd_lm.fit(newdata[i, ], design,
dropout_curve_position = fit$hyper_parameters$dropout_curve_position,
dropout_curve_scale = fit$hyper_parameters$dropout_curve_scale,
location_prior_mean = fit$hyper_parameters$location_prior_mean,
location_prior_scale = fit$hyper_parameters$location_prior_scale,
variance_prior_scale = fit$hyper_parameters$variance_prior_scale,
variance_prior_df = fit$hyper_parameters$variance_prior_df,
location_prior_df = fit$hyper_parameters$location_prior_df)
})
}else{
res_unreg <- lapply(seq_len(nrow(newdata)), function(i){
pd_lm.fit(newdata[i, ], design,
dropout_curve_position = fit$hyper_parameters$dropout_curve_position,
dropout_curve_scale = fit$hyper_parameters$dropout_curve_scale)
})
res_reg <- res_unreg
}
res_reg <- lapply(res_reg, function(x) x[c("coefficients", "coef_variance_matrix", "n_approx", "df", "s2", "n_obs")])
feat_df <- as.data.frame(mply_dbl(res_reg, function(f){
unlist(f[-c(1,2)])
}, ncol = 4))
# feat_df$rss <- vapply(res_unreg, function(x) x[["rss"]], 0.0)
coef_mat <- mply_dbl(res_reg, function(f){
f$coefficients
}, ncol=ncol(design))
colnames(coef_mat) <- names(res_reg[[1]]$coefficients)
coef_var_list <- lapply(res_reg, function(.x) .x$coef_variance_matrix)
list(feature_df = feat_df,
coefficient_matrix = coef_mat,
coef_var_list = coef_var_list)
}