Skip to content

Commit 49271da

Browse files
committed
geos with all NAs play poorly (can't calculate growth rate)
1 parent 41aa2f1 commit 49271da

File tree

2 files changed

+48
-48
lines changed

2 files changed

+48
-48
lines changed

tests/testthat/_snaps/snapshots.md

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,54 +1237,52 @@
12371237
# arx_classifier snapshots
12381238

12391239
structure(list(geo_value = c("ak", "al", "ar", "az", "ca", "co",
1240-
"ct", "dc", "de", "fl", "ga", "gu", "hi", "ia", "id", "il", "in",
1241-
"ks", "ky", "la", "ma", "me", "mi", "mn", "mo", "mp", "ms", "mt",
1242-
"nc", "nd", "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or",
1243-
"pa", "pr", "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa",
1244-
"wi", "wv", "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L,
1240+
"ct", "dc", "de", "fl", "ga", "hi", "ia", "id", "il", "in", "ks",
1241+
"ky", "la", "ma", "me", "mi", "mn", "mo", "ms", "mt", "nc", "nd",
1242+
"ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", "pa", "pr",
1243+
"ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", "wi", "wv",
1244+
"wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L,
12451245
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1246-
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1247-
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1248-
1L), levels = c("(-Inf,0.25]", "(0.25, Inf]"), class = "factor"),
1249-
forecast_date = structure(c(18992, 18992, 18992, 18992, 18992,
1250-
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1251-
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1252-
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1253-
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1254-
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1255-
18992, 18992, 18992), class = "Date"), target_date = structure(c(18999,
1256-
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1257-
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1258-
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1259-
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1260-
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1261-
18999, 18999, 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA,
1262-
-53L), class = c("tbl_df", "tbl", "data.frame"))
1246+
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L,
1247+
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L), levels = c("(-Inf,0.25]",
1248+
"(0.25, Inf]"), class = "factor"), forecast_date = structure(c(18992,
1249+
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1250+
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1251+
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1252+
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1253+
18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992,
1254+
18992, 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999,
1255+
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1256+
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1257+
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1258+
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1259+
18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999,
1260+
18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA,
1261+
-51L), class = c("tbl_df", "tbl", "data.frame"))
12631262

12641263
---
12651264

12661265
structure(list(geo_value = c("ak", "al", "ar", "az", "ca", "co",
1267-
"ct", "dc", "de", "fl", "ga", "gu", "hi", "ia", "id", "il", "in",
1268-
"ks", "ky", "la", "ma", "me", "mi", "mn", "mo", "mp", "ms", "mt",
1269-
"nc", "nd", "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or",
1270-
"pa", "pr", "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa",
1271-
"wi", "wv", "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L,
1272-
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1266+
"ct", "dc", "de", "fl", "ga", "hi", "ia", "id", "il", "in", "ks",
1267+
"ky", "la", "ma", "me", "mi", "mn", "mo", "ms", "mt", "nc", "nd",
1268+
"ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", "pa", "pr",
1269+
"ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", "wi", "wv",
1270+
"wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L,
12731271
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
12741272
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1275-
1L), levels = c("(-Inf,0.25]", "(0.25, Inf]"), class = "factor"),
1276-
forecast_date = structure(c(18994, 18994, 18994, 18994, 18994,
1277-
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1278-
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1279-
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1280-
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1281-
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1282-
18994, 18994, 18994), class = "Date"), target_date = structure(c(19001,
1283-
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1284-
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1285-
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1286-
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1287-
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1288-
19001, 19001, 19001, 19001, 19001, 19001, 19001), class = "Date")), row.names = c(NA,
1289-
-53L), class = c("tbl_df", "tbl", "data.frame"))
1273+
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L), levels = c("(-Inf,0.25]",
1274+
"(0.25, Inf]"), class = "factor"), forecast_date = structure(c(18994,
1275+
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1276+
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1277+
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1278+
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1279+
18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994,
1280+
18994, 18994, 18994, 18994, 18994), class = "Date"), target_date = structure(c(19001,
1281+
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1282+
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1283+
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1284+
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1285+
19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001,
1286+
19001, 19001, 19001, 19001, 19001), class = "Date")), row.names = c(NA,
1287+
-51L), class = c("tbl_df", "tbl", "data.frame"))
12901288

tests/testthat/test-snapshots.R

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,18 @@ test_that("arx_forecaster output format snapshots", {
146146
})
147147

148148
test_that("arx_classifier snapshots", {
149+
train <- covid_case_death_rates %>%
150+
filter(geo_value %nin% c("as", "gu", "mp", "vi"))
149151
arc1 <- arx_classifier(
150-
covid_case_death_rates %>%
152+
train %>%
151153
dplyr::filter(time_value >= as.Date("2021-11-01")),
152154
"death_rate",
153155
c("case_rate", "death_rate")
154156
)
155157
expect_snapshot_tibble(arc1$predictions)
156-
max_date <- covid_case_death_rates$time_value %>% max()
158+
max_date <- train$time_value %>% max()
157159
arc2 <- arx_classifier(
158-
covid_case_death_rates %>%
160+
train %>%
159161
dplyr::filter(time_value >= as.Date("2021-11-01")),
160162
"death_rate",
161163
c("case_rate", "death_rate"),
@@ -164,7 +166,7 @@ test_that("arx_classifier snapshots", {
164166
expect_snapshot_tibble(arc2$predictions)
165167
expect_error(
166168
arc3 <- arx_classifier(
167-
covid_case_death_rates %>%
169+
train %>%
168170
dplyr::filter(time_value >= as.Date("2021-11-01")),
169171
"death_rate",
170172
c("case_rate", "death_rate"),
@@ -174,7 +176,7 @@ test_that("arx_classifier snapshots", {
174176
)
175177
expect_error(
176178
arc4 <- arx_classifier(
177-
covid_case_death_rates %>%
179+
train %>%
178180
dplyr::filter(time_value >= as.Date("2021-11-01")),
179181
"death_rate",
180182
c("case_rate", "death_rate"),

0 commit comments

Comments
 (0)