-
Notifications
You must be signed in to change notification settings - Fork 269
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adds missing lib/ directory with helper functions
- Loading branch information
Showing
3 changed files
with
268 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
require(compiler) | ||
multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){ | ||
|
||
#Load Libraries | ||
require(Metrics) | ||
require(ModelMetrics) | ||
require(caret) | ||
|
||
#Check data | ||
if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) | ||
stop("levels of observed and predicted data do not match") | ||
has_class_probs <- all(lev %in% colnames(data)) | ||
if(has_class_probs) { | ||
## Overall multinomial loss | ||
lloss <- mnLogLoss(data = data, lev = lev, model = model) | ||
requireNamespace("ModelMetrics") | ||
#Calculate custom one-vs-all ROC curves for each class | ||
prob_stats <- lapply(levels(data[, "pred"]), | ||
function(x){ | ||
#Grab one-vs-all data for the class | ||
obs <- ifelse(data[, "obs"] == x, 1, 0) | ||
prob <- data[,x] | ||
AUCs <- try(ModelMetrics::auc(obs, data[,x]), silent = TRUE) | ||
return(AUCs) | ||
}) | ||
roc_stats <- mean(unlist(prob_stats)) | ||
} | ||
|
||
#Calculate confusion matrix-based statistics | ||
CM <- caret::confusionMatrix(data[, "pred"], data[, "obs"]) | ||
|
||
#Aggregate and average class-wise stats | ||
#Todo: add weights | ||
# RES: support two classes here as well | ||
#browser() # Debug | ||
if (length(levels(data[, "pred"])) == 2) { | ||
class_stats <- CM$byClass | ||
} else { | ||
class_stats <- colMeans(CM$byClass) | ||
names(class_stats) <- paste("Mean", names(class_stats)) | ||
} | ||
|
||
# Aggregate overall stats | ||
overall_stats <- if(has_class_probs) | ||
c(CM$overall, logLoss = as.numeric(lloss), ROC = roc_stats) else CM$overall | ||
if (length(levels(data[, "pred"])) > 2) | ||
names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_AUC" | ||
|
||
|
||
# Combine overall with class-wise stats and remove some stats we don't want | ||
stats <- c(overall_stats, class_stats) | ||
stats <- stats[! names(stats) %in% c('AccuracyNull', "AccuracyLower", "AccuracyUpper", | ||
"AccuracyPValue", "McnemarPValue", | ||
'Mean Prevalence', 'Mean Detection Prevalence')] | ||
|
||
# Clean names | ||
names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) | ||
|
||
# Change name ordering to place most useful first | ||
# May want to remove some of these eventually | ||
stat_list <- c("Accuracy", "Kappa", "Mean_F1", "Mean_Sensitivity", "Mean_Specificity", | ||
"Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate", | ||
"Mean_Balanced_Accuracy") | ||
if(has_class_probs) stat_list <- c("logLoss", "Mean_AUC", stat_list) | ||
if (length(levels(data[, "pred"])) == 2) stat_list <- gsub("^Mean_", "", stat_list) | ||
|
||
stats <- stats[c(stat_list)] | ||
|
||
return(stats) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# NOTES: | ||
# -geometric mean doesn't work well for zeros (common with probabilities) AVOID | ||
# -TODO: add weights to the voting in majority_vote() based on model strength | ||
|
||
set_seeds <- function(CVfolds, CVreps, tuneLength, init_seed = 42){ | ||
seedNum <- CVfolds * CVreps + 1 | ||
seedLen <- tuneLength^3 #(CVfolds + tuneLength) * 10 | ||
# create manual seeds vector for parallel processing repeatability | ||
set.seed(init_seed) | ||
seeds <- vector(mode = "list", length = seedNum) | ||
for(i in 1:(seedNum-1)) { | ||
seeds[[i]] <- sample.int(.Machine$integer.max, seedLen) | ||
} | ||
## For the last model: | ||
seeds[[seedNum]] <- sample.int(.Machine$integer.max, 1) | ||
return(seeds) | ||
} | ||
|
||
geometric_mean = function(x, na.rm=TRUE, zero.propagate = FALSE){ | ||
if(any(x < 0, na.rm = TRUE)){ | ||
return(NaN) | ||
} | ||
if(zero.propagate){ | ||
if(any(x == 0, na.rm = TRUE)){ | ||
return(0) | ||
} | ||
exp(mean(log(x), na.rm = na.rm)) | ||
} else { | ||
exp(sum(log(x[x > 0]), na.rm=na.rm) / length(x)) | ||
} | ||
} | ||
|
||
majority_vote <- function(ordered_predList, | ||
reference, | ||
voteType = c("count", "prob"), | ||
meanType = c("arithmetic", "geometric"), | ||
metric = "F1") { | ||
voteType <- match.arg(voteType) | ||
meanType <- match.arg(meanType) # meanType used only for voteType == "prob" | ||
classes <- levels(reference) | ||
|
||
if(voteType == "prob") { | ||
if(meanType == "geometric") { | ||
warning(paste("Geometric mean not well-behaved for small or", | ||
"zero probabilities. Results may be meaningless.")) | ||
meanFunc <- function(dat) apply(dat, 1, | ||
function(x) geometric_mean(x)) | ||
} else { | ||
# meanType == "arithmetic" | ||
meanFunc <- function (dat) rowSums(dat)/ncol(dat) | ||
} | ||
requireNamespace("dplyr") | ||
probs <- as.data.frame(matrix(nrow = nrow(ordered_predList), ncol = 0)) | ||
|
||
for(class in classes) { | ||
probs[[class]] <- ordered_predList %>% | ||
dplyr::select(contains(class)) %>% | ||
meanFunc() | ||
} | ||
#browser() | ||
votes <- colnames(probs)[max.col(probs, ties.method = "first")] | ||
#browser() | ||
} else { | ||
# voteType == "count" | ||
votes <- apply(ordered_predList, 1, | ||
function(x) names(which.max(table(x)))) # ties: first | ||
} | ||
|
||
votes <- factor(votes) | ||
levels(votes) <- classes | ||
return(votes) | ||
} | ||
|
||
averaged_metric <- function(votes, reference, metric = "F1"){ | ||
requireNamespace("caret") | ||
CM <- caret::confusionMatrix(votes,reference, mode = "everything") | ||
# metric averaged across all classes | ||
metric <- colMeans(CM$byClass, na.rm = TRUE)[[metric]] | ||
return(metric) | ||
} | ||
|
||
ordered_predict <- function(modelList, | ||
newdata, | ||
reference, | ||
type = c("raw", "prob"), | ||
metric = "F1") { | ||
type <- match.arg(type) | ||
predList <- as.data.frame(lapply(modelList, | ||
function(x) predict(x, newdata))) | ||
perfs <- apply(predList, 2, | ||
function(col) {averaged_metric(col, reference, metric)}) | ||
ord <- order(perfs, decreasing = TRUE) # best performer first | ||
|
||
if(type == "raw") { | ||
ordered_predList <- predList[, ord] | ||
} else { | ||
#type == "prob" | ||
ordered_predList <- as.data.frame(lapply(modelList[ord], | ||
function(x) predict(x, newdata, type))) | ||
} | ||
|
||
return(ordered_predList) | ||
} | ||
|
||
# set seed before calling model_combos() for reproducible results | ||
model_combos <- function(modelList, | ||
reference, | ||
newdata, | ||
metric = "F1", | ||
voteType = c("count", "prob"), | ||
meanType = c("arithmetic", "geometric"), | ||
plot = FALSE) { | ||
requireNamespace("caret") | ||
voteType <- match.arg(voteType) | ||
meanType <- match.arg(meanType) | ||
nModels <- length(modelList) | ||
|
||
modelNames <- sapply(modelList, function(x) x$method) | ||
predList <- ordered_predict(modelList, | ||
newdata, | ||
reference, | ||
metric, | ||
type = ifelse(voteType == "prob","prob","raw") | ||
) | ||
|
||
comboList <- list() | ||
for(k in 1:nModels) { | ||
combos <- combn(modelNames, k, simplify = FALSE) | ||
comboList <- c(comboList, combos) | ||
} | ||
|
||
scores <- c() | ||
for(combo in comboList) { | ||
pat <- paste(combo, collapse = "|") | ||
|
||
if((voteType == "count") & (length(combo) < 3)) { | ||
if(length(combo) == 2) {next} # can't do majority vote with 2 | ||
if(length(combo) == 1) { | ||
votes <- predList[, grepl(pat, colnames(predList))] | ||
} | ||
} else { | ||
votes <- predList[, grepl(pat, colnames(predList))] %>% | ||
majority_vote(reference, voteType, meanType, metric) | ||
} | ||
score <- averaged_metric(votes,reference, metric) | ||
names(score) <- paste(combo, collapse = ".") | ||
scores <- c(scores, score) | ||
scores <- sort(scores, decreasing = TRUE) | ||
} | ||
|
||
if(isTRUE(plot)) { | ||
requireNamespace("ggplot2") | ||
df <- data.frame(factor(names(scores)), scores) | ||
names(df) <- c("model", metric) | ||
p <- ggplot(df) + | ||
geom_point(aes(y = reorder(model, get(metric)), x = get(metric))) + | ||
labs(y = "", x = paste(metric)) | ||
print(p) | ||
} | ||
|
||
return(scores) | ||
} | ||
|
||
plot_wells <- function(dat) { | ||
p <- ggplot(dat, aes(x = Depth, y = Well.Name)) + | ||
geom_tile(aes(fill = Facies)) + | ||
facet_grid(Well.Name ~ ., scales = "free") + | ||
scale_fill_brewer(name = "Facies", labels = 1:9,palette = "Set1") + | ||
theme_bw() + | ||
theme(panel.grid = element_blank(), | ||
strip.text = element_blank(), | ||
strip.background = element_blank(), | ||
plot.title = element_text(hjust = 0.5)) + | ||
labs(title = "Rock Facies Types and Depths for each Well", | ||
y = "", | ||
x = "Depth [m]") | ||
|
||
print(p) | ||
} | ||
|
||
facies_hist <- function(dat) { | ||
ggplot(dat, aes(x = Facies)) + | ||
geom_bar(aes(y = (..count..), fill = Facies)) + | ||
scale_fill_brewer(palette = "Set1") + | ||
theme_bw() + | ||
theme(panel.grid.major.x = element_blank(), | ||
legend.position = "none", | ||
plot.title = element_text(hjust = 0.5)) + | ||
scale_x_discrete(label = 1:9) + | ||
scale_y_continuous(breaks = seq(0, 1000, 100)) + | ||
labs(title = "Distribution of Training Data by Facies", | ||
x = "Facies Type", | ||
y = "Count") | ||
} | ||
|