Skip to content

Commit

Permalink
Merge pull request #60 from nanoporetech/miniepochs
Browse files Browse the repository at this point in the history
implement mini-epochs in training
  • Loading branch information
zhengzhenxian authored Oct 17, 2021
2 parents 70ee1e7 + be48191 commit 31cdf49
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 63 deletions.
112 changes: 59 additions & 53 deletions clair3/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,48 @@ def call(self, y_true, y_pred):
return reduce_fl


class DataSequence(tf.keras.utils.Sequence):
def __init__(self, data, chunk_list, param, tensor_shape, mini_epochs=1, add_indel_length=False, validation=False):
self.data = data
self.chunk_list = chunk_list
self.batch_size = param.trainBatchSize
self.chunk_size = param.chunk_size
self.chunks_per_batch = self.batch_size // self.chunk_size
self.label_shape_cum = param.label_shape_cum[0:4 if add_indel_length else 2]
self.mini_epochs = mini_epochs
self.mini_epochs_count = -1
self.validation = validation
self.position_matrix = np.empty([self.batch_size] + tensor_shape, np.int32)
self.label = np.empty((self.batch_size, param.label_size), np.float32)
self.random_offset = 0
self.on_epoch_end()

def __len__(self):
return int((len(self.chunk_list) // self.chunks_per_batch) // self.mini_epochs)

def __getitem__(self, index):
mini_epoch_offset = self.mini_epochs_count * self.__len__()
chunk_batch_list = self.chunk_list[(mini_epoch_offset + index) * self.chunks_per_batch:(mini_epoch_offset + index + 1) * self.chunks_per_batch]
for chunk_idx, (bin_id, chunk_id) in enumerate(chunk_batch_list):
start_pos = self.random_offset + chunk_id * self.chunk_size
self.position_matrix[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \
self.data[bin_id].root.position_matrix[start_pos:start_pos + self.chunk_size]
self.label[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \
self.data[bin_id].root.label[start_pos:start_pos + self.chunk_size]

return self.position_matrix, tuple(
np.split(self.label, self.label_shape_cum, axis=1)[:len(self.label_shape_cum)]
)

def on_epoch_end(self):
self.mini_epochs_count += 1
if (self.mini_epochs_count % self.mini_epochs) == 0:
self.mini_epochs_count = 0
if not self.validation:
self.random_offset = np.random.randint(0, self.chunk_size)
np.random.shuffle(self.chunk_list)


def get_chunk_list(chunk_offset, train_chunk_num):
"""
get chunk list for training and validation data. we will randomly split training and validation dataset,
Expand Down Expand Up @@ -100,18 +142,16 @@ def train_model(args):
model = model_path.Clair3_F(add_indel_length=add_indel_length)

tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape
label_size, label_shape = param.label_size, param.label_shape
label_shape_cum = list(accumulate(label_shape))
label_shape = param.label_shape
label_shape_cum = param.label_shape_cum
batch_size, chunk_size = param.trainBatchSize, param.chunk_size
chunks_per_batch = batch_size // chunk_size
random.seed(param.RANDOM_SEED)
np.random.seed(param.RANDOM_SEED)
learning_rate = args.learning_rate if args.learning_rate else param.initialLearningRate
max_epoch = args.maxEpoch if args.maxEpoch else param.maxEpoch
task_num = 4 if add_indel_length else 2
TensorShape = (
tf.TensorShape([None] + tensor_shape), tuple(tf.TensorShape([None, label_shape[task]]) for task in range(task_num)))
TensorDtype = (tf.int32, tuple(tf.float32 for _ in range(task_num)))
mini_epochs = args.mini_epochs

def populate_dataset_table(file_list, file_path):
chunk_offset = np.zeros(len(file_list), dtype=int)
Expand Down Expand Up @@ -155,50 +195,13 @@ def populate_dataset_table(file_list, file_path):
train_chunk_num = len(train_shuffle_chunk_list)
validate_chunk_num = len(validate_shuffle_chunk_list)


def DataGenerator(x, shuffle_chunk_list, train_flag=True):

"""
data generator for pileup or full alignment data processing, pytables with blosc:lz4hc are used for extreme fast
compression and decompression. random chunk shuffling and random start position to increase training model robustness.
"""

batch_num = len(shuffle_chunk_list) // chunks_per_batch
position_matrix = np.empty([batch_size] + tensor_shape, np.int32)
label = np.empty((batch_size, param.label_size), np.float32)

random_start_position = np.random.randint(0, chunk_size) if train_flag else 0
if train_flag:
np.random.shuffle(shuffle_chunk_list)
for batch_idx in range(batch_num):
for chunk_idx in range(chunks_per_batch):
bin_id, chunk_id = shuffle_chunk_list[batch_idx * chunks_per_batch + chunk_idx]
position_matrix[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.position_matrix[
random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size]
label[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.label[
random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size]

if add_indel_length:
yield position_matrix, (
label[:, :label_shape_cum[0]],
label[:, label_shape_cum[0]:label_shape_cum[1]],
label[:, label_shape_cum[1]:label_shape_cum[2]],
label[:, label_shape_cum[2]: ]
)
else:
yield position_matrix, (
label[:, :label_shape_cum[0]],
label[:, label_shape_cum[0]:label_shape_cum[1]]
)


train_dataset = tf.data.Dataset.from_generator(
lambda: DataGenerator(table_dataset_list, train_shuffle_chunk_list, True), TensorDtype,
TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validate_dataset = tf.data.Dataset.from_generator(
lambda: DataGenerator(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, False), TensorDtype,
TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
train_seq = DataSequence(table_dataset_list, train_shuffle_chunk_list, param, tensor_shape,
mini_epochs=mini_epochs, add_indel_length=add_indel_length)
if add_validation_dataset:
val_seq = DataSequence(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, param, tensor_shape,
mini_epochs=1, add_indel_length=add_indel_length, validation=True)
else:
val_seq = None

total_steps = max_epoch * (train_chunk_num // chunks_per_batch)

Expand Down Expand Up @@ -234,16 +237,16 @@ def DataGenerator(x, shuffle_chunk_list, train_flag=True):
logging.info("[INFO] The training learning_rate: {}".format(learning_rate))
logging.info("[INFO] Total training steps: {}".format(total_steps))
logging.info("[INFO] Maximum training epoch: {}".format(max_epoch))
logging.info("[INFO] Mini-epochs per epoch: {}".format(mini_epochs))
logging.info("[INFO] Start training...")

validate_dataset = validate_dataset if add_validation_dataset else None
if args.chkpnt_fn is not None:
model.load_weights(args.chkpnt_fn)
logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn))

train_history = model.fit(x=train_dataset,
epochs=max_epoch,
validation_data=validate_dataset,
train_history = model.fit(x=train_seq,
epochs=max_epoch * mini_epochs,
validation_data=val_seq,
callbacks=[early_stop_callback,
model_save_callback,
model_best_callback,
Expand Down Expand Up @@ -293,6 +296,9 @@ def main():
parser.add_argument('--exclude_training_samples', type=str, default=None,
help="Define training samples to be excluded")

parser.add_argument('--mini_epochs', type=int, default=1,
help="Number of mini-epochs per epoch")

# Internal process control
## In pileup training mode or not
parser.add_argument('--pileup', action='store_true',
Expand Down
26 changes: 16 additions & 10 deletions clair3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,27 @@ def __init__(self, spatial_pool_size=(3, 2, 1)):
super(PyramidPolling, self).__init__()

self.spatial_pool_size = spatial_pool_size
self.pool_len = len(self.spatial_pool_size)
self.window_h = np.empty(self.pool_len, dtype=int)
self.stride_h = np.empty(self.pool_len, dtype=int)
self.window_w = np.empty(self.pool_len, dtype=int)
self.stride_w = np.empty(self.pool_len, dtype=int)

self.flatten = tf.keras.layers.Flatten()

def call(self, x):

height = int(x.get_shape()[1])
width = int(x.get_shape()[2])

for i in range(len(self.spatial_pool_size)):
def build(self, input_shape):
height = int(input_shape[1])
width = int(input_shape[2])

window_h = stride_h = int(np.ceil(height / self.spatial_pool_size[i]))
for i in range(self.pool_len):
self.window_h[i] = self.stride_h[i] = int(np.ceil(height / self.spatial_pool_size[i]))
self.window_w[i] = self.stride_w[i] = int(np.ceil(width / self.spatial_pool_size[i]))

window_w = stride_w = int(np.ceil(width / self.spatial_pool_size[i]))

max_pool = tf.nn.max_pool(x, ksize=[1, window_h, window_w, 1], strides=[1, stride_h, stride_w, 1],
def call(self, x):
for i in range(self.pool_len):
max_pool = tf.nn.max_pool(x,
ksize=[1, self.window_h[i], self.window_w[i], 1],
strides=[1, self.stride_h[i], self.stride_w[i], 1],
padding='SAME')
if i == 0:
pp = self.flatten(max_pool)
Expand Down

0 comments on commit 31cdf49

Please sign in to comment.