Skip to content

Commit

Permalink
Merge pull request #89 from TESTgroup-BNL/data_split_fix
Browse files Browse the repository at this point in the history
A bug fix to correctly stratify cal/val datasets based on user-selected groupings.  The issue was identified by @asierrl who also illustrated a potential fix, which has been included in the bug fix along with new tests to check that samples are not duplicated between cal and val.
  • Loading branch information
Shawn P. Serbin authored Dec 29, 2021
2 parents 4e472c6 + 56b4b7d commit 952a194
Show file tree
Hide file tree
Showing 22 changed files with 6,866 additions and 132 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: spectratrait
Title: A simple add-on package to aid in the fitting of leaf-level spectra-trait PLSR models
Version: 1.0.3
Version: 1.0.5
Authors@R:
c(person(given = "Julien",
family = "Lamour",
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export("%notin%")
export(VIP)
export(VIPjh)
export(create_data_split)
Expand All @@ -12,3 +13,8 @@ export(percent_rmse)
export(pls_permutation)
export(source_GitHubData)
import(httr)
importFrom(pls,plsr)
importFrom(utils,flush.console)
importFrom(utils,read.table)
importFrom(utils,setTxtProgressBar)
importFrom(utils,txtProgressBar)
43 changes: 30 additions & 13 deletions R/create_data_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,30 @@
create_data_split <- function(dataset=NULL, approach=NULL, split_seed=123456789, prop=0.8,
group_variables=NULL) {
set.seed(split_seed)

# outer if/else to stop if approach set to NULL
if(!is.null(approach)) {

## base R data split function
if (approach=="base") {
dataset$CalVal <- NA
split_var <- group_variables
if(length(group_variables) > 1){
if(length(group_variables) > 1) {
dataset$ID <- apply(dataset[, group_variables], MARGIN = 1, FUN = function(x) paste(x, collapse = " "))
} else {
dataset$ID <- dataset[, group_variables]
}
split_var_list <- unique(dataset$ID)
for(i in 1:length(split_var_list)){
temp <- row.names(dataset[ dataset$ID == split_var_list[i], ])
for(i in 1:length(split_var_list)) {
temp <- row.names(dataset[dataset$ID == split_var_list[i], ])
## there should probably be more than 4 obs I'm guessing, so this may need adjusting
if(length(temp) > 3){
if(length(temp) > 3) {
Cal <- sample(temp,round(prop*length(temp)))
Val <- temp[!temp %in% Cal]
Val <- temp[temp %notin% Cal]
dataset$CalVal[ row.names(dataset) %in% Cal ] <- "Cal"
dataset$CalVal[ row.names(dataset) %in% Val ] <- "Val"
p_cal <- length(Cal)/length(temp) * 100
message(paste0(split_var_list[i], " ", "Cal", ": ", p_cal, "%"))
message(paste0(split_var_list[i], " ", "Cal", ": ", round(p_cal,3), "%"))
} else {
message(paste(split_var_list[i], "Not enough observations"))
}
Expand All @@ -44,16 +48,29 @@ create_data_split <- function(dataset=NULL, approach=NULL, split_seed=123456789,
dataset <- dataset[!is.na(dataset$CalVal), ]
cal.plsr.data <- dataset[dataset$CalVal== "Cal",]
val.plsr.data <- dataset[dataset$CalVal== "Val",]
} else if (approach=="dplyr")
cal.plsr.data <- dataset %>%

# Remove temporary CalVal column
cal.plsr.data <- cal.plsr.data[,-which(names(cal.plsr.data)=="CalVal")]
val.plsr.data <- val.plsr.data[,-which(names(val.plsr.data)=="CalVal")]

# dplyr based data split function
} else if (approach=="dplyr") {
dataset <- dataset %>% mutate(ids=row_number())
cal.plsr.data <- dataset %>%
group_by_at(vars(all_of(group_variables))) %>%
slice(sample(1:n(), prop*n())) %>%
data.frame()
val.plsr.data <-dataset[!row.names(dataset) %in% row.names(cal.plsr.data),]

val.plsr.data <- dataset[dataset$ids %notin% cal.plsr.data$ids,]
cal.plsr.data <- cal.plsr.data[,-which(colnames(cal.plsr.data)=="ids")]
val.plsr.data <- val.plsr.data[,-which(colnames(val.plsr.data)=="ids")]
} else {
stop("**** Please choose either base R or dplyr data split ****")
stop("**** Please set approach to either base R or dplyr data split ****")
}
output_list <- list(cal_data=cal.plsr.data, val_data=val.plsr.data)
return(output_list)
}
output_list <- list(cal_data=cal.plsr.data, val_data=val.plsr.data)
return(output_list)

# if approach is set to NULL (i.e. not set) return error message
stop("**** Please set approach to either base R or dplyr data split ****")

}
7 changes: 5 additions & 2 deletions R/pls_permutation.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
##' @return output a list containing the PRESS and coef_array.
##' output <- list(PRESS=press.out, coef_array=coefs)
##'
##' @importFrom pls plsr
##' @importFrom utils flush.console read.table setTxtProgressBar txtProgressBar
##'
##' @author Julien Lamour, Shawn P. Serbin
##' @export
pls_permutation <- function(dataset=NULL, maxComps=20, iterations=20, prop=0.70,
Expand All @@ -25,8 +28,8 @@ pls_permutation <- function(dataset=NULL, maxComps=20, iterations=20, prop=0.70,

if (verbose) {
j <- 1 # <--- Numeric counter for progress bar
pb <- utils::txtProgressBar(min = 0, max = iterations,
char="*",width=70,style = 3)
pb <- txtProgressBar(min = 0, max = iterations,
char="*",width=70,style = 3)
}

for (i in seq_along(1:iterations)) {
Expand Down
6 changes: 6 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ percent_rmse <- function(plsr_dataset = NULL, inVar = NULL,
return(output)
}


##' Not %in% function
##'
##' @export
`%notin%` <- Negate(`%in%`)

##' Function to check for installed package
##' not presently used
testForPackage <- function(pkg) {
Expand Down
4 changes: 3 additions & 1 deletion inst/scripts/spectra-trait_kit_sla_plsr_example.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ sample_info <- dat_raw[,names(dat_raw) %notin% seq(350,2500,1)]
head(sample_info)

sample_info2 <- sample_info %>%
select(Plant_Species=species,Growth_Form=`growth form`,timestamp,SLA_g_cm=`SLA (g/cm )`)
select(Plant_Species=species,Growth_Form=`growth form`,timestamp,
SLA_g_cm=`SLA (g/cm )`) %>%
mutate(SLA_g_cm=as.numeric(SLA_g_cm)) # ensure SLA is numeric
head(sample_info2)

plsr_data <- data.frame(sample_info2,Spectra)
Expand Down
11 changes: 11 additions & 0 deletions man/grapes-notin-grapes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file removed spectratrait_1.0.3.pdf
Binary file not shown.
Binary file added spectratrait_1.0.5.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
library(testthat)
library(dplyr)
library(spectratrait)

test_check("spectratrait")
33 changes: 33 additions & 0 deletions tests/testthat/test.create_data_split.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
context("Test that the create data split function has the expected behavior")

test_that("Generating a data split using the dplyr approach doesn't throw an error or generate duplicates between cal. and val. data", {
plot<- rep(c("plot1", "plot2", "plot3"),each=42)
season<- rep(1:6, 21)
disease<- c(rep(0,84), rep(1,42))
d<- seq(1:126)
df <- data.frame(plot,season,disease,d)
df <- df %>% mutate(id=row_number())

split_data <- spectratrait::create_data_split(dataset=df, approach="dplyr",
split_seed=7529075, prop=0.8,
group_variables=c("plot",
"season",
"disease"))
expect_false(sum(split_data$cal_data$id %in% split_data$val_data$id)>0)
})

test_that("Generating a data split using the base approach doesn't throw an error or generate duplicates between cal. and val. data", {
plot<- rep(c("plot1", "plot2", "plot3"),each=42)
season<- rep(1:6, 21)
disease<- c(rep(0,84), rep(1,42))
d<- seq(1:126)
df <- data.frame(plot,season,disease,d)
df <- df %>% mutate(id=row_number())

split_data <- spectratrait::create_data_split(dataset=df, approach="base",
split_seed=7529075, prop=0.8,
group_variables=c("plot",
"season",
"disease"))
expect_false(sum(split_data$cal_data$id %in% split_data$val_data$id)>0)
})
8 changes: 5 additions & 3 deletions vignettes/kit_sla_plsr_example.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ title: Spectra-trait PLSR example using leaf-level spectra and specific leaf are
author: "Shawn P. Serbin, Julien Lamour, & Jeremiah Anderson"
output:
github_document: default
html_document:
df_print: paged
html_notebook: default
pdf_document: default
html_document:
df_print: paged
params:
date: !r Sys.Date()
---
Expand Down Expand Up @@ -80,7 +80,9 @@ sample_info <- dat_raw[,names(dat_raw) %notin% seq(350,2500,1)]
head(sample_info)
sample_info2 <- sample_info %>%
select(Plant_Species=species,Growth_Form=`growth form`,timestamp,SLA_g_cm=`SLA (g/cm )`)
select(Plant_Species=species,Growth_Form=`growth form`,timestamp,
SLA_g_cm=`SLA (g/cm )`) %>%
mutate(SLA_g_cm=as.numeric(SLA_g_cm)) # ensure SLA is numeric
head(sample_info2)
plsr_data <- data.frame(sample_info2,Spectra)
Expand Down
Loading

0 comments on commit 952a194

Please sign in to comment.