|
1 | 1 | #' Direct autoregressive classifier with covariates
|
2 | 2 | #'
|
3 |
| -#' This is an autoregressive classification model for |
4 |
| -#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning |
5 |
| -#' that it estimates a class at a particular target horizon. |
| 3 | +#' |
| 4 | +#' @description |
| 5 | +#' This is an autoregressive classification model for continuous data. It does |
| 6 | +#' "direct" forecasting, meaning that it estimates a class at a particular |
| 7 | +#' target horizon. |
| 8 | +#' |
| 9 | +#' @details |
| 10 | +#' The `arx_classifier()` is an autoregressive classification model for `epi_df` |
| 11 | +#' data that is used to predict a discrete class for each case under |
| 12 | +#' consideration. It is a direct forecaster in that it estimates the classes |
| 13 | +#' at a specific horizon or ahead value. |
| 14 | +#' |
| 15 | +#' To get a sense of how the `arx_classifier()` works, let's consider a simple |
| 16 | +#' example with minimal inputs. For this, we will use the built-in |
| 17 | +#' `covid_case_death_rates` that contains confirmed COVID-19 cases and deaths |
| 18 | +#' from JHU CSSE for all states over Dec 31, 2020 to Dec 31, 2021. From this, |
| 19 | +#' we'll take a subset of data for five states over June 4, 2021 to December |
| 20 | +#' 31, 2021. Our objective is to predict whether the case rates are increasing |
| 21 | +#' when considering the 0, 7 and 14 day case rates: |
| 22 | +#' |
| 23 | +#' ```{r} |
| 24 | +#' jhu <- covid_case_death_rates %>% |
| 25 | +#' filter( |
| 26 | +#' time_value >= "2021-06-04", |
| 27 | +#' time_value <= "2021-12-31", |
| 28 | +#' geo_value %in% c("ca", "fl", "tx", "ny", "nj") |
| 29 | +#' ) |
| 30 | +#' |
| 31 | +#' out <- arx_classifier(jhu, outcome = "case_rate", predictors = "case_rate") |
| 32 | +#' |
| 33 | +#' out$predictions |
| 34 | +#' ``` |
| 35 | +#' |
| 36 | +#' The key takeaway from the predictions is that there are two prediction |
| 37 | +#' classes: `(-Inf, 0.25]` and `(0.25, Inf)`: the classes to predict must be |
| 38 | +#' discrete. The discretization of the real-valued outcome is controlled by |
| 39 | +#' the `breaks` argument, which defaults to `0.25`. Such breaks will be |
| 40 | +#' automatically extended to cover the entire real line. For example, the |
| 41 | +#' default break of `0.25` is silently extended to `breaks = c(-Inf, .25, |
| 42 | +#' Inf)` and, therefore, results in two classes: `[-Inf, 0.25]` and `(0.25, |
| 43 | +#' Inf)`. These two classes are used to discretize the outcome. The conversion |
| 44 | +#' of the outcome to such classes is handled internally. So if discrete |
| 45 | +#' classes already exist for the outcome in the `epi_df`, then we recommend to |
| 46 | +#' code a classifier from scratch using the `epi_workflow` framework for more |
| 47 | +#' control. |
| 48 | +#' |
| 49 | +#' The `trainer` is a `parsnip` model describing the type of estimation such |
| 50 | +#' that `mode = "classification"` is enforced. The two typical trainers that |
| 51 | +#' are used are `parsnip::logistic_reg()` for two classes or |
| 52 | +#' `parsnip::multinom_reg()` for more than two classes. |
| 53 | +#' |
| 54 | +#' ```{r} |
| 55 | +#' workflows::extract_spec_parsnip(out$epi_workflow) |
| 56 | +#' ``` |
| 57 | +#' |
| 58 | +#' From the parsnip model specification, we can see that the trainer used is |
| 59 | +#' logistic regression, which is expected for our binary outcome. More |
| 60 | +#' complicated trainers like `parsnip::naive_Bayes()` or |
| 61 | +#' `parsnip::rand_forest()` may also be used (however, we will stick to the |
| 62 | +#' basics in this gentle introduction to the classifier). |
| 63 | +#' |
| 64 | +#' If you use the default trainer of logistic regression for binary |
| 65 | +#' classification and you decide against using the default break of 0.25, then |
| 66 | +#' you should only input one break so that there are two classification bins |
| 67 | +#' to properly dichotomize the outcome. For example, let's set a break of 0.5 |
| 68 | +#' instead of relying on the default of 0.25. We can do this by passing 0.5 to |
| 69 | +#' the `breaks` argument in `arx_class_args_list()` as follows: |
| 70 | +#' |
| 71 | +#' ```{r} |
| 72 | +#' out_break_0.5 <- arx_classifier( |
| 73 | +#' jhu, |
| 74 | +#' outcome = "case_rate", |
| 75 | +#' predictors = "case_rate", |
| 76 | +#' args_list = arx_class_args_list( |
| 77 | +#' breaks = 0.5 |
| 78 | +#' ) |
| 79 | +#' ) |
| 80 | +#' |
| 81 | +#' out_break_0.5$predictions |
| 82 | +#' ``` |
| 83 | +#' Indeed, we can observe that the two `.pred_class` are now (-Inf, 0.5] and |
| 84 | +#' (0.5, Inf). See `help(arx_class_args_list)` for other available |
| 85 | +#' modifications. |
| 86 | +#' |
| 87 | +#' Additional arguments that may be supplied to `arx_class_args_list()` include |
| 88 | +#' the expected `lags` and `ahead` arguments for an autoregressive-type model. |
| 89 | +#' These have default values of 0, 7, and 14 days for the lags of the |
| 90 | +#' predictors and 7 days ahead of the forecast date for predicting the |
| 91 | +#' outcome. There is also `n_training` to indicate the upper bound for the |
| 92 | +#' number of training rows per key. If you would like some practice with using |
| 93 | +#' this, then remove the filtering command to obtain data within "2021-06-04" |
| 94 | +#' and "2021-12-31" and instead set `n_training` to be the number of days |
| 95 | +#' between these two dates, inclusive of the end points. The end results |
| 96 | +#' should be the same. In addition to `n_training`, there are `forecast_date` |
| 97 | +#' and `target_date` to specify the date that the forecast is created and |
| 98 | +#' intended, respectively. We will not dwell on such arguments here as they |
| 99 | +#' are not unique to this classifier or absolutely essential to understanding |
| 100 | +#' how it operates. The remaining arguments will be discussed organically, as |
| 101 | +#' they are needed to serve our purposes. For information on any remaining |
| 102 | +#' arguments that are not discussed here, please see the function |
| 103 | +#' documentation for a complete list and their definitions. |
6 | 104 | #'
|
7 | 105 | #' @inheritParams arx_forecaster
|
8 | 106 | #' @param outcome A character (scalar) specifying the outcome (in the
|
@@ -68,9 +166,7 @@ arx_classifier <- function(
|
68 | 166 | }
|
69 | 167 | forecast_date <- args_list$forecast_date %||% forecast_date_default
|
70 | 168 | target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
|
71 |
| - preds <- forecast( |
72 |
| - wf, |
73 |
| - ) %>% |
| 169 | + preds <- forecast(wf) %>% |
74 | 170 | as_tibble() %>%
|
75 | 171 | select(-time_value)
|
76 | 172 |
|
@@ -249,7 +345,7 @@ arx_class_epi_workflow <- function(
|
249 | 345 | #' be created using growth rates (as the predictors are) or lagged
|
250 | 346 | #' differences. The second case is closer to the requirements for the
|
251 | 347 | #' [2022-23 CDC Flusight Hospitalization Experimental Target](https://github.com/cdcepi/Flusight-forecast-data/blob/745511c436923e1dc201dea0f4181f21a8217b52/data-experimental/README.md).
|
252 |
| -#' See the Classification Vignette for details of how to create a reasonable |
| 348 | +#' See the [Classification chapter from the forecasting book](https://cmu-delphi.github.io/delphi-tooling-book/arx-classifier.html) Vignette for details of how to create a reasonable |
253 | 349 | #' baseline for this case. Selecting `"growth_rate"` (the default) uses
|
254 | 350 | #' [epiprocess::growth_rate()] to create the outcome using some of the
|
255 | 351 | #' additional arguments below. Choosing `"lag_difference"` instead simply
|
|
0 commit comments