Skip to content

Commit

Permalink
import from private proj
Browse files Browse the repository at this point in the history
  • Loading branch information
k committed Jul 21, 2019
1 parent babe469 commit 707c941
Show file tree
Hide file tree
Showing 36 changed files with 45,545 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# mortality_prediction
member: Liyan Fang, Shimeng Jiang, Xujian Liang, Zhiyong Deng

## Data preprocessing

> Since the raw datasets from mimic3 database are too large, we have created test dataset in the directory: `data/test_data/`, which contains `data1.zip` and `data2.zip` that are sub datasets for testing codes for data preprocess.
Expand Down
Binary file added data/.DS_Store
Binary file not shown.
Binary file added data/preprocessed_data.zip
Binary file not shown.
Binary file added data/test_data/data1.zip
Binary file not shown.
Binary file added data/test_data/data2.zip
Binary file not shown.
Binary file added src/.DS_Store
Binary file not shown.
Binary file added src/__pycache__/ml_utils.cpython-37.pyc
Binary file not shown.
87 changes: 87 additions & 0 deletions src/base_LSTM_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import print_function
from __future__ import absolute_import

from keras.models import Model
from keras.layers import Input, Dense, LSTM, Masking, Dropout
from keras.layers.wrappers import Bidirectional, TimeDistributed

from lstm_utils import ExtendMask, LastTimestep


class base_LSTM_m(Model):

def __init__(self, dim, batch_norm, dropout, rec_dropout, deep_supervision=False, num_classes=1, target_repl=False,
depth=1, input_dim=76, **kwargs):

print("==> not used params in network class:", kwargs.keys())

self.dim = dim
self.batch_norm = batch_norm
self.dropout = dropout
self.rec_dropout = rec_dropout
self.depth = depth

activation = 'sigmoid'

# Input layers and masking
X = Input(shape=(None, input_dim), name='X')
inputs = [X]
mX = Masking()(X)

if deep_supervision:
M = Input(shape=(None,), name='M')
inputs.append(M)

# Configurations
is_bidirectional = True
if deep_supervision:
is_bidirectional = False

# Main part of the network
for i in range(depth - 1):
num_units = dim
if is_bidirectional:
num_units = num_units // 2

lstm = LSTM(units=num_units,
activation='tanh',
return_sequences=True,
recurrent_dropout=rec_dropout,
dropout=dropout)

if is_bidirectional:
mX = Bidirectional(lstm)(mX)
else:
mX = lstm(mX)

# Output module of the network
return_sequences = (target_repl or deep_supervision)
L = LSTM(units=dim,
activation='tanh',
return_sequences=return_sequences,
dropout=dropout,
recurrent_dropout=rec_dropout)(mX)

if dropout > 0:
L = Dropout(dropout)(L)
if target_repl:
y = TimeDistributed(Dense(num_classes, activation=activation), name='seq')(L)
y_last = LastTimestep(name='single')(y)
outputs = [y_last, y]
elif deep_supervision:
y = TimeDistributed(Dense(num_classes, activation=activation))(L)
y = ExtendMask()([y, M]) # this way we extend mask of y to M
outputs = [y]
else:
y = Dense(num_classes, activation=activation)(L)
outputs = [y]

super(base_LSTM_m, self).__init__(inputs=inputs, outputs=outputs)

def say_name(self):
return "{}.n{}{}{}{}.dep{}".format('k_lstm',
self.dim,
".bn" if self.batch_norm else "",
".d{}".format(self.dropout) if self.dropout > 0 else "",
".rd{}".format(self.rec_dropout) if self.rec_dropout > 0 else "",
self.depth)
299 changes: 299 additions & 0 deletions src/base_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
from __future__ import print_function
from __future__ import absolute_import


import numpy as np
import os
import re
import argparse

from base_LSTM_model import base_LSTM_m
from utils import Reader, load_data, print_metrics_binary, save_results
import utils
from keras.callbacks import ModelCheckpoint, CSVLogger
from lstm_utils import InHospitalMortalityMetrics, Discretizer, Normalizer
import matplotlib.pyplot as plt
from sklearn.utils.fixes import signature
from sklearn import metrics
import keras_metrics

parser = argparse.ArgumentParser()

parser.add_argument('--dim', type=int, default=16,
help='number of hidden units')
parser.add_argument('--depth', type=int, default=2,
help='number of bi-LSTMs')
parser.add_argument('--epochs', type=int, default=100,
help='number of chunks to train')
parser.add_argument('--load_state', type=str, default="",
help='state file path')
parser.add_argument('--mode', type=str, default="train",
help='mode: train or test')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--l2', type=float, default=0, help='L2 regularization')
parser.add_argument('--l1', type=float, default=0, help='L1 regularization')
parser.add_argument('--save_every', type=int, default=1,
help='save state every x epoch')
parser.add_argument('--prefix', type=str, default="",
help='optional prefix of network name')
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--rec_dropout', type=float, default=0.0,
help="dropout rate for recurrent connections")
parser.add_argument('--batch_norm', type=bool, default=False,
help='batch normalization')
parser.add_argument('--timestep', type=float, default=1.0,
help="fixed timestep used in the dataset")
parser.add_argument('--imputation', type=str, default='previous')
parser.add_argument('--small_part', dest='small_part', action='store_true')
parser.add_argument('--whole_data', dest='small_part', action='store_false')
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--beta_1', type=float, default=0.9,
help='beta_1 param for Adam optimizer')
parser.add_argument('--verbose', type=int, default=2)
parser.add_argument('--size_coef', type=float, default=4.0)
parser.add_argument('--normalizer_state', type=str, default=None,
help='Path to a state file of a normalizer. Leave none if you want to '
'use one of the provided ones.')
parser.add_argument('--deep_supervision', type=bool, default=False,
help='set deep supervision for the model')
parser.set_defaults(small_part=False)

parser.add_argument('--target_repl_coef', type=float, default=0.0)
parser.add_argument('--data', type=str, help='Path to the data of in-hospital mortality task',
default=os.path.join(os.path.dirname(__file__), '../data/preprocessed_data/'))
parser.add_argument('--output_dir', type=str, help='Directory relative which all output files are stored',
default=os.path.join(os.path.dirname(__file__), '../output/'))


args = parser.parse_args()
print(args)

drop_out = args.dropout
depth = args.depth
batch_norm = args.batch_norm
mode = args.mode
dim = args.dim
target_repl_coef = args.target_repl_coef
data = args.data
rec_dropout = args.dropout
timestep = args.timestep
normalizer_state = args.normalizer_state
imputation=args.imputation
l1 = args.l1
l2 = args.l2
batch_size=args.batch_size
optimizer=args.optimizer
lr=args.lr
beta_1=args.beta_1
load_state=args.load_state
verbose=args.verbose
epochs=args.epochs
output_dir=args.output_dir
small_part=args.small_part
save_every=args.save_every
deep_supervision=args.deep_supervision
# mode = 'train'
# target_repl_coef = 0.0
# target_repl = (target_repl_coef > 0.0 and mode == 'train')
# data = os.path.join(os.path.dirname(__file__), '../data/in-hospital-mortality/')
# timestep = 1.0
# normalizer_state = None
# imputation='previous'
# l1 = 0
# l2 = 0
# batch_size=8
# optimizer='adam'
# lr=0.001
# beta_1=0.9
# load_state=''
# verbose=2
# epochs=10
# output_dir=os.path.join(os.path.dirname(__file__), '../output/')
# small_part=False
# save_every=1

target_repl = (target_repl_coef > 0.0 and mode == 'train')

if small_part:
save_every = 2**30


# Build readers, discretizers, normalizers
train_reader = Reader(dataset_dir=os.path.join(data, 'train'),
listfile=os.path.join(data, 'train_listfile.csv'),
period_length=48.0)

val_reader = Reader(dataset_dir=os.path.join(data, 'train'),
listfile=os.path.join(data, 'val_listfile.csv'),
period_length=48.0)

discretizer = Discretizer(timestep=float(timestep),
store_masks=True,
impute_strategy='previous',
start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels) # choose here which columns to standardize
normalizer_state = normalizer_state
if normalizer_state is None:
# normalizer_state = 'ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(timestep, imputation)
normalizer_state = './resources/ihm_ts{}.input_str:{}.start_time:zero.normalizer'.format(
timestep, imputation)
normalizer_state = os.path.join(os.path.dirname(__file__), normalizer_state)
normalizer.load_params(normalizer_state)

# Build the model
print("==> using model base_Lstm")
model = base_LSTM_m(dim=dim, batch_norm=batch_norm, dropout=drop_out, rec_dropout=rec_dropout,deep_supervision=deep_supervision,num_classes=1,depth=depth,target_repl=target_repl)
suffix = ".bs{}{}{}.ts{}{}".format(batch_size,
".L1{}".format(l1) if l1 > 0 else "",
".L2{}".format(l2) if l2 > 0 else "",
timestep,
".trc{}".format(target_repl_coef) if target_repl_coef > 0 else "")
model.final_name = model.say_name() + suffix
print("==> model.final_name:", model.final_name)


# Compile the model
print("==> compiling the model")
optimizer_config = {'class_name': optimizer,
'config': {'lr': lr,
'beta_1': beta_1}}

# NOTE: one can use binary_crossentropy even for (B, T, C) shape.
# It will calculate binary_crossentropies for each class
# and then take the mean over axis=-1. Tre results is (B, T).
if target_repl:
loss = ['binary_crossentropy'] * 2
loss_weights = [1 - args.target_repl_coef, args.target_repl_coef]
else:
loss = 'binary_crossentropy'
loss_weights = None

model.compile(optimizer=optimizer_config,
loss=loss, loss_weights=loss_weights)
model.summary()

# Load model weights
n_trained_chunks = 0
if load_state != "":
model.load_weights(load_state)
n_trained_chunks = int(re.match(".*epoch([0-9]+).*", load_state).group(1))


# Read data
train_raw = utils.load_data(train_reader, discretizer, normalizer, small_part)
val_raw = utils.load_data(val_reader, discretizer, normalizer, small_part)

if target_repl:
T = train_raw[0][0].shape[0]

def extend_labels(data):
data = list(data)
labels = np.array(data[1]) # (B,)
data[1] = [labels, None]
data[1][1] = np.expand_dims(labels, axis=-1).repeat(T, axis=1) # (B, T)
data[1][1] = np.expand_dims(data[1][1], axis=-1) # (B, T, 1)
return data

train_raw = extend_labels(train_raw)
val_raw = extend_labels(val_raw)

if mode == 'train':

# Prepare training
path = os.path.join(output_dir, 'keras_states/' + model.final_name + '.epoch{epoch}.test{val_loss}.state')

metrics_callback = InHospitalMortalityMetrics(train_data=train_raw,
val_data=val_raw,
target_repl=(target_repl_coef > 0),
batch_size=batch_size,
verbose=verbose)
# make sure save directory exists
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
saver = ModelCheckpoint(path, verbose=1, period=save_every)

keras_logs = os.path.join(output_dir, 'keras_logs')
if not os.path.exists(keras_logs):
os.makedirs(keras_logs)
csv_logger = CSVLogger(os.path.join(keras_logs, model.final_name + '.csv'),
append=True, separator=';')

print("==> training")
history = model.fit(x=train_raw[0],
y=train_raw[1],
validation_data=val_raw,
epochs=n_trained_chunks + epochs,
initial_epoch=n_trained_chunks,
callbacks=[metrics_callback, saver, csv_logger],
shuffle=True,
verbose=verbose,
batch_size=batch_size)

precision, recall = metrics_callback.getPrecisionRecall()






elif mode == 'test':

# ensure that the code uses test_reader
del train_reader
del val_reader
del train_raw
del val_raw

test_reader = utils.Reader(dataset_dir=os.path.join(data, 'test'),
listfile=os.path.join(data, 'test_listfile.csv'),
period_length=48.0)
ret = utils.load_data(test_reader, discretizer, normalizer, small_part,
return_names=True)

data = ret["data"][0]
labels = ret["data"][1]
names = ret["names"]



predictions = model.predict(data, batch_size=batch_size, verbose=1)
predictions = np.array(predictions)[:, 0]

ret = utils.print_metrics_binary(labels, predictions)

precision, recall, _ = metrics.precision_recall_curve(labels, predictions)

auprc = ret['auprc']

# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument

step_kwargs = ({'step': 'post'}
if 'step' in signature(plt.fill_between).parameters
else {})
plt.step(recall, precision, color='b', alpha=0.2,
where='post')
plt.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('LSTM Precision-Recall curve: auprc={0:0.3f}'.format(auprc))
plt.show()



path = os.path.join(output_dir, "test_predictions", os.path.basename(load_state)) + ".csv"
utils.save_results(names, predictions, labels, path)
if target_repl_coef > 0.0:
model.save('./output/base_lstm_DS.h5')
else:
model.save('./output/base_lstm_.h5')

else:
raise ValueError("Wrong value for args.mode")
Loading

0 comments on commit 707c941

Please sign in to comment.