Skip to content

Commit e6862a9

Browse files
committed
Commit the Autoencoder
1 parent 528c7fc commit e6862a9

22 files changed

+37503
-0
lines changed

autoencoder_training_main.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# train an autoencoder with attention mechanism for multivariate time series
2+
import sys
3+
import os
4+
import time
5+
import copy
6+
import numpy as np
7+
import torch
8+
import torch.nn as nn
9+
import torch.multiprocessing as mp
10+
from torch.utils.data import DataLoader
11+
12+
from nn_architecture.ae_networks import TransformerAutoencoder, TransformerFlattenAutoencoder, TransformerDoubleAutoencoder, train, save
13+
from helpers.dataloader import Dataloader
14+
from helpers import system_inputs
15+
from helpers.trainer import AETrainer
16+
from helpers.ddp_training import AEDDPTrainer, run
17+
from helpers.get_master import find_free_port
18+
19+
20+
def main():
21+
22+
# ------------------------------------------------------------------------------------------------------------------
23+
# Configure training parameters
24+
# ------------------------------------------------------------------------------------------------------------------
25+
26+
default_args = system_inputs.parse_arguments(sys.argv, file='autoencoder_training_main.py')
27+
print('-----------------------------------------\n')
28+
29+
# User inputs
30+
opt = {
31+
'path_dataset': default_args['path_dataset'],
32+
'path_checkpoint': default_args['path_checkpoint'],
33+
'save_name': default_args['save_name'],
34+
'target': default_args['target'],
35+
'sample_interval': default_args['sample_interval'],
36+
# 'conditions': default_args['conditions'],
37+
'channel_label': default_args['channel_label'],
38+
'channels_out': default_args['channels_out'],
39+
'timeseries_out': default_args['timeseries_out'],
40+
'n_epochs': default_args['n_epochs'],
41+
'batch_size': default_args['batch_size'],
42+
'train_ratio': default_args['train_ratio'],
43+
'learning_rate': default_args['learning_rate'],
44+
'hidden_dim': default_args['hidden_dim'],
45+
'num_heads': default_args['num_heads'],
46+
'num_layers': default_args['num_layers'],
47+
'activation': default_args['activation'],
48+
'learning_rate': default_args['learning_rate'],
49+
'num_heads': default_args['num_heads'],
50+
'num_layers': default_args['num_layers'],
51+
'ddp': default_args['ddp'],
52+
'ddp_backend': default_args['ddp_backend'],
53+
# 'n_conditions': len(default_args['conditions']) if default_args['conditions'][0] != '' else 0,
54+
'norm_data': True,
55+
'std_data': False,
56+
'diff_data': False,
57+
'kw_timestep': default_args['kw_timestep'],
58+
'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
59+
'world_size': torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count(),
60+
'history': None,
61+
'trained_epochs': 0
62+
}
63+
64+
# ----------------------------------------------------------------------------------------------------------------------
65+
# Load, process, and split data
66+
# ----------------------------------------------------------------------------------------------------------------------
67+
68+
# Scale function -> Not necessary; already in dataloader -> param: norm_data=True
69+
# def scale(dataset):
70+
# x_min, x_max = dataset.min(), dataset.max()
71+
# return (dataset-x_min)/(x_max-x_min)
72+
73+
data = Dataloader(path=opt['path_dataset'],
74+
channel_label=opt['channel_label'], kw_timestep=opt['kw_timestep'],
75+
norm_data=opt['norm_data'], std_data=opt['std_data'], diff_data=opt['diff_data'],)
76+
dataset = data.get_data()
77+
# dataset = dataset[:, opt['n_conditions']:, :].to(opt['device']) #Remove labels
78+
# dataset = scale(dataset)
79+
80+
# Split data function
81+
def split_data(dataset, train_size=.8):
82+
num_samples = dataset.shape[0]
83+
shuffle_index = np.arange(num_samples)
84+
np.random.shuffle(shuffle_index)
85+
86+
cutoff_index = int(num_samples*train_size)
87+
train = dataset[shuffle_index[:cutoff_index]]
88+
test = dataset[shuffle_index[cutoff_index:]]
89+
90+
return test, train
91+
92+
# Determine n_channels, output_dim, and seq_length
93+
opt['n_channels'] = dataset.shape[-1]
94+
opt['sequence_length'] = dataset.shape[1]
95+
96+
# Split dataset and convert to pytorch dataloader class
97+
test_dataset, train_dataset = split_data(dataset, opt['train_ratio'])
98+
test_dataloader = DataLoader(test_dataset, batch_size=opt['batch_size'], shuffle=True)
99+
train_dataloader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True)
100+
101+
# ------------------------------------------------------------------------------------------------------------------
102+
# Initiate and train autoencoder
103+
# ------------------------------------------------------------------------------------------------------------------
104+
105+
# Initiate autoencoder
106+
model_dict = None
107+
if default_args['load_checkpoint'] and os.path.isfile(opt['path_checkpoint']):
108+
model_dict = torch.load(opt['path_checkpoint'])
109+
# model_state = model_dict['state_dict']
110+
111+
target_old = opt['target']
112+
channels_out_old = opt['channels_out']
113+
timeseries_out_old = opt['timeseries_out']
114+
115+
opt['target'] = model_dict['configuration']['target']
116+
opt['channels_out'] = model_dict['configuration']['channels_out']
117+
opt['timeseries_out'] = model_dict['configuration']['timeseries_out']
118+
119+
# Report changes to user
120+
print(f"Loading model {opt['path_checkpoint']}.\n\nInhereting the following parameters:")
121+
print("parameter:\t\told value -> new value")
122+
print(f"target:\t\t\t{target_old} -> {opt['target']}")
123+
print(f"channels_out:\t{channels_out_old} -> {opt['channels_out']}")
124+
print(f"timeseries_out:\t{timeseries_out_old} -> {opt['timeseries_out']}")
125+
print('-----------------------------------\n')
126+
# print(f"Target: {opt['target']}")
127+
# if (opt['target'] == 'channels') | (opt['target'] == 'full'):
128+
# print(f"channels_out: {opt['channels_out']}")
129+
# if (opt['target'] == 'timeseries') | (opt['target'] == 'full'):
130+
# print(f"timeseries_out: {opt['timeseries_out']}")
131+
# print('-----------------------------------\n')
132+
133+
elif default_args['load_checkpoint'] and not os.path.isfile(opt['path_checkpoint']):
134+
raise FileNotFoundError(f"Checkpoint file {opt['path_checkpoint']} not found.")
135+
136+
# Add parameters for tracking
137+
opt['input_dim'] = opt['n_channels'] if opt['target'] in ['channels', 'full'] else opt['sequence_length']
138+
opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['n_channels']
139+
opt['output_dim_2'] = opt['sequence_length'] if opt['target'] in ['channels'] else opt['timeseries_out']
140+
141+
if opt['target'] == 'channels':
142+
model = TransformerAutoencoder(input_dim=opt['n_channels'],
143+
output_dim=opt['channels_out'],
144+
output_dim_2=opt['sequence_length'],
145+
target=TransformerAutoencoder.TARGET_CHANNELS,
146+
hidden_dim=opt['hidden_dim'],
147+
num_layers=opt['num_layers'],
148+
num_heads=opt['num_heads'],).to(opt['device'])
149+
elif opt['target'] == 'time':
150+
model = TransformerAutoencoder(input_dim=opt['sequence_length'],
151+
output_dim=opt['timeseries_out'],
152+
output_dim_2=opt['n_channels'],
153+
target=TransformerAutoencoder.TARGET_TIMESERIES,
154+
hidden_dim=opt['hidden_dim'],
155+
num_layers=opt['num_layers'],
156+
num_heads=opt['num_heads'],).to(opt['device'])
157+
elif opt['target'] == 'full':
158+
model = TransformerDoubleAutoencoder(input_dim=opt['n_channels'],
159+
output_dim=opt['output_dim'],
160+
output_dim_2=opt['output_dim_2'],
161+
sequence_length=opt['sequence_length'],
162+
hidden_dim=opt['hidden_dim'],
163+
num_layers=opt['num_layers'],
164+
num_heads=opt['num_heads'],).to(opt['device'])
165+
else:
166+
raise ValueError(f"Encode target '{opt['target']}' not recognized, options are 'channels', 'time', or 'full'.")
167+
168+
# Populate model configuration
169+
history = {}
170+
for key in opt.keys():
171+
if (not key == 'history') | (not key == 'trained_epochs'):
172+
history[key] = [opt[key]]
173+
history['trained_epochs'] = []
174+
175+
if model_dict is not None:
176+
# update history
177+
for key in history.keys():
178+
history[key] = model_dict['configuration']['history'][key] + history[key]
179+
180+
opt['history'] = history
181+
182+
if opt['ddp']:
183+
trainer = AEDDPTrainer(model, opt)
184+
if default_args['load_checkpoint']:
185+
trainer.load_checkpoint(default_args['path_checkpoint'])
186+
mp.spawn(run, args=(opt['world_size'], find_free_port(), opt['ddp_backend'], trainer, opt),
187+
nprocs=opt['world_size'], join=True)
188+
else:
189+
trainer = AETrainer(model, opt)
190+
if default_args['load_checkpoint']:
191+
trainer.load_checkpoint(default_args['path_checkpoint'])
192+
samples = trainer.training(train_dataloader, test_dataloader)
193+
model = trainer.model
194+
print("Training finished.")
195+
196+
# ----------------------------------------------------------------------------------------------------------------------
197+
# Save autoencoder
198+
# ----------------------------------------------------------------------------------------------------------------------
199+
200+
# Save model
201+
# model_dict = dict(state_dict=model.state_dict(), config=model.config)
202+
if opt['save_name'] is None:
203+
fn = opt['path_dataset'].split('/')[-1].split('.csv')[0]
204+
opt['save_name'] = os.path.join("trained_ae", f"ae_{fn}_{str(time.time()).split('.')[0]}.pt")
205+
# save(model_dict, save_name)
206+
207+
trainer.save_checkpoint(opt['save_name'], update_history=True, samples=samples)
208+
print(f"Model and configuration saved in {opt['save_name']}")
209+
210+
if __name__ == "__main__":
211+
main()

0 commit comments

Comments
 (0)