|
| 1 | +# train an autoencoder with attention mechanism for multivariate time series |
| 2 | +import sys |
| 3 | +import os |
| 4 | +import time |
| 5 | +import copy |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +import torch.multiprocessing as mp |
| 10 | +from torch.utils.data import DataLoader |
| 11 | + |
| 12 | +from nn_architecture.ae_networks import TransformerAutoencoder, TransformerFlattenAutoencoder, TransformerDoubleAutoencoder, train, save |
| 13 | +from helpers.dataloader import Dataloader |
| 14 | +from helpers import system_inputs |
| 15 | +from helpers.trainer import AETrainer |
| 16 | +from helpers.ddp_training import AEDDPTrainer, run |
| 17 | +from helpers.get_master import find_free_port |
| 18 | + |
| 19 | + |
| 20 | +def main(): |
| 21 | + |
| 22 | + # ------------------------------------------------------------------------------------------------------------------ |
| 23 | + # Configure training parameters |
| 24 | + # ------------------------------------------------------------------------------------------------------------------ |
| 25 | + |
| 26 | + default_args = system_inputs.parse_arguments(sys.argv, file='autoencoder_training_main.py') |
| 27 | + print('-----------------------------------------\n') |
| 28 | + |
| 29 | + # User inputs |
| 30 | + opt = { |
| 31 | + 'path_dataset': default_args['path_dataset'], |
| 32 | + 'path_checkpoint': default_args['path_checkpoint'], |
| 33 | + 'save_name': default_args['save_name'], |
| 34 | + 'target': default_args['target'], |
| 35 | + 'sample_interval': default_args['sample_interval'], |
| 36 | + # 'conditions': default_args['conditions'], |
| 37 | + 'channel_label': default_args['channel_label'], |
| 38 | + 'channels_out': default_args['channels_out'], |
| 39 | + 'timeseries_out': default_args['timeseries_out'], |
| 40 | + 'n_epochs': default_args['n_epochs'], |
| 41 | + 'batch_size': default_args['batch_size'], |
| 42 | + 'train_ratio': default_args['train_ratio'], |
| 43 | + 'learning_rate': default_args['learning_rate'], |
| 44 | + 'hidden_dim': default_args['hidden_dim'], |
| 45 | + 'num_heads': default_args['num_heads'], |
| 46 | + 'num_layers': default_args['num_layers'], |
| 47 | + 'activation': default_args['activation'], |
| 48 | + 'learning_rate': default_args['learning_rate'], |
| 49 | + 'num_heads': default_args['num_heads'], |
| 50 | + 'num_layers': default_args['num_layers'], |
| 51 | + 'ddp': default_args['ddp'], |
| 52 | + 'ddp_backend': default_args['ddp_backend'], |
| 53 | + # 'n_conditions': len(default_args['conditions']) if default_args['conditions'][0] != '' else 0, |
| 54 | + 'norm_data': True, |
| 55 | + 'std_data': False, |
| 56 | + 'diff_data': False, |
| 57 | + 'kw_timestep': default_args['kw_timestep'], |
| 58 | + 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
| 59 | + 'world_size': torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count(), |
| 60 | + 'history': None, |
| 61 | + 'trained_epochs': 0 |
| 62 | + } |
| 63 | + |
| 64 | + # ---------------------------------------------------------------------------------------------------------------------- |
| 65 | + # Load, process, and split data |
| 66 | + # ---------------------------------------------------------------------------------------------------------------------- |
| 67 | + |
| 68 | + # Scale function -> Not necessary; already in dataloader -> param: norm_data=True |
| 69 | + # def scale(dataset): |
| 70 | + # x_min, x_max = dataset.min(), dataset.max() |
| 71 | + # return (dataset-x_min)/(x_max-x_min) |
| 72 | + |
| 73 | + data = Dataloader(path=opt['path_dataset'], |
| 74 | + channel_label=opt['channel_label'], kw_timestep=opt['kw_timestep'], |
| 75 | + norm_data=opt['norm_data'], std_data=opt['std_data'], diff_data=opt['diff_data'],) |
| 76 | + dataset = data.get_data() |
| 77 | + # dataset = dataset[:, opt['n_conditions']:, :].to(opt['device']) #Remove labels |
| 78 | + # dataset = scale(dataset) |
| 79 | + |
| 80 | + # Split data function |
| 81 | + def split_data(dataset, train_size=.8): |
| 82 | + num_samples = dataset.shape[0] |
| 83 | + shuffle_index = np.arange(num_samples) |
| 84 | + np.random.shuffle(shuffle_index) |
| 85 | + |
| 86 | + cutoff_index = int(num_samples*train_size) |
| 87 | + train = dataset[shuffle_index[:cutoff_index]] |
| 88 | + test = dataset[shuffle_index[cutoff_index:]] |
| 89 | + |
| 90 | + return test, train |
| 91 | + |
| 92 | + # Determine n_channels, output_dim, and seq_length |
| 93 | + opt['n_channels'] = dataset.shape[-1] |
| 94 | + opt['sequence_length'] = dataset.shape[1] |
| 95 | + |
| 96 | + # Split dataset and convert to pytorch dataloader class |
| 97 | + test_dataset, train_dataset = split_data(dataset, opt['train_ratio']) |
| 98 | + test_dataloader = DataLoader(test_dataset, batch_size=opt['batch_size'], shuffle=True) |
| 99 | + train_dataloader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True) |
| 100 | + |
| 101 | + # ------------------------------------------------------------------------------------------------------------------ |
| 102 | + # Initiate and train autoencoder |
| 103 | + # ------------------------------------------------------------------------------------------------------------------ |
| 104 | + |
| 105 | + # Initiate autoencoder |
| 106 | + model_dict = None |
| 107 | + if default_args['load_checkpoint'] and os.path.isfile(opt['path_checkpoint']): |
| 108 | + model_dict = torch.load(opt['path_checkpoint']) |
| 109 | + # model_state = model_dict['state_dict'] |
| 110 | + |
| 111 | + target_old = opt['target'] |
| 112 | + channels_out_old = opt['channels_out'] |
| 113 | + timeseries_out_old = opt['timeseries_out'] |
| 114 | + |
| 115 | + opt['target'] = model_dict['configuration']['target'] |
| 116 | + opt['channels_out'] = model_dict['configuration']['channels_out'] |
| 117 | + opt['timeseries_out'] = model_dict['configuration']['timeseries_out'] |
| 118 | + |
| 119 | + # Report changes to user |
| 120 | + print(f"Loading model {opt['path_checkpoint']}.\n\nInhereting the following parameters:") |
| 121 | + print("parameter:\t\told value -> new value") |
| 122 | + print(f"target:\t\t\t{target_old} -> {opt['target']}") |
| 123 | + print(f"channels_out:\t{channels_out_old} -> {opt['channels_out']}") |
| 124 | + print(f"timeseries_out:\t{timeseries_out_old} -> {opt['timeseries_out']}") |
| 125 | + print('-----------------------------------\n') |
| 126 | + # print(f"Target: {opt['target']}") |
| 127 | + # if (opt['target'] == 'channels') | (opt['target'] == 'full'): |
| 128 | + # print(f"channels_out: {opt['channels_out']}") |
| 129 | + # if (opt['target'] == 'timeseries') | (opt['target'] == 'full'): |
| 130 | + # print(f"timeseries_out: {opt['timeseries_out']}") |
| 131 | + # print('-----------------------------------\n') |
| 132 | + |
| 133 | + elif default_args['load_checkpoint'] and not os.path.isfile(opt['path_checkpoint']): |
| 134 | + raise FileNotFoundError(f"Checkpoint file {opt['path_checkpoint']} not found.") |
| 135 | + |
| 136 | + # Add parameters for tracking |
| 137 | + opt['input_dim'] = opt['n_channels'] if opt['target'] in ['channels', 'full'] else opt['sequence_length'] |
| 138 | + opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['n_channels'] |
| 139 | + opt['output_dim_2'] = opt['sequence_length'] if opt['target'] in ['channels'] else opt['timeseries_out'] |
| 140 | + |
| 141 | + if opt['target'] == 'channels': |
| 142 | + model = TransformerAutoencoder(input_dim=opt['n_channels'], |
| 143 | + output_dim=opt['channels_out'], |
| 144 | + output_dim_2=opt['sequence_length'], |
| 145 | + target=TransformerAutoencoder.TARGET_CHANNELS, |
| 146 | + hidden_dim=opt['hidden_dim'], |
| 147 | + num_layers=opt['num_layers'], |
| 148 | + num_heads=opt['num_heads'],).to(opt['device']) |
| 149 | + elif opt['target'] == 'time': |
| 150 | + model = TransformerAutoencoder(input_dim=opt['sequence_length'], |
| 151 | + output_dim=opt['timeseries_out'], |
| 152 | + output_dim_2=opt['n_channels'], |
| 153 | + target=TransformerAutoencoder.TARGET_TIMESERIES, |
| 154 | + hidden_dim=opt['hidden_dim'], |
| 155 | + num_layers=opt['num_layers'], |
| 156 | + num_heads=opt['num_heads'],).to(opt['device']) |
| 157 | + elif opt['target'] == 'full': |
| 158 | + model = TransformerDoubleAutoencoder(input_dim=opt['n_channels'], |
| 159 | + output_dim=opt['output_dim'], |
| 160 | + output_dim_2=opt['output_dim_2'], |
| 161 | + sequence_length=opt['sequence_length'], |
| 162 | + hidden_dim=opt['hidden_dim'], |
| 163 | + num_layers=opt['num_layers'], |
| 164 | + num_heads=opt['num_heads'],).to(opt['device']) |
| 165 | + else: |
| 166 | + raise ValueError(f"Encode target '{opt['target']}' not recognized, options are 'channels', 'time', or 'full'.") |
| 167 | + |
| 168 | + # Populate model configuration |
| 169 | + history = {} |
| 170 | + for key in opt.keys(): |
| 171 | + if (not key == 'history') | (not key == 'trained_epochs'): |
| 172 | + history[key] = [opt[key]] |
| 173 | + history['trained_epochs'] = [] |
| 174 | + |
| 175 | + if model_dict is not None: |
| 176 | + # update history |
| 177 | + for key in history.keys(): |
| 178 | + history[key] = model_dict['configuration']['history'][key] + history[key] |
| 179 | + |
| 180 | + opt['history'] = history |
| 181 | + |
| 182 | + if opt['ddp']: |
| 183 | + trainer = AEDDPTrainer(model, opt) |
| 184 | + if default_args['load_checkpoint']: |
| 185 | + trainer.load_checkpoint(default_args['path_checkpoint']) |
| 186 | + mp.spawn(run, args=(opt['world_size'], find_free_port(), opt['ddp_backend'], trainer, opt), |
| 187 | + nprocs=opt['world_size'], join=True) |
| 188 | + else: |
| 189 | + trainer = AETrainer(model, opt) |
| 190 | + if default_args['load_checkpoint']: |
| 191 | + trainer.load_checkpoint(default_args['path_checkpoint']) |
| 192 | + samples = trainer.training(train_dataloader, test_dataloader) |
| 193 | + model = trainer.model |
| 194 | + print("Training finished.") |
| 195 | + |
| 196 | + # ---------------------------------------------------------------------------------------------------------------------- |
| 197 | + # Save autoencoder |
| 198 | + # ---------------------------------------------------------------------------------------------------------------------- |
| 199 | + |
| 200 | + # Save model |
| 201 | + # model_dict = dict(state_dict=model.state_dict(), config=model.config) |
| 202 | + if opt['save_name'] is None: |
| 203 | + fn = opt['path_dataset'].split('/')[-1].split('.csv')[0] |
| 204 | + opt['save_name'] = os.path.join("trained_ae", f"ae_{fn}_{str(time.time()).split('.')[0]}.pt") |
| 205 | + # save(model_dict, save_name) |
| 206 | + |
| 207 | + trainer.save_checkpoint(opt['save_name'], update_history=True, samples=samples) |
| 208 | + print(f"Model and configuration saved in {opt['save_name']}") |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + main() |
0 commit comments