This is the official PyTorch implementation of the paper "Spatial–Temporal Synchronous Graph Transformer network (STSGT) for COVID-19 forecasting" that was presented at IEEE/ACM CHASE 2022 conference and published in Elsevier Smart Health Journal (2022).
- python >= 3.6
- pytorch >= 1.8.0
The following steps are required to replicate our work:
- Download datasets.
- JHU Dataset - Download JHU COVID time-series data (download
time_series_covid19_confirmed_US.csv
for daily US infected cases andtime_series_covid19_deaths_US.csv
for daily US death cases) and save indata/COVID_JHU
directory. This project usedMar 15,2020 - Nov 30,2021
for analysis. - NYT Dataset - Download NYT COVID time-series data (download
us-states.csv
for daily US infected and death cases) and save indata/COVID_NYT
directory. This project usedMar 18,2020 - Nov 30,2021
for analysis.
- Generate Feature Matrix (X) and Adjacency Matrix (W) from downloaded datasets.
- JHU Dataset (US) - Inside the folder
data/COVID_JHU
, run the fileGenerate_51_states_X_W.py
to generate X and W matrix for 50 states of US and Washington D.C. (51 nodes of graph). - JHU Dataset (Michigan) - Inside the folder
data/COVID_JHU
, run the fileGenerate_51_states_X_W_Michigan.py
to generate X and W matrix for 83 counties of the state of Michigan (83 nodes of graph). - NYT Dataset (US) - Inside the folder
data/COVID_NYT
, run the fileGenerate_51_states_X_W_NYT.py
to generate X matrix for 50 states of US and Washington D.C. (51 nodes of graph). We used the same adjacency matrix (W) as generated using JHU dataset.
- Generate Train, Validation and Test datasets from the generated X matrix.
- We divided the entire dataset in chronological order with 80% training, 10% validation and 10% testing.
- Run the file
generate_training_data.py
to generate the processed filestrain.npz, val.npz, test.npz
from X matrix and save the processed files indata/COVID_JHU/processed
ordata/COVID_NYT/processed
. Use theconfirmed
ordeaths
in the argument to generate infected and death cases processed files respectively.
# For JHU Daily Infected cases data
python generate_training_data.py --traffic_df_filename "data/COVID_JHU/covid19_confirmed_US_51_states_X_matrix_final.csv"
# For NYT Daily Death cases data
python generate_training_data.py --traffic_df_filename "data/COVID_NYT/covid19_NYT_deaths_US_51_states_X_matrix_final.csv"
- Define paths and hyper-parameters in configuration files.
- Refer to the files
config/COVID_JHU.conf
andconfig/COVID_NYT.conf
for the data paths, hyper-parameters and model configurations used for training and testing. - The
sensors_distance
in the config files indicate the path to the adjacency matrix W.
- Train the model
python train.py --epochs 100 --learning_rate 0.001 --expid 1 --print_every 20
- The pre-trained models could be found in
checkpoints/pretrained_models
- Refer to the required folder
JHU or NYT
,Infected or Deaths
for infected or death cases respectively and our model is in folderSTST
- Test the model
- An example for testing with
COVID_JHU
dataset's daily infected cases andCOVID_NYT
dataset's daily death cases with our modelSTST
(name in code for STSGT model) is given here. The... _best_model.pth
indicates the model with the lowest Mean Absolute Error (MAE) on the validation set.
# For JHU Daily Infected cases data with our trained model
python test.py --checkpoint "checkpoints/pretrained_models/JHU_States_Infected/STST/exp_2_1654.67_best_model.pth"
# For NYT Daily Death cases data with our trained model
python test.py --checkpoint "checkpoints/pretrained_models/NYT_States_Deaths/STST/exp_1_19.06_best_model.pth"
- Please choose the correct configuration file with the
DATASET
variable in bothtrain.py
andtest.py
.
Please cite our paper if you find this work useful for your research:
@article{banerjee2022spatial,
title={Spatial--temporal synchronous graph transformer network (STSGT) for COVID-19 forecasting},
author={Banerjee, Soumyanil and Dong, Ming and Shi, Weisong},
journal={Smart Health},
volume={26},
pages={100348},
year={2022},
publisher={Elsevier}
}