Skip to content

Commit

Permalink
refactor: add validate_detail function
Browse files Browse the repository at this point in the history
  • Loading branch information
overdodactyl committed Jan 6, 2024
1 parent cfb3891 commit e1f5c72
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
18 changes: 9 additions & 9 deletions R/dx_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ get_kappa_interpretation <- function(kappa) {
#' @export
#' @concept metrics
dx_cohens_kappa <- function(cm, detail = "full") {
validate_detail(detail)
# Calculate observed agreement (po)
po <- (cm$tp + cm$tn) / cm$n

Expand Down Expand Up @@ -606,13 +607,11 @@ calculate_mcc <- function(cm) {
#' @noRd
metric_binomial <- function(x, n, name, detail = "full", citype = "exact", ...) {
# if (check_zero_denominator(n, name)) return(NA)

validate_detail(detail)
if (detail == "simple") {
return(x / n)
} else if (detail == "full") {
return(ci_binomial(x, n, measure = name, citype, ...))
} else {
stop("Invalid detail parameter: should be 'simple' or 'full'")
}
}

Expand Down Expand Up @@ -756,6 +755,9 @@ dx_lrt_pos <- function(cm, detail = "full", ...) {
#' data frame with the metric and confidence intervals.
#' @noRd
metric_ratio <- function(cm, dx_ratio_func, dx_sd_func, detail = "full", ...) {

validate_detail(detail)

# Extract counts from confusion matrix
tp <- cm$tp
tn <- cm$tn
Expand All @@ -779,8 +781,6 @@ metric_ratio <- function(cm, dx_ratio_func, dx_sd_func, detail = "full", ...) {
return(ratio)
} else if (detail == "full") {
return(ci_ratio(tp, tn, fp, fn, ratio, ratio_sd, continuity_correction=continuity_correction, ...))
} else {
stop("Invalid detail parameter: should be 'simple' or 'full'")
}
}

Expand Down Expand Up @@ -857,6 +857,7 @@ ci_ratio <- function(tp, tn, fp, fn, ratio, ratio_sd, name, continuity_correctio
#' @export
#' @concept metrics
dx_auc_pr <- function(precision, recall, detail = "full") {
validate_detail(detail)
# Remove any NA values that could cause issues in the calculation
valid_indices <- !is.na(precision) & !is.na(recall)
precision <- precision[valid_indices]
Expand Down Expand Up @@ -921,6 +922,7 @@ dx_auc_pr <- function(precision, recall, detail = "full") {
#' @export
#' @concept metrics
dx_auc <- function(truth, predprob, detail = "full") {
validate_detail(detail)
rocest <- pROC::roc(truth, predprob, ci = T, quiet = TRUE)
aucest <- pROC::auc(rocest)
auctext <- as.character(pROC::ci(aucest))
Expand All @@ -941,8 +943,6 @@ dx_auc <- function(truth, predprob, detail = "full") {
lci_raw = auc_lci,
uci_raw = auc_uci
))
} else {
stop("Invalid detail parameter: should be 'simple' or 'full'")
}
}

Expand Down Expand Up @@ -1340,6 +1340,7 @@ get_roc <- function(true_varname, pred_varname, data, direction) {
#' @concept metrics
#' @export
dx_brier <- function(predprob, truth, detail = "full") {
validate_detail(detail)
# Ensuring that the length of predicted probabilities and actual outcomes are the same
if (length(predprob) != length(truth)) {
stop("The length of predicted probabilities and actual outcomes must be the same.")
Expand All @@ -1355,8 +1356,6 @@ dx_brier <- function(predprob, truth, detail = "full") {
estimate_raw = brier,
notes = "CIs not yet implemented"
))
} else {
stop("Invalid detail parameter: should be 'simple' or 'full'")
}
}

Expand All @@ -1381,6 +1380,7 @@ calculate_brier <- function(truth, predprob) {
#' @export
#' @concept metrics
dx_nir <- function(cm, detail = "full") {
validate_detail(detail)
# Calculate the total number of actual positives and negatives
dispos <- cm$dispos # Number of actual positives
disneg <- cm$disneg # Number of actual negatives
Expand Down
5 changes: 5 additions & 0 deletions R/dx_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ boot_metric <- function(truth, predprob, metric_func, metric_args, bootreps, mea
}

evaluate_metric <- function(cm, metric_func, measure_name, detail, boot, bootreps, ...) {
validate_detail(detail)
# Calculate the metric using the provided function
metric_raw <- metric_func(cm, ...)

Expand Down Expand Up @@ -367,4 +368,8 @@ compare_df <- function(models = "",

}

validate_detail <- function(detail) {
check <- match.arg(detail, choices = c("full", "simple"))
}


0 comments on commit e1f5c72

Please sign in to comment.