-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpenalized_logistic.qmd
157 lines (100 loc) · 2.85 KB
/
penalized_logistic.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
title: "PENALIZED LOGISTIC REGRESSION"
format: html
---
## PENALIZED LOGISTIC REGRESSION
> https://www.tidymodels.org/start/case-study/
Since our outcome variable children is categorical, logistic regression would be a good first model to start. Let’s use a model that can perform feature selection during training. The `glmnet` R package fits a generalized linear model via penalized maximum likelihood. It fits linear, logistic and multinomial, poisson, and Cox regression models. It can also fit multi-response linear regression, generalized linear models for custom families, and relaxed lasso regression models.
和LASSO类似,对应于outcome位二分类变量的情况。
```{r}
library(tidymodels)
library(glmnet)
```
### 所使用数据集
```{r, echo=FALSE, warning=FALSE, message=FALSE}
hotels <-
read_csv('./datasets/hotels.csv') %>%
mutate(across(where(is.character), as.factor))
dim(hotels)
set.seed(123)
splits <- initial_split(hotels, strata = children)
hotel_other <- training(splits)
hotel_test <- testing(splits)
val_set <- validation_split(hotel_other,
strata = children,
prop = 0.80)
```
### 基于tidymodels的实现
建个模型:
```{r}
lr_mod <-
logistic_reg(penalty = tune(), mixture = 1) %>%
set_engine("glmnet")
```
recipe 变量,
对于分类变量需要手工dummy??
```{r}
holidays <- c("AllSouls", "AshWednesday", "ChristmasEve", "Easter",
"ChristmasDay", "GoodFriday", "NewYearsDay", "PalmSunday")
lr_recipe <-
recipe(children ~ ., data = hotel_other) %>%
step_date(arrival_date) %>%
step_holiday(arrival_date, holidays = holidays) %>%
step_rm(arrival_date) %>%
step_dummy(all_nominal_predictors()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors())
```
```{r}
lr_workflow <-
workflow() %>%
add_model(lr_mod) %>%
add_recipe(lr_recipe)
```
```{r}
lr_reg_grid <- tibble(penalty = 10^seq(-4, -1, length.out = 30))
lr_res <-
lr_workflow %>%
tune_grid(val_set,
grid = lr_reg_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(roc_auc))
```
```{r}
lr_plot <-
lr_res %>%
collect_metrics() %>%
ggplot(aes(x = penalty, y = mean)) +
geom_point() +
geom_line() +
ylab("Area under the ROC Curve") +
scale_x_log10(labels = scales::label_number())
lr_plot
```
```{r}
top_models <-
lr_res %>%
show_best("roc_auc", n = 15) %>%
arrange(penalty)
top_models
```
```{r}
lr_best <-
lr_res %>%
collect_metrics() %>%
arrange(penalty) %>%
slice(12)
lr_best
```
```{r}
lr_auc <-
lr_res %>%
collect_predictions(parameters = lr_best) %>%
roc_curve(children, .pred_children) %>%
mutate(model = "Logistic Regression")
autoplot(lr_auc)
```
### 基于glmnet的实现
```{r}
library(glmnet)
```