-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer_age2.py
130 lines (92 loc) · 3.63 KB
/
transformer_age2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
# coding: utf-8
import tensorflow as tf
import numpy as np
from pathlib import Path
from datetime import datetime
import os
DATA_PATH = Path("/scratch/ajk4yq/ecg/tfrecords/")
TRAIN_RECS = list(DATA_PATH.glob("train*.tfrecords"))
VAL_RECS = list(DATA_PATH.glob("val*.tfrecords"))
BATCH_SIZE = 8
record_format = {
'ecg/data': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
'age': tf.io.FixedLenFeature([], tf.float32),
'gender': tf.io.FixedLenFeature([], tf.int64),
}
def _parse_record(record):
example = tf.io.parse_single_example(record, record_format)
ecg_data = tf.reshape(example['ecg/data'], [5000,12])
label = example['age']
return ecg_data, label
def drop_na_ages(x,y):
return not tf.math.reduce_any(tf.math.is_nan(y))
def age_lt_90(x,y):
return tf.math.reduce_all(tf.math.less_equal(y, tf.constant([90.0])))
def load_dataset(filenames):
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.with_options(ignore_order)
dataset = dataset.map(_parse_record, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.filter(drop_na_ages)
dataset = dataset.filter(age_lt_90)
return dataset
def get_dataset(filenames, labeled=True):
dataset = load_dataset(filenames)
dataset = dataset.shuffle(2048)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
return dataset
# In[ ]:
train_dataset = get_dataset(TRAIN_RECS)
val_dataset = get_dataset(VAL_RECS)
# In[ ]:
#https://keras.io/examples/timeseries/timeseries_transformer_classification/
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(inputs)
x = tf.keras.layers.MultiHeadAttention(key_dim=head_size, num_heads=num_heads,dropout=dropout)(x,x)
x = tf.keras.layers.Dropout(dropout)(x)
res = x + inputs
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(res)
x = tf.keras.layers.Conv1D(filters=ff_dim, kernel_size=1, activation='relu')(x)
x = tf.keras.layers.Dropout(dropout)(x)
x = tf.keras.layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
return x + res
input_layer = tf.keras.layers.Input(shape=(5000, 12))
x = input_layer
x = transformer_encoder(x, head_size=5000, num_heads=2, ff_dim=4*12, dropout=0)
x = tf.keras.layers.MaxPooling1D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(64)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(1)(x)
model = tf.keras.models.Model(input_layer, x)
model.compile(
optimizer='adam',
loss=tf.keras.losses.MeanSquaredError(),
metrics=['mse', 'mae']
)
# In[ ]:
model.summary()
# In[ ]:
def make_checkpoint_dir(data_path, label):
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
output_dir = f"{label}-{formatted_datetime}"
output_path = f"{data_path}/{output_dir}"
if not os.path.exists(output_path):
os.makedirs(output_path)
return output_path
callbacks = [
tf.keras.callbacks.TerminateOnNaN(),
tf.keras.callbacks.ReduceLROnPlateau(),
tf.keras.callbacks.ModelCheckpoint(make_checkpoint_dir("data/models", "transformer-age"))
]
model.fit(train_dataset, epochs=10, validation_data=val_dataset, callbacks=callbacks)