From 9365434ef5c559cac89afe0afffefa7058e8d052 Mon Sep 17 00:00:00 2001 From: Filipe Tostevin Date: Thu, 5 Aug 2021 16:35:11 +0100 Subject: [PATCH 1/5] switch training data generator to Sequence, implement mini_epochs --- clair3/Train.py | 108 +++++++++++++++++++++++++----------------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/clair3/Train.py b/clair3/Train.py index 07aa92e..a98ab8f 100644 --- a/clair3/Train.py +++ b/clair3/Train.py @@ -62,6 +62,46 @@ def call(self, y_true, y_pred): return reduce_fl +class DataSequence(tf.keras.utils.Sequence): + def __init__(self, data, chunk_list, param, tensor_shape, mini_epochs=1, add_indel_length=False, validation=False): + self.data = data + self.chunk_list = chunk_list + self.batch_size = param.trainBatchSize + self.chunk_size = param.chunk_size + self.chunks_per_batch = self.batch_size // self.chunk_size + self.label_shape_cum = list(accumulate(param.label_shape))[0:4 if add_indel_length else 2] + self.mini_epochs = mini_epochs + self.mini_epochs_count = -1 + self.validation = validation + self.position_matrix = np.empty([self.batch_size] + tensor_shape, np.int32) + self.label = np.empty((self.batch_size, param.label_size), np.float32) + self.random_offset = 0 + self.on_epoch_end() + + def __len__(self): + return int((len(self.chunk_list) // self.chunks_per_batch) // self.mini_epochs) + + def __getitem__(self, index): + chunk_batch_list = self.chunk_list[index * self.chunks_per_batch:(index + 1) * self.chunks_per_batch] + for chunk_idx, (bin_id, chunk_id) in enumerate(chunk_batch_list): + start_pos = self.random_offset + chunk_id * self.chunk_size + self.position_matrix[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \ + self.data[bin_id].root.position_matrix[start_pos:start_pos + self.chunk_size] + self.label[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \ + self.data[bin_id].root.label[start_pos:start_pos + self.chunk_size] + + return self.position_matrix, tuple( + np.split(self.label, self.label_shape_cum, axis=1)[:len(self.label_shape_cum)] + ) + + def on_epoch_end(self): + self.mini_epochs_count += 1 + if not self.validation and (self.mini_epochs_count % self.mini_epochs) == 0: + self.random_offset = np.random.randint(0, self.chunk_size) + np.random.shuffle(self.chunk_list) + self.mini_epochs_count = 0 + + def get_chunk_list(chunk_offset, train_chunk_num): """ get chunk list for training and validation data. we will randomly split training and validation dataset, @@ -100,7 +140,7 @@ def train_model(args): model = model_path.Clair3_F(add_indel_length=add_indel_length) tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape - label_size, label_shape = param.label_size, param.label_shape + label_shape = param.label_shape label_shape_cum = list(accumulate(label_shape)) batch_size, chunk_size = param.trainBatchSize, param.chunk_size chunks_per_batch = batch_size // chunk_size @@ -109,9 +149,7 @@ def train_model(args): learning_rate = args.learning_rate if args.learning_rate else param.initialLearningRate max_epoch = args.maxEpoch if args.maxEpoch else param.maxEpoch task_num = 4 if add_indel_length else 2 - TensorShape = ( - tf.TensorShape([None] + tensor_shape), tuple(tf.TensorShape([None, label_shape[task]]) for task in range(task_num))) - TensorDtype = (tf.int32, tuple(tf.float32 for _ in range(task_num))) + mini_epochs = args.mini_epochs def populate_dataset_table(file_list, file_path): chunk_offset = np.zeros(len(file_list), dtype=int) @@ -155,50 +193,13 @@ def populate_dataset_table(file_list, file_path): train_chunk_num = len(train_shuffle_chunk_list) validate_chunk_num = len(validate_shuffle_chunk_list) - - def DataGenerator(x, shuffle_chunk_list, train_flag=True): - - """ - data generator for pileup or full alignment data processing, pytables with blosc:lz4hc are used for extreme fast - compression and decompression. random chunk shuffling and random start position to increase training model robustness. - - """ - - batch_num = len(shuffle_chunk_list) // chunks_per_batch - position_matrix = np.empty([batch_size] + tensor_shape, np.int32) - label = np.empty((batch_size, param.label_size), np.float32) - - random_start_position = np.random.randint(0, chunk_size) if train_flag else 0 - if train_flag: - np.random.shuffle(shuffle_chunk_list) - for batch_idx in range(batch_num): - for chunk_idx in range(chunks_per_batch): - bin_id, chunk_id = shuffle_chunk_list[batch_idx * chunks_per_batch + chunk_idx] - position_matrix[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.position_matrix[ - random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size] - label[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.label[ - random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size] - - if add_indel_length: - yield position_matrix, ( - label[:, :label_shape_cum[0]], - label[:, label_shape_cum[0]:label_shape_cum[1]], - label[:, label_shape_cum[1]:label_shape_cum[2]], - label[:, label_shape_cum[2]: ] - ) - else: - yield position_matrix, ( - label[:, :label_shape_cum[0]], - label[:, label_shape_cum[0]:label_shape_cum[1]] - ) - - - train_dataset = tf.data.Dataset.from_generator( - lambda: DataGenerator(table_dataset_list, train_shuffle_chunk_list, True), TensorDtype, - TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) - validate_dataset = tf.data.Dataset.from_generator( - lambda: DataGenerator(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, False), TensorDtype, - TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + train_seq = DataSequence(table_dataset_list, train_shuffle_chunk_list, param, tensor_shape, + mini_epochs=mini_epochs, add_indel_length=add_indel_length) + if add_validation_dataset: + val_seq = DataSequence(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, param, tensor_shape, + mini_epochs=1, add_indel_length=add_indel_length, validation=True) + else: + val_seq = None total_steps = max_epoch * (train_chunk_num // chunks_per_batch) @@ -234,16 +235,16 @@ def DataGenerator(x, shuffle_chunk_list, train_flag=True): logging.info("[INFO] The training learning_rate: {}".format(learning_rate)) logging.info("[INFO] Total training steps: {}".format(total_steps)) logging.info("[INFO] Maximum training epoch: {}".format(max_epoch)) + logging.info("[INFO] Mini-epochs per epoch: {}".format(mini_epochs)) logging.info("[INFO] Start training...") - validate_dataset = validate_dataset if add_validation_dataset else None if args.chkpnt_fn is not None: model.load_weights(args.chkpnt_fn) logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn)) - train_history = model.fit(x=train_dataset, - epochs=max_epoch, - validation_data=validate_dataset, + train_history = model.fit(x=train_seq, + epochs=max_epoch * mini_epochs, + validation_data=val_seq, callbacks=[early_stop_callback, model_save_callback, model_best_callback, @@ -293,6 +294,9 @@ def main(): parser.add_argument('--exclude_training_samples', type=str, default=None, help="Define training samples to be excluded") + parser.add_argument('--mini_epochs', type=int, default=1, + help="Number of mini-epochs per epoch") + # Internal process control ## In pileup training mode or not parser.add_argument('--pileup', action='store_true', From 6d13ed4fef44294b02921ff468051ee7ca12075c Mon Sep 17 00:00:00 2001 From: Filipe Tostevin Date: Thu, 5 Aug 2021 17:46:40 +0100 Subject: [PATCH 2/5] fix chunk indexing for miniepochs --- clair3/Train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/clair3/Train.py b/clair3/Train.py index a98ab8f..8f8fe1b 100644 --- a/clair3/Train.py +++ b/clair3/Train.py @@ -82,7 +82,8 @@ def __len__(self): return int((len(self.chunk_list) // self.chunks_per_batch) // self.mini_epochs) def __getitem__(self, index): - chunk_batch_list = self.chunk_list[index * self.chunks_per_batch:(index + 1) * self.chunks_per_batch] + mini_epoch_offset = self.mini_epochs_count * self.__len__() + chunk_batch_list = self.chunk_list[(mini_epoch_offset + index) * self.chunks_per_batch:(mini_epoch_offset + index + 1) * self.chunks_per_batch] for chunk_idx, (bin_id, chunk_id) in enumerate(chunk_batch_list): start_pos = self.random_offset + chunk_id * self.chunk_size self.position_matrix[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \ From d8ffca96f7c40c25676b9707ba8df7bb828840a7 Mon Sep 17 00:00:00 2001 From: Filipe Tostevin Date: Fri, 6 Aug 2021 09:22:35 +0100 Subject: [PATCH 3/5] reset mini_epochs count for validation dataset so chunk indices don't go out of bounds --- clair3/Train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/clair3/Train.py b/clair3/Train.py index 8f8fe1b..beaad2c 100644 --- a/clair3/Train.py +++ b/clair3/Train.py @@ -97,10 +97,11 @@ def __getitem__(self, index): def on_epoch_end(self): self.mini_epochs_count += 1 - if not self.validation and (self.mini_epochs_count % self.mini_epochs) == 0: - self.random_offset = np.random.randint(0, self.chunk_size) - np.random.shuffle(self.chunk_list) + if (self.mini_epochs_count % self.mini_epochs) == 0: self.mini_epochs_count = 0 + if not self.validation: + self.random_offset = np.random.randint(0, self.chunk_size) + np.random.shuffle(self.chunk_list) def get_chunk_list(chunk_offset, train_chunk_num): From dbae748f817bd9a03d7520793533948e1f5d6d9a Mon Sep 17 00:00:00 2001 From: Filipe Tostevin Date: Wed, 11 Aug 2021 13:58:20 +0100 Subject: [PATCH 4/5] label_shape_cum already exists in param file so reuse it --- clair3/Train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clair3/Train.py b/clair3/Train.py index beaad2c..9614f18 100644 --- a/clair3/Train.py +++ b/clair3/Train.py @@ -69,7 +69,7 @@ def __init__(self, data, chunk_list, param, tensor_shape, mini_epochs=1, add_ind self.batch_size = param.trainBatchSize self.chunk_size = param.chunk_size self.chunks_per_batch = self.batch_size // self.chunk_size - self.label_shape_cum = list(accumulate(param.label_shape))[0:4 if add_indel_length else 2] + self.label_shape_cum = param.label_shape_cum[0:4 if add_indel_length else 2] self.mini_epochs = mini_epochs self.mini_epochs_count = -1 self.validation = validation @@ -143,7 +143,7 @@ def train_model(args): tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape label_shape = param.label_shape - label_shape_cum = list(accumulate(label_shape)) + label_shape_cum = param.label_shape_cum batch_size, chunk_size = param.trainBatchSize, param.chunk_size chunks_per_batch = batch_size // chunk_size random.seed(param.RANDOM_SEED) From be48191e09cbf756df3594b71b4b11d4a98194f2 Mon Sep 17 00:00:00 2001 From: Filipe Tostevin Date: Thu, 12 Aug 2021 10:31:00 +0100 Subject: [PATCH 5/5] build PyramidPooling layer before calling fixes crash when tf tries to call with input tensor of shape [None, None, None, None] --- clair3/model.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/clair3/model.py b/clair3/model.py index ca7a93c..0419304 100644 --- a/clair3/model.py +++ b/clair3/model.py @@ -262,21 +262,27 @@ def __init__(self, spatial_pool_size=(3, 2, 1)): super(PyramidPolling, self).__init__() self.spatial_pool_size = spatial_pool_size + self.pool_len = len(self.spatial_pool_size) + self.window_h = np.empty(self.pool_len, dtype=int) + self.stride_h = np.empty(self.pool_len, dtype=int) + self.window_w = np.empty(self.pool_len, dtype=int) + self.stride_w = np.empty(self.pool_len, dtype=int) self.flatten = tf.keras.layers.Flatten() - def call(self, x): - - height = int(x.get_shape()[1]) - width = int(x.get_shape()[2]) - - for i in range(len(self.spatial_pool_size)): + def build(self, input_shape): + height = int(input_shape[1]) + width = int(input_shape[2]) - window_h = stride_h = int(np.ceil(height / self.spatial_pool_size[i])) + for i in range(self.pool_len): + self.window_h[i] = self.stride_h[i] = int(np.ceil(height / self.spatial_pool_size[i])) + self.window_w[i] = self.stride_w[i] = int(np.ceil(width / self.spatial_pool_size[i])) - window_w = stride_w = int(np.ceil(width / self.spatial_pool_size[i])) - - max_pool = tf.nn.max_pool(x, ksize=[1, window_h, window_w, 1], strides=[1, stride_h, stride_w, 1], + def call(self, x): + for i in range(self.pool_len): + max_pool = tf.nn.max_pool(x, + ksize=[1, self.window_h[i], self.window_w[i], 1], + strides=[1, self.stride_h[i], self.stride_w[i], 1], padding='SAME') if i == 0: pp = self.flatten(max_pool)