-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathJSDM_convergence.R
126 lines (102 loc) · 4.37 KB
/
JSDM_convergence.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
source(here::here("R", "utils_generic.R"))
source(here::here("R", "utils_JSDM.R"))
library(ape)
library(Hmsc)
library(withr)
# Sample names ------------------------------------------------------------
mynames <- crossing(
n = c(100, 1000),
p = c("sort", "qeq"),
d = c("ma", "mp"),
m = c("full", "envi", "time")) %>%
mutate(model = paste(p, d, m, "thin", n, "samples_250_chains_4.rds", sep = "_")) %>%
pull(model)
# JSDM mcmc convergence ---------------------------------------------------
# The Gelman/Rubin's diagnostic should be 1, ideally less than 1.05.
mybetas <- map_dfr(mynames, psrf_ess, "beta")
mygammas <- map_dfr(mynames, psrf_ess, "gamma")
myomegas <- map_dfr(mynames, psrf_ess, "omega")
myrhos <- map_dfr(mynames, psrf_ess, "rho")
# save convergence statistics
myrhos %>%
group_by(model, thin) %>%
summarize(p.est.m=mean(p.est, na.rm=T),
p.est.sd=sd(p.est, na.rm=T)) %>%
write_tsv(here::here("tables", "JSDM_convergence", "GR_rho.tsv"))
myrhos %>%
group_by(model, thin) %>%
summarize(ess.m=mean(ess, na.rm=T),
ess.sd=sd(ess, na.rm=T)) %>%
ungroup() %>%
write_tsv(here::here("tables", "JSDM_convergence", "ESS_rho.tsv"))
mybetas %>%
group_by(model, thin) %>%
summarize(p.est.m=mean(p.est, na.rm=T),
p.est.sd=sd(p.est, na.rm=T)) %>%
write_tsv(here::here("tables", "JSDM_convergence", "GR_betas.tsv"))
# JSDM R2 (variance explained) --------------------------------------------
myr2final <- map_dfr(mynames, R2fit)
myr2final %>%
ungroup() %>%
group_by(model, thin, phase, distribution) %>%
summarize(mean_R2=mean(R2, na.rm=T),
sd_R2=sd(R2, na.rm=T),
mean_RMSE=mean(RMSE, na.rm=T),
sd_RMSE=sd(RMSE, na.rm=T),
mean_AUC=mean(AUC, na.rm=T),
sd_AUC=sd(AUC, na.rm=T)) %>%
write_tsv(here::here("tables", "JSDM_fits", "R2_summary_results.tsv"))
# JSDM WAIC ---------------------------------------------------------------
# WAIC stand for widely applicable information criterion
mywaicfinal <- map_dfr(mynames, WAICfit)
mywaicfinal %>%
group_by(phase, distribution, model, thin) %>%
summarize(min=min(WAIC_full)) %>%
write_tsv(here::here("tables", "JSDM_fits", "min_waic.tsv"))
# 5 fold cross validation -------------------------------------------------
# 5f cv takes a long time so we only fit for a subset of the models
finalnames <- c("qeq_ma_full_thin_1000_samples_250_chains_4.rds",
"qeq_mp_full_thin_1000_samples_250_chains_4.rds",
"sort_ma_full_thin_1000_samples_250_chains_4.rds",
"sort_mp_full_thin_1000_samples_250_chains_4.rds")
mycv5final <- map_dfr(finalnames, cv5fit)
mycv5final %>%
filter(R2 > 0) %>%
group_by(phase, distribution) %>%
mutate(mean_R2=mean(R2, na.rm=T),
sd_R2=sd(R2, na.rm=T),
mean_RMSE=mean(RMSE, na.rm=T),
sd_RMSE=sd(RMSE, na.rm=T),
mean_AUC=mean(AUC, na.rm=T),
sd_AUC=sd(AUC, na.rm=T)) %>%
write_tsv(here::here("tables", "JSDM_fits", "cv5f_results.tsv"))
# Table S9 ----------------------------------------------------------------
tcv5f <- mycv5final %>%
ungroup() %>%
filter(R2 > 0) %>%
group_by(phase, distribution) %>%
summarize(mean_predictive_R2=mean(R2, na.rm=T),
sd_predictive_R2=sd(R2, na.rm=T)) %>%
mutate(model="full")
tr2 <-myr2final %>%
filter(thin==1000) %>%
group_by(phase, distribution, model) %>%
summarize(mean_explanatory_R2=mean(R2, na.rm=T),
sd_explanatory_R2=sd(R2, na.rm=T))
# make and format final table
left_join(tr2, tcv5f) %>%
dplyr::rename(Included_effects=model,
Model=distribution,
Phase=phase) %>%
dplyr::mutate(Phase=if_else(Phase=="sorting", "Sorting", "Equil"),
Model=if_else(Model=="normal", "COP", "PA"),
Included_effects=case_when(Included_effects == "full" ~ "Full",
Included_effects == "no_fixed_effects" ~ "Random only",
Included_effects == "no_random_effects" ~ "Fixed only")) %>%
mutate(Phase=factor(Phase, levels=c("Sorting", "Equil")),
Model=factor(Model, levels=c("COP", "PA")),
Included_effects=factor(Included_effects, levels=c("Full", "Fixed only", "Random only"))) %>%
arrange(Model, Phase, Included_effects) %>%
xtable::xtable(auto=TRUE) %>%
print() %>%
write_lines(here::here("tables", "table_S9.tex"))