-
Notifications
You must be signed in to change notification settings - Fork 1
/
SHAPforxgboost2
76 lines (61 loc) · 2.2 KB
/
SHAPforxgboost2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
rm(list=ls())
setwd("F:/healthy retail")
set.seed(123)
#install.packages("SHAPforxgboost")
require(SHAPforxgboost)
#install.packages("varhandle")
library(varhandle)
data=read.csv("rdata_mRFEI.csv")
head(data)
data=data[1:3000,]
X1=as.matrix(data[which(names(data) != "mrfei")])
head(X1)
ntree=5
mod1 = xgboost::xgboost(
data = X1, label = data$mrfei, gamma = 0, eta =1,
lambda = 0,nrounds = ntree, verbose = FALSE)
# shap.values(model, X_dataset) returns the SHAP
# data matrix and ranked features by mean|SHAP|
shap_values<- shap.values(xgb_model = mod1, X_train = X1) #returns the SHAP
shap_score<-shap_values$shap_score #all shap scores
View(shap_score)
shap_data<-cbind.data.frame(shap_score, shap_values$BIAS0) #add intercept to the score data
pred_mod<-predict(mod1,X1,ntreelimit=ntree) #add predicted values to the score data
shap_data[, `:=`(rowSum = round(rowSums(shap_data),6), pred_mod = round(pred_mod,6))] #add row sum and predicted values
View(shap_data)
View(shap_score)
colMeans(shap_data)
shap_mean<-shap_values$mean_shap_score #mean shap scores
sm1
sm1=cbind.data.frame(lab=labels(shap_mean),sm=as.numeric(round(shap_mean,3)) )
ggplot(data=sm1,aes(x=reorder(lab,sm),y=sm) )+
geom_bar(stat="identity")+
coord_flip()
shap_long<- shap.prep(xgb_model = mod1, X_train = X1) #long format for graph
shap.plot.summary(shap_long)
sm2=rbind.data.frame(sm1,cbind.data.frame(lab="mean_pred",sm=round(mean(pred_mod),3)))
ggplot(data=sm2,aes(x=reorder(lab,sm),y=sm) )+
geom_bar(stat="identity")+
coord_flip()
cc=c()
for(i in 1:length(colnames(shap_score))){
sh=as.matrix(shap_score)[,i]
or=c(X1[,i])
co=cor(sh,or)
cc=c(cc,colnames(shap_score)[[i]],co)
}
aa=data.frame(matrix(c(cc),ncol=2,byrow = TRUE))
bb=cbind.data.frame(X1=names(shap_mean),shap_mean)
sm3=merge(aa,bb,by="X1",sort=TRUE)
sm4=read.csv("varlabels.csv")
colnames(sm4)=c("X1","varlabs")
sm3=merge(sm3,sm4,by="X1",sort=TRUE,all.x=TRUE)
sm3=unfactor(sm3)
ggplot(data=sm3,aes(x=reorder(varlabs,shap_mean),y=shap_mean,fill=X2))+
geom_bar(stat="identity")+
coord_flip()+
ggtitle("Direction of Association")+
xlab("Predictor")+
ylab("Mean |Shapley|")+
labs(fill="Correlation")+
scale_fill_gradient2(low="navyblue",mid="white",high="red")