Skip to content

Commit

Permalink
Merge pull request #61 from HKU-BAL/v0.1-r7
Browse files Browse the repository at this point in the history
V0.1 r7
  • Loading branch information
zhengzhenxian authored Oct 19, 2021
2 parents 31cdf49 + 687305d commit 6cd8994
Show file tree
Hide file tree
Showing 20 changed files with 837 additions and 196 deletions.
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ WORKDIR /opt/bin
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \
rm Miniconda3-latest-Linux-x86_64.sh && \
wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz -P /opt/models && \
tar -zxvf /opt/models/clair3_models.tar.gz -C /opt/models && \
rm /opt/models/clair3_models.tar.gz && \
conda config --add channels defaults && \
conda config --add channels bioconda && \
conda config --add channels conda-forge && \
Expand All @@ -48,4 +45,7 @@ COPY . .
RUN cd /opt/bin/preprocess/realign && \
g++ -std=c++14 -O1 -shared -fPIC -o realigner ssw_cpp.cpp ssw.c realigner.cpp && \
g++ -std=c++11 -shared -fPIC -o debruijn_graph -O3 debruijn_graph.cpp && \
wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz -P /opt/models && \
tar -zxvf /opt/models/clair3_models.tar.gz -C /opt/models && \
rm /opt/models/clair3_models.tar.gz && \
echo "source activate clair3" > ~/.bashrc
53 changes: 34 additions & 19 deletions README.md

Large diffs are not rendered by default.

38 changes: 30 additions & 8 deletions clair3/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,34 @@ def on_epoch_end(self):
np.random.shuffle(self.chunk_list)


def get_chunk_list(chunk_offset, train_chunk_num):
def get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch=10, training_dataset_percentage=None):
"""
get chunk list for training and validation data. we will randomly split training and validation dataset,
all training data is directly acquired from various tensor bin files.
"""
need_split_validation_data = training_dataset_percentage is not None
all_shuffle_chunk_list = []
training_chunk_list, validation_chunk_list = [], []
for bin_idx, chunk_num in enumerate(chunk_offset):
all_shuffle_chunk_list += [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)]
np.random.seed(0)
np.random.shuffle(all_shuffle_chunk_list) # keep the same random validate dataset
current_chunk_list = [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)]
all_shuffle_chunk_list += current_chunk_list
if need_split_validation_data:
buffer_chunk_num = chunks_per_batch
if chunk_num < buffer_chunk_num:
training_chunk_list += [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)]
continue

training_chunk_num = int((chunk_num - buffer_chunk_num) * training_dataset_percentage)
validation_chunk_num = int(chunk_num - buffer_chunk_num - training_chunk_num)
if training_chunk_num > 0:
training_chunk_list += current_chunk_list[:training_chunk_num]
if validation_chunk_num > 0:
validation_chunk_list += current_chunk_list[-validation_chunk_num:]

if need_split_validation_data:
return np.array(training_chunk_list), np.array(validation_chunk_list)

return np.array(all_shuffle_chunk_list[:train_chunk_num]), np.array(all_shuffle_chunk_list[train_chunk_num:])


Expand Down Expand Up @@ -145,6 +162,7 @@ def train_model(args):
label_shape = param.label_shape
label_shape_cum = param.label_shape_cum
batch_size, chunk_size = param.trainBatchSize, param.chunk_size
assert batch_size % chunk_size == 0
chunks_per_batch = batch_size // chunk_size
random.seed(param.RANDOM_SEED)
np.random.seed(param.RANDOM_SEED)
Expand All @@ -159,7 +177,7 @@ def populate_dataset_table(file_list, file_path):
for bin_idx, bin_file in enumerate(file_list):
table_dataset = tables.open_file(os.path.join(file_path, bin_file), 'r')
table_dataset_list.append(table_dataset)
chunk_num = (len(table_dataset.root.label) - chunk_size) // chunk_size
chunk_num = (len(table_dataset.root.label) - batch_size) // chunk_size
chunk_offset[bin_idx] = chunk_num
return table_dataset_list, chunk_offset

Expand All @@ -185,13 +203,17 @@ def populate_dataset_table(file_list, file_path):
total_chunks = train_chunk_num + validate_chunk_num
else:
total_chunks = int(sum(chunk_offset))
training_dataset_percentage = param.trainingDatasetPercentage if add_validation_dataset else None
if add_validation_dataset:
total_batches = total_chunks // chunks_per_batch
validate_chunk_num = int(max(1., np.floor(total_batches * (1 - param.trainingDatasetPercentage))) * chunks_per_batch)
validate_chunk_num = int(max(1., np.floor(total_batches * (1 - training_dataset_percentage))) * chunks_per_batch)
# +++++++++++++**----
# +:training *:buffer -:validation
# distribute one batch data as buffer for each bin file, avoiding shifting training data to validation data
train_chunk_num = int(total_chunks - validate_chunk_num)
else:
train_chunk_num = total_chunks
train_shuffle_chunk_list, validate_shuffle_chunk_list = get_chunk_list(chunk_offset, train_chunk_num)
train_shuffle_chunk_list, validate_shuffle_chunk_list = get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch, training_dataset_percentage)
train_chunk_num = len(train_shuffle_chunk_list)
validate_chunk_num = len(validate_shuffle_chunk_list)

Expand Down Expand Up @@ -223,7 +245,7 @@ def populate_dataset_table(file_list, file_path):
metrics=metrics,
optimizer=optimizer
)
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode="min")
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10*mini_epochs, mode="min")
model_save_callback = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False)
model_best_callback = tf.keras.callbacks.ModelCheckpoint("best_val_loss", monitor='val_loss', save_best_only=True, mode="min")
train_log_callback = tf.keras.callbacks.CSVLogger("training.log", separator='\t')
Expand Down
91 changes: 82 additions & 9 deletions clair3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,48 @@ def item_from(row):
f.wait()


def remove_common_suffix(ref_base, alt_base):
min_length = min(len(ref_base) - 1, min([len(item) - 1 for item in alt_base])) # keep at least one base
prefix = ref_base[::-1]
for string in alt_base:
string = string[::-1]
while string[:len(prefix)] != prefix and prefix:
prefix = prefix[:len(prefix) - 1]
if not prefix:
break
res_length = len(prefix)
if res_length > min_length:
return ref_base, alt_base
return ref_base[:len(ref_base) - res_length], [item[:len(item) - res_length] for item in alt_base]

return ref_base[-min_length], [item[-min_length] for item in alt_base]


def decode_alt(ref_base, alt_base):
if ',' not in alt_base:
return [ref_base], [alt_base]
alt_base = alt_base.split(',')
ref_base_list, alt_base_list = [], []
for ab in alt_base:
rb,ab = remove_common_suffix(ref_base, [ab])
ref_base_list.append(rb)
alt_base_list.append(ab[0])
return ref_base_list, alt_base_list


def variant_map_from(var_fn, tree, is_tree_empty):
Y = {}
truth_alt_dict = {}
miss_variant_set = set()
if var_fn is None:
return Y, miss_variant_set
return Y, miss_variant_set, truth_alt_dict

f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn)))
for row in f.stdout:
columns = row.split()
ctg_name, position_str = columns[0], columns[1]
genotype1, genotype2 = columns[-2], columns[-1]
if row[0] == "#":
continue
columns = row.strip().split()
ctg_name, position_str, ref_base, alt_base, genotype1, genotype2 = columns
key = ctg_name + ":" + position_str
if genotype1 == '-1' or genotype2 == '-1':
miss_variant_set.add(key)
Expand All @@ -124,11 +155,41 @@ def variant_map_from(var_fn, tree, is_tree_empty):
continue

Y[key] = output_labels_from_vcf_columns(columns)

ref_base_list, alt_base_list = decode_alt(ref_base, alt_base)
truth_alt_dict[int(position_str)] = (ref_base_list, alt_base_list)
f.stdout.close()
f.wait()
return Y, miss_variant_set

return Y, miss_variant_set, truth_alt_dict

def find_read_support(pos, truth_alt_dict, alt_info):
alt_info = alt_info.rstrip().split('-')
seqs = alt_info[1].split(' ') if len(alt_info) > 1 else ''
seq_alt_bases_dict = dict(zip(seqs[::2], [int(item) for item in seqs[1::2]])) if len(seqs) else {}

pos = int(pos)
if pos not in truth_alt_dict:
# candidate position not in the truth vcf or unified truth vcf
return None
ref_base_list, alt_base_list = truth_alt_dict[pos]
found = 0
for alt_type in seq_alt_bases_dict:
if '*' in alt_type or '#' in alt_type or 'R' in alt_type:
continue
if alt_type[0] == 'X':
if alt_type[1] in alt_base_list:
found += 1
elif alt_type[0] == 'I':
if alt_type[1:] in alt_base_list:
found += 1
elif alt_type[0] == 'D':
del_cigar = alt_type[1:]
for rb, ab in zip(ref_base_list, alt_base_list):
if rb[1:] == del_cigar and len(ab) == 1:
found += 1
if found >= len(alt_base_list):
return True
# return False if we find any alternative bases missed in subsampled bam, then remove the position from training
return False

def write_table_dict(table_dict, string, label, pos, total, alt_info, tensor_shape, pileup):
"""
Expand Down Expand Up @@ -207,7 +268,7 @@ def print_bin_size(path, prefix=None):
print('[INFO] total: {}'.format(total))


def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None):
def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, truth_alt_dict, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None):

"""
Bin reader generator for bin file generation.
Expand All @@ -216,6 +277,7 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis
Y: dictionary (contig name: label information) to store all variant and non variant information.
tree: dictionary(contig name : intervaltree) for quick region querying.
miss_variant_set: sometimes there will have true variant missing after downsampling reads.
truth_alt_dict: unified truth reference base and alternative bases to find read support.
is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character.
maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training
time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate.
Expand All @@ -224,6 +286,8 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis
X = {}
ref_list = []
total = 0
variant_set_with_read_support = set()
variants_without_read_support = 0
for row_idx, row in enumerate(tensor_fn):
chrom, coord, seq, string, alt_info = row.split("\t")
alt_info = alt_info.rstrip()
Expand All @@ -238,6 +302,13 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis
if key in miss_variant_set:
continue

have_read_support = find_read_support(pos=coord, truth_alt_dict=truth_alt_dict, alt_info=alt_info)
if have_read_support is not None and not have_read_support:
miss_variant_set.add(key)
variants_without_read_support += 1
continue

variant_set_with_read_support.add(key)
if key not in X:
X[key] = (string, alt_info, seq)
if is_reference:
Expand Down Expand Up @@ -267,6 +338,7 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis
if total % 100000 == 0:
print("[INFO] Processed %d tensors" % total, file=sys.stderr)

print("[INFO] Variants with read support/variants without read support: {}/{}".format(len(variant_set_with_read_support), variants_without_read_support))
if maximum_non_variant_ratio is not None:
_filter_non_variants(X, ref_list, maximum_non_variant_ratio)
yield X, total, True
Expand Down Expand Up @@ -306,7 +378,7 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow

tree = bed_tree_from(bed_file_path=bed_fn)
is_tree_empty = len(tree.keys()) == 0
Y_true_var, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty)
Y_true_var, miss_variant_set, truth_alt_dict = variant_map_from(var_fn, tree, is_tree_empty)
Y = copy.deepcopy(Y_true_var)

global param
Expand Down Expand Up @@ -367,6 +439,7 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow
is_tree_empty=is_tree_empty,
tree=tree,
miss_variant_set=miss_variant_set,
truth_alt_dict=truth_alt_dict,
is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos,
maximum_non_variant_ratio=maximum_non_variant_ratio)

Expand Down
2 changes: 1 addition & 1 deletion docs/full_alignment_training.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train a model for Clair3 full-alignment calling
# Train a model for Clair3 full-alignment calling (revision 0)

This document shows how to train and fine-tune a deep learning model for Clair3 full-alignment calling. For training a model for pileup calling, please check [here](pileup_training.md). Clair3 needs both a pileup model and a full-alignment model to work. Compared to Clair3's pileup model training, training a full-alignment model needs much longer time. The disk space requirement also increases significantly. The training materials are grouped according to sample, coverage, and chromosome. The groups are converted into tensor binaries. The binaries are much space-efficient and easier to process. As required, multiples tensor binaries can be used together for model training and fine-tuning.

Expand Down
Loading

0 comments on commit 6cd8994

Please sign in to comment.